Skip to content

Commit b8cc36e

Browse files
committed
Minor fixes in sliding intervals and documentation
1 parent e83c3d7 commit b8cc36e

2 files changed

Lines changed: 50 additions & 20 deletions

File tree

scripts/mpi_running.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def parallel_eig(d, off_d, nprocs):
99

1010
print("inside parallel_eig")
1111
comm = MPI.COMM_SELF.Spawn(
12-
sys.executable, args=["./parallel_tridiag_eigen.py"], maxprocs=nprocs
12+
sys.executable, args=["src/pyclassify/parallel_tridiag_eigen.py"], maxprocs=nprocs
1313
)
1414
print("sending")
1515
comm.send(d, dest=0, tag=11)

src/pyclassify/cxx_utils.cpp

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,11 @@ QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const dou
130130
diag[0]=c*c*a_0+s*s*diag[1]-2*s*c*b_1;
131131
diag[1]=c*c*diag[1]+s*s*a_0+2*s*c*b_1;
132132

133-
134133
}
135-
136134
}
137-
138-
139135
}
140136

141137

142-
143138
unsigned j, k;
144139
for(unsigned int i=0; i<m; i++){
145140
c=Matrix_trigonometric[i][0];
@@ -170,8 +165,7 @@ QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const dou
170165
Q[k]=tmp*c-Q[k+n]*s;
171166
Q[k+n]=tmp*s+Q[k+n]*c;
172167
}
173-
174-
168+
175169
}
176170

177171
iter++;
@@ -181,9 +175,7 @@ QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const dou
181175
}
182176
}
183177

184-
if(iter==max_iter){
185-
std::cout<<"Converges failed"<<std::endl;
186-
}
178+
if(iter==max_iter) { std::cout<<"The QR method did not converge."<<std::endl; }
187179

188180
std::vector<std::vector<double>> eig_vec(n,std::vector<double> (n, 0));
189181
//std::cout<<"Iteration: "<<iter<<std::endl;
@@ -210,7 +202,6 @@ Eigen_value_calculator(std::vector<double> diag, std::vector<double> off_diag, c
210202

211203

212204

213-
214205
std::vector<std::array<double, 2>> Matrix_trigonometric(n-1, {0, 0});
215206

216207
unsigned int iter = 0;
@@ -331,6 +322,12 @@ Eigen_value_calculator(std::vector<double> diag, std::vector<double> off_diag, c
331322
}
332323

333324

325+
/* Now we implement all the functions that are needed to solve the secular equation following the ETH lecture notes in the references.
326+
* This procedure involves defining the secular function, creating a nonlinear solver and computing the zero for each interval.
327+
* Notice that the outer zero is computed using bisection as a consequence of the fact that the procedure described in the notes only ensures
328+
* conergence in the inner intervals. */
329+
330+
334331
double compute_sum(
335332
const std::vector<double>& v,
336333
const std::vector<double>& d,
@@ -352,6 +349,7 @@ double compute_sum(
352349
return rho * sum;
353350
}
354351

352+
355353
void compute_Psi(
356354
const unsigned int i,
357355
const std::vector<double>& v,
@@ -362,6 +360,7 @@ void compute_Psi(
362360
std::function<double(double)>& dPsi_1,
363361
std::function<double(double)>& dPsi_2
364362
) {
363+
/* Function to compute the psi_s that appear in the secular function. */
365364
Psi_1 = [&](double x) {
366365
return compute_sum(v, d, x, 0, i + 1, false, rho);
367366
};
@@ -376,6 +375,7 @@ void compute_Psi(
376375
};
377376
}
378377

378+
379379
std::pair<double, double> find_root(
380380
const unsigned int i,
381381
const bool left_center,
@@ -385,6 +385,8 @@ std::pair<double, double> find_root(
385385
double lam_0,
386386
const double tol = 1e-15,
387387
const unsigned int maxiter = 100) {
388+
389+
/* Function to compute the inner root in the i-th interval */
388390
std::vector<double> diag = d;
389391
double shift;
390392

@@ -435,6 +437,7 @@ double bisection(
435437
double b,
436438
const double tol,
437439
const unsigned int max_iter) {
440+
438441
unsigned int iter_count = 0;
439442

440443
while ((b - a) / 2.0 > tol) {
@@ -462,16 +465,26 @@ double compute_outer_zero(
462465
const double rho,
463466
const double interval_end,
464467
const double tol = 1e-14,
465-
const unsigned int max_iter = 1000){
468+
const unsigned int max_iter = 100){
469+
470+
// This function calls bisection on a sliding interval. The reason for that is to ensure that we are in the condition to be able to use bisection.
466471

467-
const double threshold = 1e-11;
472+
double threshold = 1e-11;
468473
double update = 0.0;
469474

470-
// Compute L2 norm of v
475+
// Compute L2 norm of v and use it for the update
476+
471477
for (size_t i = 0; i < v.size(); ++i) {
472478
update += v[i] * v[i];
473479
}
474-
update = std::sqrt(update);
480+
// update = std::sqrt(update); // actually we use the square of the norm to avoid having to compute the square root
481+
482+
// another possibility for the update is this one (which is cheaper, but might cause troubles if the elements of d are too close to each other):
483+
484+
//if (rho>=0)
485+
// update = d[d.size() - 1] - d[d.size() - 2]; // cheaper than computing the norm each time
486+
//else
487+
// update = d[1] - d[0]
475488

476489
auto f = [&](double x) -> double {
477490
double sum = 0.0;
@@ -487,13 +500,30 @@ double compute_outer_zero(
487500
if (rho > 0.0) {
488501
a = interval_end + threshold;
489502
b = interval_end + 1.0;
490-
while (f(a) * f(b) > 0.0) {
491-
a = b;
492-
b += update;
503+
504+
// Ensure that f(a) has the correct sign
505+
while (f(a)>0){
506+
// we are on the wrong side of the zero!
507+
b = a;
508+
threshold = threshold*0.5;
509+
a = interval_end + threshold;
493510
}
511+
512+
while (f(a) * f(b) > 0.0) {
513+
a = b;
514+
b += update;
515+
}
516+
494517
} else if (rho < 0.0) {
495518
b = interval_end - threshold;
496519
a = interval_end - 1.0;
520+
521+
// also in this case, ensure we are on the correct side of the zero
522+
while (f(b)>0){
523+
a = b;
524+
threshold = threshold*0.5;
525+
b = interval_end - threshold;
526+
}
497527
while (f(a) * f(b) > 0.0) {
498528
b = a;
499529
a -= update;
@@ -787,4 +817,4 @@ PYBIND11_MODULE(cxx_utils, m) {
787817
m.def("Eigen_value_calculator", &Eigen_value_calculator, py::arg("diag"), py::arg("off_diag"), py::arg("tol")=1e-8, py::arg("max_iter")=5000);
788818
m.def("secular_solver_cxx", &secular_solver, py::arg("rho"), py::arg("d"), py::arg("v"), py::arg("indices"));
789819
m.def("deflate_eigenpairs_cxx", &deflateEigenpairs, py::arg("D"), py::arg("v"), py::arg("beta"), py::arg("tol_factor") = 1e-12);
790-
}
820+
}

0 commit comments

Comments
 (0)