11from mpi4py import MPI
22import numpy as np
33from time import time
4- from pyclassify .cxx_utils import QR_algorithm , secular_solver_cxx , deflate_eigenpairs_cxx
4+ from pyclassify .cxx_utils import (
5+ QR_algorithm ,
6+ secular_solver_cxx ,
7+ deflate_eigenpairs_cxx ,
8+ )
59from pyclassify .zero_finder import secular_solver_python as secular_solver
610from line_profiler import profile , LineProfiler
711import scipy .sparse as sp
@@ -158,19 +162,22 @@ def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
158162
159163 return deflated_eigvals , np .array (deflated_eigvecs ), D_keep , v_keep , P_3 @ P_2 @ P
160164
165+
161166def find_interval_extreme (total_dimension , n_processor ):
162167 """
163168 Computes the intervals for vector for being scattered.
164169 Input:
165170 -total_dimension: the dimension of the vector that has to be splitted
166171 -n_processor: the number of processor to which the scatter vector has to be sent
167-
172+
168173 """
169174
170- base = total_dimension // n_processor
175+ base = total_dimension // n_processor
171176 rest = total_dimension % n_processor
172177
173- counts = np .array ([base + 1 if i < rest else base for i in range (n_processor )], dtype = int )
178+ counts = np .array (
179+ [base + 1 if i < rest else base for i in range (n_processor )], dtype = int
180+ )
174181 displs = np .insert (np .cumsum (counts ), 0 , 0 )[:- 1 ]
175182
176183 return counts , displs
@@ -210,7 +217,6 @@ def parallel_tridiag_eigen(
210217 n = len (diag )
211218 prof_filename = f"Profile_folder/profile.rank{ current_rank } .depth{ depth } .lprof"
212219
213-
214220 if n <= min_size or size == 1 :
215221 eigvals , eigvecs = QR_algorithm (diag , off , 1e-16 , max_iterQR )
216222 eigvecs = np .array (eigvecs )
@@ -245,8 +251,8 @@ def parallel_tridiag_eigen(
245251 depth = depth + 1 ,
246252 profiler = profiler ,
247253 )
248- eigvals_right = None
249- eigvecs_right = None
254+ eigvals_right = None
255+ eigvecs_right = None
250256 else :
251257 eigvals_right , eigvecs_right = parallel_tridiag_eigen (
252258 diag2 ,
@@ -257,50 +263,49 @@ def parallel_tridiag_eigen(
257263 depth = depth + 1 ,
258264 profiler = profiler ,
259265 )
260- eigvals_left = None
261- eigvecs_left = None
262-
266+ eigvals_left = None
267+ eigvecs_left = None
263268
264269 # 1) Identify the two “root” ranks in MPI.COMM_WORLD
265- left_size = size // 2 if size > 1 else 1
266- root_left = 0
270+ left_size = size // 2 if size > 1 else 1
271+ root_left = 0
267272 root_right = left_size
268- other_root = root_right if color == 0 else root_left
273+ other_root = root_right if color == 0 else root_left
269274
270- # now exchange between the two roots
275+ # now exchange between the two roots
271276 if subcomm .Get_rank () == 0 :
272- send_data = (eigvals_left , eigvecs_left ) \
273- if color == 0 else (eigvals_right , eigvecs_right )
277+ send_data = (
278+ (eigvals_left , eigvecs_left )
279+ if color == 0
280+ else (eigvals_right , eigvecs_right )
281+ )
274282 recv_data = comm .sendrecv (
275- send_data , dest = other_root , sendtag = depth ,
276- source = other_root , recvtag = depth
283+ send_data , dest = other_root , sendtag = depth , source = other_root , recvtag = depth
277284 )
278285 # unpack
279286 if color == 0 :
280287 eigvals_right , eigvecs_right = recv_data
281288 else :
282- eigvals_left , eigvecs_left = recv_data
289+ eigvals_left , eigvecs_left = recv_data
283290
284- eigvals_left = subcomm .bcast (eigvals_left , root = 0 )
285- eigvecs_left = subcomm .bcast (eigvecs_left , root = 0 )
291+ eigvals_left = subcomm .bcast (eigvals_left , root = 0 )
292+ eigvecs_left = subcomm .bcast (eigvecs_left , root = 0 )
286293 eigvals_right = subcomm .bcast (eigvals_right , root = 0 )
287294 eigvecs_right = subcomm .bcast (eigvecs_right , root = 0 )
288295
289-
290296 # if rank == 0:
291297 # eigvals_right = comm.recv(source=left_size, tag=77)
292298 # eigvecs_right = comm.recv(source=left_size, tag=78)
293299 # elif rank == left_size:
294300 # comm.send(eigvals_right, dest=0, tag=77)
295301 # comm.send(eigvecs_right, dest=0, tag=78)
296-
297302
298303 if rank == 0 :
299304
300305 # Merge Step
301306 n1 = len (eigvals_left )
302307 D = np .concatenate ((eigvals_left , eigvals_right ))
303- D_size = D .size
308+ D_size = D .size
304309 v_vec = np .concatenate ((eigvecs_left [- 1 , :], eigvecs_right [0 , :]))
305310
306311 deflated_eigvals , deflated_eigvecs , D_keep , v_keep , P = deflate_eigenpairs_cxx (
@@ -329,37 +334,37 @@ def parallel_tridiag_eigen(
329334 beta , D_keep [idx ], v_keep [idx ] , np .arange (reduced_dim )
330335 )
331336 lam = np .array (lam )
332- delta = np .array (delta )
333- changing_position = np .array (changing_position )
334-
337+ delta = np .array (delta )
338+ changing_position = np .array (changing_position )
339+ # #diff=lam_s-lam
335340 else :
336341 lam = np .array ([])
337-
342+
338343 counts , displs = find_interval_extreme (reduced_dim , size )
339344
340345 else :
341346 counts = None
342347 displs = None
343- lam = None
344- D_keep = None
345- v_keep = None
346- delta = None
347- reduced_dim = None
348- D_size = None
349- changing_position = None
350- type_lam = None
351- type_D = None
352- P = None
353- idx_inv = None
354- n1 = None
348+ lam = None
349+ D_keep = None
350+ v_keep = None
351+ delta = None
352+ reduced_dim = None
353+ D_size = None
354+ changing_position = None
355+ type_lam = None
356+ type_D = None
357+ P = None
358+ idx_inv = None
359+ n1 = None
355360 deflated_eigvals = None
356361 deflated_eigvecs = None
357-
362+
358363 counts = comm .bcast (counts , root = 0 )
359364 displs = comm .bcast (displs , root = 0 )
360- lam = comm .bcast (lam , root = 0 )
361- D_keep = comm .bcast (D_keep , root = 0 )
362- v_keep = comm .bcast (v_keep , root = 0 )
365+ lam = comm .bcast (lam , root = 0 )
366+ D_keep = comm .bcast (D_keep , root = 0 )
367+ v_keep = comm .bcast (v_keep , root = 0 )
363368 my_count = counts [rank ]
364369 type_lam = comm .bcast (lam .dtype , root = 0 )
365370
@@ -389,28 +394,30 @@ def parallel_tridiag_eigen(
389394 # now do the scatterv
390395 comm .Scatterv (
391396 [lam , counts , displs , mpi_type ], # send tuple, only root’s lam is used here
392- lam_buffer , # recvbuf on every rank
393- root = 0
397+ lam_buffer , # recvbuf on every rank
398+ root = 0 ,
394399 )
395400
396401
397- initial_point = displs [rank ]
402+ initial_point = displs [rank ]
398403
399404 for k_rel in range (lam_buffer .size ):
400- k = k_rel + initial_point
405+ k = k_rel + initial_point
401406 numerator = lam - D_keep [k ]
402407 denominator = np .concatenate ((D_keep [:k ], D_keep [k + 1 :])) - D_keep [k ]
403408 numerator [:- 1 ] = numerator [:- 1 ] / denominator
404409 v_keep [k ] = np .sqrt (np .abs (np .prod (numerator ) / beta )) * np .sign (v_keep [k ])
405410
406411 # eigenpairs = []
407412
408- eig_vecs = np .empty ((D_size , my_count ),dtype = type_lam )
409- eig_val = np .empty (my_count , dtype = type_lam )
413+ eig_vecs = np .empty ((D_size , my_count ), dtype = type_lam )
414+ eig_val = np .empty (my_count , dtype = type_lam )
410415
411416 for j_rel in range (lam_buffer .size ):
412417 y = np .zeros (D_size )
413- j = j_rel + initial_point
418+ # y[:reduced_dim]=v_keep/(lam[j]-D_keep)
419+ # y /= np.linalg.norm(y)
420+ j = j_rel + initial_point
414421 diff = lam [j ] - D_keep
415422 diff [idx_inv [changing_position [j ]]] = delta [j ]
416423 y [:reduced_dim ] = v_keep / (diff )
@@ -421,75 +428,72 @@ def parallel_tridiag_eigen(
421428 y = P .T @ y
422429 vec = np .concatenate ((eigvecs_left @ y [:n1 ], eigvecs_right @ y [n1 :]))
423430 vec /= np .linalg .norm (vec )
424- eig_vecs [:, j_rel ]= vec
425- eig_val [j_rel ]= lam [j ]
426- #eigenpairs.append((lam[j], vec))
431+ eig_vecs [:, j_rel ] = vec
432+ eig_val [j_rel ] = lam [j ]
433+ # eigenpairs.append((lam[j], vec))
427434
428- if reduced_dim < D_size :
435+ if reduced_dim < D_size :
429436
430- if rank == 0 :
431- le_deflation = len (deflated_eigvals )
437+ if rank == 0 :
438+ le_deflation = len (deflated_eigvals )
432439 counts , displs = find_interval_extreme (le_deflation , size )
433440
434441 counts = comm .bcast (counts , root = 0 )
435442 displs = comm .bcast (displs , root = 0 )
436443 my_count = counts [rank ]
437-
438- deflated_eigvals_buffer = np .empty (my_count , dtype = type_lam )
444+
445+ deflated_eigvals_buffer = np .empty (my_count , dtype = type_lam )
439446 if rank == 0 :
440447 char = deflated_eigvals .dtype .char
441- type_eig = deflated_eigvals .dtype
448+ type_eig = deflated_eigvals .dtype
442449 else :
443450 char = None
444- type_eig = None
451+ type_eig = None
445452
446453 # now everyone learns the character code:
447454 char = comm .bcast (char , root = 0 )
448455 type_eig = comm .bcast (type_eig , root = 0 )
449456 comm .Scatterv (
450- [deflated_eigvals , counts , displs , MPI ._typedict [char ]],
451- deflated_eigvals_buffer ,
452- root = 0 ,
457+ [deflated_eigvals , counts , displs , MPI ._typedict [char ]],
458+ deflated_eigvals_buffer ,
459+ root = 0 ,
453460 )
454- if rank == 0 :
455- _ , k = deflated_eigvecs .shape
461+ if rank == 0 :
462+ _ , k = deflated_eigvecs .shape
456463 else :
457464 mat = None
458- k = None
459- k = comm .bcast (k , root = 0 )
465+ k = None
466+ k = comm .bcast (k , root = 0 )
460467 # each row of `mat` is one deflated vec
461468 sendcounts = [c * k for c in counts ]
462- senddispls = [d * k for d in displs ]
463- deflated_eigvecs_buffer = np .empty ( (my_count , k ), dtype = type_eig )
469+ senddispls = [d * k for d in displs ]
470+ deflated_eigvecs_buffer = np .empty ((my_count , k ), dtype = type_eig )
464471 if rank == 0 :
465- flat_send = deflated_eigvecs .copy ().flatten () # shape (M*k,)
472+ flat_send = deflated_eigvecs .copy ().flatten () # shape (M*k,)
466473 sendbuf = [flat_send , sendcounts , senddispls , MPI ._typedict [char ]]
467474 else :
468475 sendbuf = None
469-
470476
471477 # now scatter to everyone
472478 comm .Scatterv (
473- sendbuf , # only meaningful on rank 0
474- deflated_eigvecs_buffer , # each rank’s recv‐buffer of length k × my_count
475- root = 0
479+ sendbuf , # only meaningful on rank 0
480+ deflated_eigvecs_buffer , # each rank’s recv‐buffer of length k × my_count
481+ root = 0 ,
476482 )
477-
478483
479- #local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
484+ # local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
480485 for i in range (my_count ):
481486 small_vec = deflated_eigvecs_buffer [i ]
482487 # apply the two block Q’s:
483- left_part = eigvecs_left @ small_vec [:n1 ]
488+ left_part = eigvecs_left @ small_vec [:n1 ]
484489 right_part = eigvecs_right @ small_vec [n1 :]
485490 local_final_vecs = np .concatenate ((left_part , right_part ))
486491 local_final_vecs = local_final_vecs .reshape (k , 1 )
487- eig_val = np .append (eig_val , deflated_eigvals_buffer [i ])
488- eig_vecs = np .concatenate ([eig_vecs , local_final_vecs ], axis = 1 )
492+ eig_val = np .append (eig_val , deflated_eigvals_buffer [i ])
493+ eig_vecs = np .concatenate ([eig_vecs , local_final_vecs ], axis = 1 )
489494
490-
491495 # 1) Each rank computes its local length:
492- local_count = eig_val .size # or however many elements you’ll send
496+ local_count = eig_val .size # or however many elements you’ll send
493497
494498 # 2) Everyone exchanges counts via allgather:
495499 # this returns a Python list of length `size` on every rank
@@ -500,48 +504,39 @@ def parallel_tridiag_eigen(
500504
501505 # # 2) Broadcast that list from rank 0 back to everyone
502506 # recvcounts = comm.bcast(counts, root=0)
503-
504507
505508 final_eig_val = np .empty (D_size , dtype = eig_val .dtype )
506509
507-
508-
509- displs = np .append ([0 ], np .cumulative_sum (recvcounts [:- 1 ]).astype (int ))
510+ displs = np .append ([0 ], np .cumulative_sum (recvcounts [:- 1 ]).astype (int ))
510511
511512 mpi_t = MPI ._typedict [eig_val .dtype .char ]
512- comm .Allgatherv (
513- [eig_val , mpi_t ],
514- [final_eig_val , recvcounts , displs , mpi_t ]
515- )
516-
513+ comm .Allgatherv ([eig_val , mpi_t ], [final_eig_val , recvcounts , displs , mpi_t ])
514+
517515 # 1) Flatten local eigenvector block
518- # eig_vecs has shape (D_size, local_count)
519- local_flat = eig_vecs .T .flatten ()
516+ # eig_vecs has shape (D_size, local_count)
517+ local_flat = eig_vecs .T .flatten ()
520518
521519 # 2) Build sendcounts & displacements for the flattened arrays
522520 sendcounts_vecs = [c * D_size for c in recvcounts ]
523- senddispls_vecs = [d * D_size for d in displs ]
521+ senddispls_vecs = [d * D_size for d in displs ]
524522
525523 # 3) Allocate full receive buffer on every rank
526524 flat_all = np .empty (sum (sendcounts_vecs ), dtype = eig_vecs .dtype )
527525
528526 # 4) Perform the all-gather-variable-counts
529527 mpi_tvec = MPI ._typedict [eig_vecs .dtype .char ]
530528 comm .Allgatherv (
531- [local_flat , mpi_tvec ], # sendbuf
532- [flat_all ,
533- sendcounts_vecs ,
534- senddispls_vecs ,
535- mpi_tvec ] # recvbuf spec
529+ [local_flat , mpi_tvec ], # sendbuf
530+ [flat_all , sendcounts_vecs , senddispls_vecs , mpi_tvec ], # recvbuf spec
536531 )
537532
538533 # 5) Reshape on every rank (or just on rank 0 if you prefer)
539534 # total_pairs == sum(recvcounts)
540535 final_eig_vecs = flat_all .reshape (D_size , D_size )
541- final_eig_vecs = final_eig_vecs .T
542- index_sort = np .argsort (final_eig_val )
543- final_eig_vecs = final_eig_vecs [:, index_sort ]
544- final_eig_val = final_eig_val [index_sort ]
536+ final_eig_vecs = final_eig_vecs .T
537+ index_sort = np .argsort (final_eig_val )
538+ final_eig_vecs = final_eig_vecs [:, index_sort ]
539+ final_eig_val = final_eig_val [index_sort ]
545540 # if rank==0:
546541 # print(final_eig_val)
547542 # print(final_eig_vecs)
0 commit comments