diff --git a/tsadar/core/physics/ratintn.py b/tsadar/core/physics/ratintn.py index dd56a476..1fc00f6e 100644 --- a/tsadar/core/physics/ratintn.py +++ b/tsadar/core/physics/ratintn.py @@ -44,9 +44,17 @@ def ratcen(f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: gav = 0.5 * (g[1:-1] + g[0:-2]) tmp = fav * gdif - gav * fdif - rf = fav / gav + tmp * gdif / (12.0 * gav**3) + + use_rf = jnp.abs(gdif) < 1.0e-4 * jnp.abs(gav) + # Guard the denominator of the *unused* branch so it stays finite: near a pole gav -> 0, + # so the (unselected) rf branch is inf there. jnp.where picks rfn for the value, but autodiff + # differentiates both branches and propagates 0*inf = nan. On CPU gav is a tiny non-zero so the + # rf gradient is finite (0*huge=0); on GPU FMA/rounding lands gav at exactly 0 -> inf -> nan. + # The double-where keeps each branch's gradient finite where it is not selected. + gav_safe = jnp.where(use_rf, gav, 1.0) + rf = fav / gav_safe + tmp * gdif / (12.0 * gav_safe**3) rfn = fdif / gdif + tmp * jnp.log((gav + (0.5 + 0j) * gdif) / (gav - 0.5 * gdif)) / gdif**2 - out = jnp.where((jnp.abs(gdif) < 1.0e-4 * jnp.abs(gav))[None, :], rf, rfn) + out = jnp.where(use_rf[None, :], rf, rfn) return jnp.real(out) diff --git a/tsadar/inverse/loops.py b/tsadar/inverse/loops.py index 62e8ed88..f10f18d1 100644 --- a/tsadar/inverse/loops.py +++ b/tsadar/inverse/loops.py @@ -92,6 +92,9 @@ def _1d_optax_loop_( best_loss = 1e16 epoch_loss = 1e19 + # fall back to the starting weights so a never-improving (e.g. NaN) loss + # still returns valid params instead of raising UnboundLocalError + best_weights = eqx.combine(diff_params, static_params) for i_epoch in range(config["optimizer"]["num_epochs"]): tbatch.set_description(f"Epoch {i_epoch + 1}, Prev Epoch Loss {epoch_loss:.2e}")