@@ -316,22 +316,20 @@ def parallel_tridiag_eigen(
316316 # D, v_vec, beta, tol_factor
317317 # )
318318
319-
320-
321- D_keep = np .array (D_keep )
319+ D_keep = np .array (D_keep )
322320
323321 reduced_dim = len (D_keep )
324322
325323 if D_keep .size > 0 :
326324 idx = np .argsort (D_keep )
327325 idx_inv = np .arange (0 , reduced_dim )
328326 idx_inv = idx_inv [idx ]
329-
327+
330328 # T= np.diag(D_keep) + beta * np.outer(v_keep, v_keep)
331329 # lam , _ = np.linalg.eigh(T)
332330
333331 lam , changing_position , delta = secular_solver_cxx (
334- beta , D_keep [idx ], v_keep [idx ] , np .arange (reduced_dim )
332+ beta , D_keep [idx ], v_keep [idx ], np .arange (reduced_dim )
335333 )
336334 lam = np .array (lam )
337335 delta = np .array (delta )
@@ -366,18 +364,17 @@ def parallel_tridiag_eigen(
366364 D_keep = comm .bcast (D_keep , root = 0 )
367365 v_keep = comm .bcast (v_keep , root = 0 )
368366 my_count = counts [rank ]
369- type_lam = comm .bcast (lam .dtype , root = 0 )
370-
371- lam_buffer = np .empty (my_count , dtype = type_lam )
367+ type_lam = comm .bcast (lam .dtype , root = 0 )
372368
373- P = comm .bcast (P , root = 0 )
374- D_size = comm .bcast (D_size )
375- changing_position = comm .bcast (changing_position , root = 0 )
376- delta = comm .bcast (delta , root = 0 )
377- idx_inv = comm .bcast (idx_inv , root = 0 )
378- n1 = comm .bcast (n1 , root = 0 )
379- reduced_dim = comm .bcast (reduced_dim , root = 0 )
369+ lam_buffer = np .empty (my_count , dtype = type_lam )
380370
371+ P = comm .bcast (P , root = 0 )
372+ D_size = comm .bcast (D_size )
373+ changing_position = comm .bcast (changing_position , root = 0 )
374+ delta = comm .bcast (delta , root = 0 )
375+ idx_inv = comm .bcast (idx_inv , root = 0 )
376+ n1 = comm .bcast (n1 , root = 0 )
377+ reduced_dim = comm .bcast (reduced_dim , root = 0 )
381378
382379 # map numpy dtype → MPI datatype
383380 if lam .dtype == np .float64 :
@@ -398,7 +395,6 @@ def parallel_tridiag_eigen(
398395 root = 0 ,
399396 )
400397
401-
402398 initial_point = displs [rank ]
403399
404400 for k_rel in range (lam_buffer .size ):
@@ -543,8 +539,6 @@ def parallel_tridiag_eigen(
543539 return final_eig_val , final_eig_vecs
544540
545541
546-
547-
548542def parallel_eigen (
549543 main_diag , off_diag , tol_QR = 1e-15 , max_iterQR = 5000 , tol_deflation = 1e-15
550544):
@@ -578,7 +572,7 @@ def parallel_eigen(
578572 # main_diag = np.ones(n, dtype=np.float64) * 2.0
579573 # off_diag = np.ones(n - 1, dtype=np.float64) *1.0
580574 main_diag = (np .random .rand (n ) * 2 ).astype (np .float64 )
581- off_diag = (np .random .rand (n - 1 ) * 1 ).astype (np .float64 )
575+ off_diag = (np .random .rand (n - 1 ) * 1 ).astype (np .float64 )
582576 # eig = np.arange(1, n + 1)
583577 # A = np.diag(eig)
584578 # U = scipy.stats.ortho_group.rvs(n)
0 commit comments