From d80dde21dd0f1f71666a0b9cdd117e50815cd0ca Mon Sep 17 00:00:00 2001 From: archis Date: Thu, 4 Jun 2026 10:49:04 -0700 Subject: [PATCH] Fix GPU-only NaN gradient in form factor pole integral The rationally-centered integrand in ratcen() selects between a regular form (rf) and a near-pole form (rfn) via jnp.where. The unused rf branch divides by gav (= average of the denominator grid), which goes to ~0 near a pole. jnp.where returns the correct value, but reverse-mode autodiff differentiates both branches and propagates 0*inf from the dead rf branch. On CPU gav lands at a tiny non-zero so the rf gradient is finite (0*huge=0); on GPU FMA/rounding lands gav at exactly 0, giving 0*inf=nan. The nan flows back through vTe=sqrt(Te/Me) into grad.electron.normed_Te and blows up fits. Guard the unused branch's denominator with the double-where idiom so its gradient stays finite. The selected value is unchanged, so the fix is backend-independent and numerically identical in the forward pass. Also pre-initialize best_weights in _1d_optax_loop_ so a never-improving (e.g. NaN) batch returns valid params instead of raising UnboundLocalError. Co-Authored-By: Claude Opus 4.8 (1M context) --- tsadar/core/physics/ratintn.py | 12 ++++++++++-- tsadar/inverse/loops.py | 3 +++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tsadar/core/physics/ratintn.py b/tsadar/core/physics/ratintn.py index dd56a476..1fc00f6e 100644 --- a/tsadar/core/physics/ratintn.py +++ b/tsadar/core/physics/ratintn.py @@ -44,9 +44,17 @@ def ratcen(f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: gav = 0.5 * (g[1:-1] + g[0:-2]) tmp = fav * gdif - gav * fdif - rf = fav / gav + tmp * gdif / (12.0 * gav**3) + + use_rf = jnp.abs(gdif) < 1.0e-4 * jnp.abs(gav) + # Guard the denominator of the *unused* branch so it stays finite: near a pole gav -> 0, + # so the (unselected) rf branch is inf there. jnp.where picks rfn for the value, but autodiff + # differentiates both branches and propagates 0*inf = nan. On CPU gav is a tiny non-zero so the + # rf gradient is finite (0*huge=0); on GPU FMA/rounding lands gav at exactly 0 -> inf -> nan. + # The double-where keeps each branch's gradient finite where it is not selected. + gav_safe = jnp.where(use_rf, gav, 1.0) + rf = fav / gav_safe + tmp * gdif / (12.0 * gav_safe**3) rfn = fdif / gdif + tmp * jnp.log((gav + (0.5 + 0j) * gdif) / (gav - 0.5 * gdif)) / gdif**2 - out = jnp.where((jnp.abs(gdif) < 1.0e-4 * jnp.abs(gav))[None, :], rf, rfn) + out = jnp.where(use_rf[None, :], rf, rfn) return jnp.real(out) diff --git a/tsadar/inverse/loops.py b/tsadar/inverse/loops.py index 62e8ed88..f10f18d1 100644 --- a/tsadar/inverse/loops.py +++ b/tsadar/inverse/loops.py @@ -92,6 +92,9 @@ def _1d_optax_loop_( best_loss = 1e16 epoch_loss = 1e19 + # fall back to the starting weights so a never-improving (e.g. NaN) loss + # still returns valid params instead of raising UnboundLocalError + best_weights = eqx.combine(diff_params, static_params) for i_epoch in range(config["optimizer"]["num_epochs"]): tbatch.set_description(f"Epoch {i_epoch + 1}, Prev Epoch Loss {epoch_loss:.2e}")