From cf3cafebaaf0032decab4faa512ecfbd00283ec5 Mon Sep 17 00:00:00 2001 From: archis Date: Sun, 14 Jun 2026 22:24:30 -0700 Subject: [PATCH 1/3] feat(hermite-legendre-1d): mixed Hermite-Legendre Vlasov-Poisson solver New ADEPTModule implementing the mixed Hermite-Legendre spectral method of Issan, Delzanno & Roytershteyn (arXiv:2606.12322) for the 1D-1V electrostatic Vlasov-Poisson system. The electron distribution is split f = f0 + df: the near-Maxwellian bulk f0 is expanded in the AW-Hermite basis, while the strongly non-Maxwellian part df is expanded in the Legendre basis on a bounded velocity window. The two are coupled one-way (highest Hermite mode -> Legendre) and through Poisson. Modeled structurally on BaseVlasov1D and reusing the Lawson-RK4 design of _hermite_poisson_1d. Both free-streaming operators are symmetric-tridiagonal in mode index, so each is integrated exactly via a prediagonalized matrix exponential; the E-field force, the Legendre Dirichlet penalty, and the Hermite->Legendre coupling are advanced explicitly. Space is spectral (Fourier, periodic). An explicit integrator is used by choice (the paper's implicit midpoint has a large memory footprint): mass and momentum are conserved to machine precision, energy to the time-integrator order (dt-convergent). - adept/_hermite_legendre_1d/{vector_field,modules,storage}.py + public entry - registered as solver "hermite-legendre-1d" in _base_ dispatch and adept.__init__ - conservation constraint J_{Nh,0..2}=0 and Lenard-Bernstein artificial collisions - configs: linear-advection, two-stream, bump-on-tail (paper parameters) - tests: streaming/J-integral units, linear advection vs analytic, conservation - docs: solvers/hermite_legendre_1d/config.md + quick links Co-Authored-By: Claude Opus 4.8 (1M context) --- adept/__init__.py | 2 +- adept/_base_.py | 3 + adept/_hermite_legendre_1d/__init__.py | 3 + adept/_hermite_legendre_1d/modules.py | 404 +++++++++++++++ adept/_hermite_legendre_1d/storage.py | 189 +++++++ adept/_hermite_legendre_1d/vector_field.py | 490 ++++++++++++++++++ adept/hermite_legendre_1d.py | 5 + configs/hermite-legendre-1d/bump-on-tail.yaml | 54 ++ .../hermite-legendre-1d/linear-advection.yaml | 51 ++ configs/hermite-legendre-1d/two-stream.yaml | 51 ++ docs/ARCHITECTURE.md | 1 + docs/RUNNING_A_SIM.md | 1 + .../solvers/hermite_legendre_1d/config.md | 136 +++++ tests/test_hermite_legendre_1d/__init__.py | 0 .../test_conservation.py | 77 +++ .../test_linear_advection.py | 105 ++++ .../test_streaming.py | 103 ++++ 17 files changed, 1674 insertions(+), 1 deletion(-) create mode 100644 adept/_hermite_legendre_1d/__init__.py create mode 100644 adept/_hermite_legendre_1d/modules.py create mode 100644 adept/_hermite_legendre_1d/storage.py create mode 100644 adept/_hermite_legendre_1d/vector_field.py create mode 100644 adept/hermite_legendre_1d.py create mode 100644 configs/hermite-legendre-1d/bump-on-tail.yaml create mode 100644 configs/hermite-legendre-1d/linear-advection.yaml create mode 100644 configs/hermite-legendre-1d/two-stream.yaml create mode 100644 docs/source/solvers/hermite_legendre_1d/config.md create mode 100644 tests/test_hermite_legendre_1d/__init__.py create mode 100644 tests/test_hermite_legendre_1d/test_conservation.py create mode 100644 tests/test_hermite_legendre_1d/test_linear_advection.py create mode 100644 tests/test_hermite_legendre_1d/test_streaming.py diff --git a/adept/__init__.py b/adept/__init__.py index 1c311e14..3b5effaf 100644 --- a/adept/__init__.py +++ b/adept/__init__.py @@ -1,3 +1,3 @@ from ._base_ import ADEPTModule, ergoExo # noqa: I001 from .mlflow_logging import MlflowLoggingModule -from . import hermite_poisson_1d, lpse2d, vlasov1d, vlasov2d +from . import hermite_legendre_1d, hermite_poisson_1d, lpse2d, vlasov1d, vlasov2d diff --git a/adept/_base_.py b/adept/_base_.py index 949bb5a5..aa21bb83 100644 --- a/adept/_base_.py +++ b/adept/_base_.py @@ -326,6 +326,9 @@ def _get_adept_module_(self, cfg: dict) -> ADEPTModule: elif cfg["solver"] == "pic-1d": from adept.pic1d import BasePIC1D as this_module + elif cfg["solver"] == "hermite-legendre-1d": + from adept.hermite_legendre_1d import BaseHermiteLegendre1D as this_module + else: raise NotImplementedError("This solver approach has not been implemented yet") diff --git a/adept/_hermite_legendre_1d/__init__.py b/adept/_hermite_legendre_1d/__init__.py new file mode 100644 index 00000000..4e8ef1dc --- /dev/null +++ b/adept/_hermite_legendre_1d/__init__.py @@ -0,0 +1,3 @@ +from .modules import BaseHermiteLegendre1D + +__all__ = ["BaseHermiteLegendre1D"] diff --git a/adept/_hermite_legendre_1d/modules.py b/adept/_hermite_legendre_1d/modules.py new file mode 100644 index 00000000..1dc9cc13 --- /dev/null +++ b/adept/_hermite_legendre_1d/modules.py @@ -0,0 +1,404 @@ +"""Base ADEPTModule for the 1D mixed Hermite-Legendre Vlasov-Poisson solver. + +Implements the method of Issan, Delzanno & Roytershteyn (arXiv:2606.12322). The +electron distribution is split f = f0 + df with an AW-Hermite expansion for the +near-Maxwellian bulk f0 and a Legendre expansion (on a bounded velocity window) for +the strongly non-Maxwellian part df. Electrostatic, single electron species with an +immobile neutralizing ion background; periodic in x (Fourier); explicit Lawson-RK4. + +Normalization (paper sec 2.1): t*wpe, x/lambda_D, v/vthe. + +Config keys in cfg["physics"]: + Lx (float), alpha (float, Hermite scale), u (float, Hermite shift, default 0), + v_a, v_b (float, Legendre velocity window), gamma (penalty, default 0.5), + nu_H, nu_L (artificial collision rates, default 0), enforce_conservation (bool, default True) + +Config keys in cfg["grid"]: + Nx (int), Nh (int, Hermite modes), Nl (int, Legendre modes), tmax (float), dt (float, default 0.01) + +Config cfg["initialization"]: + type: "linear-advection" | "two-stream" | "bump-on-tail" | "custom" (+ type-specific params) +""" + +import os +import sys + +import numpy as np +from diffrax import ConstantStepSize, NoProgressMeter, ODETerm, SaveAt, SubSaveAt, TqdmProgressMeter, diffeqsolve +from jax import numpy as jnp + +from adept._base_ import ADEPTModule, Stepper +from adept._hermite_legendre_1d.storage import ( + get_save_quantities, + store_coeff_timeseries, + store_fields_timeseries, +) +from adept._hermite_legendre_1d.vector_field import ( + CombinedLinearExp1D, + DiagonalCollisionExp1D, + HermiteLegendre1DVectorField, + PoissonSolver1D, + StreamingExp1D, + hermite_legendre_coupling_vector, + hermite_streaming_matrix, + legendre_constants, + safe_col, +) + + +def _density_profile(x: np.ndarray, Lx: float, base: float, eps: float, mode: int) -> np.ndarray: + """base * (1 + eps cos(2*pi*mode*x/Lx)).""" + return base * (1.0 + eps * np.cos(2.0 * np.pi * mode * x / Lx)) + + +def _project_legendre(g_v, Nl: int, v_a: float, v_b: float) -> np.ndarray: + """Project a velocity profile g(v) onto the Legendre basis: B_m = (1/width) int g xi_m dv.""" + from adept._hermite_legendre_1d.vector_field import _legendre_basis_values + + width = v_b - v_a + deg = max(4 * Nl, 400) + nodes, weights = np.polynomial.legendre.leggauss(deg) + v = 0.5 * width * nodes + 0.5 * (v_b + v_a) + w = 0.5 * width * weights + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + return (xi @ (g_v(v) * w)) / width + + +class BaseHermiteLegendre1D(ADEPTModule): + """1D mixed Hermite-Legendre Vlasov-Poisson base solver (normalized units).""" + + def __init__(self, cfg: dict) -> None: + super().__init__(cfg) + + # ------------------------------------------------------------------ + def write_units(self) -> dict: + """Normalized units throughout; pass through any precomputed derived block.""" + return self.cfg.get("units", {}).get("derived", {}) + + # ------------------------------------------------------------------ + def get_derived_quantities(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + grid["dx"] = Lx / Nx + + tmax = float(grid["tmax"]) + dt = float(grid.get("dt", 0.01)) + nt = round(tmax / dt) + grid["dt"] = dt + grid["nt"] = nt + grid["tmax"] = dt * nt # snap to exact multiple + grid["max_steps"] = nt + 4 + + for save_cfg in self.cfg.get("save", {}).values(): + if isinstance(save_cfg, dict) and "t" in save_cfg: + t_cfg = save_cfg["t"] + t_cfg.setdefault("tmin", 0.0) + t_cfg.setdefault("tmax", grid["tmax"]) + + self.cfg["grid"] = grid + + # ------------------------------------------------------------------ + def get_solver_quantities(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + + kx_1d = jnp.fft.fftfreq(Nx) * Nx * 2.0 * jnp.pi / Lx + one_over_kx = np.zeros(Nx, dtype=np.float64) + one_over_kx[1:] = 1.0 / np.asarray(kx_1d[1:]) + kx_sq = kx_1d**2 + + x = jnp.linspace(0.0, Lx, Nx, endpoint=False) + modes = jnp.fft.fftfreq(Nx) * Nx + mask23 = jnp.abs(modes) <= (Nx // 3) + + grid.update( + { + "x": x, + "kx_1d": kx_1d, + "one_over_kx": jnp.asarray(one_over_kx), + "kx_sq": kx_sq, + "mask23": mask23, + "width": width, + } + ) + self.cfg["grid"] = grid + + # ------------------------------------------------------------------ + def init_state_and_args(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + + x = np.asarray(grid["x"]) + C = np.zeros((Nh, Nx), dtype=np.float64) # real-space C_n(x) + B = np.zeros((Nl, Nx), dtype=np.float64) # real-space B_m(x) + + init = self.cfg.get("initialization", {"type": "linear-advection"}) + itype = init.get("type", "linear-advection") + + if itype == "linear-advection": + eps = float(init.get("eps", 1.0)) + mode = int(init.get("mode", 1)) + n_bulk = _density_profile(x, Lx, 1.0, eps, mode) + C[0] = n_bulk / alpha + + elif itype == "two-stream": + eps = float(init.get("eps", 0.01)) + mode = int(init.get("mode", 1)) + n_bulk = _density_profile(x, Lx, 1.0, eps, mode) + C[0] = n_bulk / alpha + # v^2 Maxwellian -> C_2 = sqrt(2) C_0 in the AW-Hermite basis (alpha=sqrt(2)) + if Nh > 2: + C[2] = np.sqrt(2.0) * C[0] + + elif itype == "bump-on-tail": + eps = float(init.get("eps", 1e-4)) + mode = int(init.get("mode", 1)) + n_beam = float(init.get("n_beam", 0.01)) + v_drift = float(init.get("v_drift", 10.0)) + v_th = float(init.get("v_th", 1.0)) + n_bulk = _density_profile(x, Lx, 1.0 - n_beam, eps, mode) + C[0] = n_bulk / alpha + amp = n_beam / (np.sqrt(2.0 * np.pi) * v_th) + B_beam = _project_legendre(lambda v: amp * np.exp(-((v - v_drift) ** 2) / (2.0 * v_th**2)), Nl, v_a, v_b) + B[:] = B_beam[:, None] # spatially uniform beam + + elif itype == "custom": + # hermite: {n: {base, eps, mode}} ; df: {beams: [{amp, v_drift, v_th}], eps, mode} + for n_str, spec in init.get("hermite", {}).items(): + n = int(n_str) + C[n] = _density_profile( + x, Lx, float(spec.get("base", 0.0)), float(spec.get("eps", 0.0)), int(spec.get("mode", 1)) + ) + df = init.get("df", None) + if df: + beams = df.get("beams", []) + + def g(v): + out = np.zeros_like(v) + for b in beams: + out = out + float(b["amp"]) * np.exp( + -((v - float(b.get("v_drift", 0.0))) ** 2) / (2.0 * float(b.get("v_th", 1.0)) ** 2) + ) + return out + + B_beam = _project_legendre(g, Nl, v_a, v_b) + spatial = _density_profile(x, Lx, 1.0, float(df.get("eps", 0.0)), int(df.get("mode", 1))) + B[:] = B_beam[:, None] * spatial[None, :] + else: + raise ValueError(f"Unknown initialization.type {itype!r}") + + Ck = jnp.fft.fft(jnp.asarray(C), axis=-1, norm="forward").astype(jnp.complex128) + Bk = jnp.fft.fft(jnp.asarray(B), axis=-1, norm="forward").astype(jnp.complex128) + + # Seed the field diagnostics with the actual t=0 Poisson field so the energy + # diagnostic is self-consistent at t=0 (the perturbed initial state carries a + # nonzero E ~ eps); leaving e=0 would record a spurious O(eps^2) jump at step 1. + field_on = bool(physics.get("field", True)) + if field_on: + poisson = PoissonSolver1D( + one_over_kx=grid["one_over_kx"], kx_sq=grid["kx_sq"], alpha=alpha, width=(v_b - v_a) + ) + e0 = poisson.electric_field(Ck, Bk) + phi0 = poisson.potential(Ck, Bk) + else: + e0 = jnp.zeros(Nx) + phi0 = jnp.zeros(Nx) + + self.state = { + "Ck": Ck.view(jnp.float64), + "Bk": Bk.view(jnp.float64), + "e": e0, + "phi": phi0, + } + self.args = {} + + # ------------------------------------------------------------------ + def init_diffeqsolve(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + gamma = float(physics.get("gamma", 0.5)) + nu_H = float(physics.get("nu_H", 0.0)) + nu_L = float(physics.get("nu_L", 0.0)) + enforce = bool(physics.get("enforce_conservation", True)) + field_on = bool(physics.get("field", True)) + dt = float(grid["dt"]) + + kx_1d = grid["kx_1d"] + one_over_kx = grid["one_over_kx"] + kx_sq = grid["kx_sq"] + mask23 = grid["mask23"] + + # Streaming exponentials (exact, prediagonalized symmetric tridiagonals) + T_H = hermite_streaming_matrix(Nh, u, alpha) + leg = legendre_constants(Nl, v_a, v_b) + hermite_stream = StreamingExp1D(T_H, prefactor=-1j * alpha, kx_1d=kx_1d) + legendre_stream = StreamingExp1D(np.asarray(leg["T_L"]), prefactor=-1j, kx_1d=kx_1d) + + hermite_coll = DiagonalCollisionExp1D(nu_H, safe_col(Nh)) + legendre_coll = DiagonalCollisionExp1D(nu_L, safe_col(Nl)) + combined_exp = CombinedLinearExp1D(hermite_stream, legendre_stream, hermite_coll, legendre_coll) + + poisson = PoissonSolver1D(one_over_kx=one_over_kx, kx_sq=kx_sq, alpha=alpha, width=width) + + # Explicit-term constants + n = jnp.arange(Nh, dtype=jnp.float64) + sqrt_2n_over_alpha = jnp.sqrt(2.0 * n) / alpha + gamma_vec = jnp.where(jnp.arange(Nl) >= 3, gamma, 0.0) + + J = hermite_legendre_coupling_vector(Nh, Nl, alpha, u, v_a, v_b, enforce_conservation=enforce) + coupling_vec = -(alpha / width) * jnp.sqrt(Nh / 2.0) * J # folds prefactor into J_{Nh,m} + + vector_field = HermiteLegendre1DVectorField( + combined_exp=combined_exp, + poisson=poisson, + kx_1d=kx_1d, + sqrt_2n_over_alpha=sqrt_2n_over_alpha, + deriv=leg["deriv"], + gamma_vec=gamma_vec, + xi_a=leg["xi_a"], + xi_b=leg["xi_b"], + coupling_vec=coupling_vec, + alpha=alpha, + width=width, + dt=dt, + mask23=mask23, + field_on=field_on, + ) + + self.cfg = get_save_quantities(self.cfg) + tmax = float(grid["tmax"]) + max_steps = int(grid["max_steps"]) + + subsaves = {} + for k, v in self.cfg["save"].items(): + if isinstance(v, dict) and "t" in v and "func" in v: + subsaves[k] = SubSaveAt(ts=v["t"]["ax"], fn=v["func"]) + + self.time_quantities = {"t0": 0.0, "t1": tmax, "max_steps": max_steps} + self.diffeqsolve_quants = { + "terms": ODETerm(vector_field), + "solver": Stepper(), + "saveat": {"subs": subsaves}, + } + + # ------------------------------------------------------------------ + def __call__(self, trainable_modules: dict, args: dict | None = None) -> dict: + if args is None: + args = self.args + grid = self.cfg["grid"] + sol = diffeqsolve( + terms=self.diffeqsolve_quants["terms"], + solver=self.diffeqsolve_quants["solver"], + stepsize_controller=ConstantStepSize(), + t0=self.time_quantities["t0"], + t1=self.time_quantities["t1"], + dt0=float(grid["dt"]), + y0=self.state, + args=args, + saveat=SaveAt(**self.diffeqsolve_quants["saveat"]), + max_steps=self.time_quantities["max_steps"], + progress_meter=TqdmProgressMeter() if sys.stdout.isatty() else NoProgressMeter(), + ) + return {"solver result": sol} + + # ------------------------------------------------------------------ + def post_process(self, run_output: dict, td: str) -> dict: + import matplotlib.pyplot as plt + + sol = run_output["solver result"] + binary_dir = os.path.join(td, "binary") + plots_dir = os.path.join(td, "plots") + os.makedirs(binary_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + x = np.asarray(self.cfg["grid"]["x"]) + datasets = {} + metrics = {"simulation_completed": True} + + for k in sol.ys.keys(): + t_arr = np.asarray(sol.ts[k]) + data = sol.ys[k] + + if k == "fields": + fields_dict = {name: np.asarray(arr) for name, arr in data.items()} + datasets["fields"] = store_fields_timeseries(fields_dict, t_arr, binary_dir, x) + for field_name, arr in fields_dict.items(): + if arr.ndim == 2: + try: + arr_plot = np.where(np.isfinite(arr), arr, np.nan) + fig, ax = plt.subplots(figsize=(9, 5), tight_layout=True) + im = ax.pcolormesh(x, t_arr, arr_plot, shading="auto", cmap="RdBu_r") + plt.colorbar(im, ax=ax) + ax.set_xlabel("x [norm]") + ax.set_ylabel("t [norm]") + ax.set_title(field_name) + fig.savefig(os.path.join(plots_dir, f"spacetime-{field_name}.png"), bbox_inches="tight") + except Exception as e: + print(f"post_process: spacetime plot for {field_name} failed: {e}") + finally: + plt.close("all") + + elif k in ("hermite", "legendre"): + name = "Ck" if k == "hermite" else "Bk" + arr = np.asarray(data[name]) + datasets[k] = store_coeff_timeseries(name, arr, t_arr, binary_dir) + + elif k == "default": + scalars = {name: np.asarray(arr) for name, arr in data.items()} + for name, arr in scalars.items(): + if arr.ndim == 1: + final_val = float(arr[-1]) + metrics[f"final_{name}"] = final_val if np.isfinite(final_val) else float("nan") + # relative drift for conserved invariants + if name in ("mass", "momentum", "energy") and np.isfinite(arr[0]) and abs(arr[0]) > 0: + metrics[f"reldrift_{name}"] = float(np.max(np.abs(arr - arr[0])) / abs(arr[0])) + try: + arr_plot = np.where(np.isfinite(arr), arr, np.nan) + fig, axes = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True) + axes[0].plot(t_arr, arr_plot) + axes[0].set_xlabel("t") + axes[0].set_ylabel(name) + axes[0].grid(alpha=0.3) + axes[1].semilogy(t_arr, np.abs(arr_plot) + 1e-30) + axes[1].set_xlabel("t") + axes[1].set_ylabel(f"|{name}|") + axes[1].grid(alpha=0.3) + fig.savefig(os.path.join(plots_dir, f"scalar-{name}.png"), bbox_inches="tight") + except Exception as e: + print(f"post_process: scalar plot for {name} failed: {e}") + finally: + plt.close("all") + + if "default" in sol.ts: + metrics["n_timesteps"] = len(sol.ts["default"]) + + return {"metrics": metrics, "datasets": datasets} diff --git a/adept/_hermite_legendre_1d/storage.py b/adept/_hermite_legendre_1d/storage.py new file mode 100644 index 00000000..c423ce9d --- /dev/null +++ b/adept/_hermite_legendre_1d/storage.py @@ -0,0 +1,189 @@ +""" +Storage / save functions for the 1D mixed Hermite-Legendre solver. + +State keys: Ck (Nh, Nx) complex viewed as float64, Bk (Nl, Nx) complex viewed as +float64, e (Nx,), phi (Nx,). + +Supported cfg["save"] keys: + "fields" -> {e, phi} spacetime arrays + "hermite" -> Ck Hermite-Fourier coefficient timeseries + "legendre" -> Bk Legendre-Fourier coefficient timeseries + "default" -> scalar invariants (mass, momentum, energy) + field energy. Always added. + +The scalar invariants follow the analytic definitions of the mixed method (paper +eqns 26, 28, 30-31). With the conservation constraint J_{Nh,0..2}=0 enforced they +are conserved to solver tolerance, which is the primary correctness gate. +""" + +import os + +import numpy as np +import xarray as xr +from jax import numpy as jnp + +# --------------------------------------------------------------------------- +# Save functions (called by diffrax SubSaveAt at specified timesteps) +# --------------------------------------------------------------------------- + + +def get_fields_save_func(): + """Save electric field e and potential phi.""" + + def fields_save_func(t, y, args): + return {"e": y["e"], "phi": y["phi"]} + + return fields_save_func + + +def get_hermite_save_func(): + def hermite_save_func(t, y, args): + return {"Ck": y["Ck"].view(jnp.complex128)} + + return hermite_save_func + + +def get_legendre_save_func(): + def legendre_save_func(t, y, args): + return {"Bk": y["Bk"].view(jnp.complex128)} + + return legendre_save_func + + +def get_default_save_func( + alpha: float, u: float, width: float, sigma1: float, sigma2: float, sigma_bar: float, Lx: float +): + """Save scalar invariants: total mass, momentum, energy, and field energy. + + Spatial integrals reduce to the domain length times the k=0 Fourier mode, e.g. + integral C_n dx = Lx * Re(Ck[n, 0]) since norm="forward" puts the mean in [0]. + """ + alpha = float(alpha) + u = float(u) + width = float(width) + sigma1 = float(sigma1) + sigma2 = float(sigma2) + sigma_bar = float(sigma_bar) + Lx = float(Lx) + + def default_save_func(t, y, args): + Ck = y["Ck"].view(jnp.complex128) + Bk = y["Bk"].view(jnp.complex128) + e = y["e"] + nx = e.shape[0] + dx_local = Lx / nx + + def intg(arr_k, n): + return Lx * jnp.real(arr_k[n, 0]) if arr_k.shape[0] > n else 0.0 + + C0, C1, C2 = intg(Ck, 0), intg(Ck, 1), intg(Ck, 2) + B0, B1, B2 = intg(Bk, 0), intg(Bk, 1), intg(Bk, 2) + + mass = alpha * C0 + width * B0 + momentum = (alpha**2 / jnp.sqrt(2.0)) * C1 + u * alpha * C0 + width * sigma1 * B1 + sigma_bar * width * B0 + e_kin = (alpha / 2.0) * ( + (alpha**2 / jnp.sqrt(2.0)) * C2 + jnp.sqrt(2.0) * u * alpha * C1 + (alpha**2 / 2.0 + u**2) * C0 + ) + (width / 2.0) * (sigma2 * sigma1 * B2 + 2.0 * sigma1 * sigma_bar * B1 + (sigma1**2 + sigma_bar**2) * B0) + e_pot = 0.5 * jnp.sum(e**2) * dx_local + energy = e_kin + e_pot + + # density extrema for blow-up monitoring + n_f0 = alpha * jnp.fft.ifft(Ck[0], norm="forward").real + n_df = width * jnp.fft.ifft(Bk[0], norm="forward").real + n_e = n_f0 + n_df + + return { + "mass": mass, + "momentum": momentum, + "energy": energy, + "e_energy": e_pot, + "n_e_max": jnp.max(n_e), + "n_e_min": jnp.min(n_e), + "Bk_max": jnp.max(jnp.abs(Bk)), + "Ck_max": jnp.max(jnp.abs(Ck)), + } + + return default_save_func + + +# --------------------------------------------------------------------------- +# Configure save axes and attach save functions +# --------------------------------------------------------------------------- + + +def get_save_quantities(cfg: dict) -> dict: + """Attach time axes and save functions to cfg["save"]. Modifies cfg in place.""" + grid = cfg["grid"] + physics = cfg["physics"] + + tmax = float(grid["tmax"]) + nt = int(grid["nt"]) + + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + sigma1 = (width / 2.0) * 1.0 / np.sqrt(3.0 * 1.0) # n=1 + sigma2 = (width / 2.0) * 2.0 / np.sqrt(5.0 * 3.0) # n=2 + sigma_bar = 0.5 * (v_a + v_b) + Lx = float(physics["Lx"]) + + for save_key, save_cfg in cfg.get("save", {}).items(): + if not isinstance(save_cfg, dict): + continue + if "t" in save_cfg and isinstance(save_cfg["t"], dict): + t_cfg = save_cfg["t"] + if "ax" not in t_cfg: + t_cfg["ax"] = np.linspace( + float(t_cfg.get("tmin", 0.0)), float(t_cfg.get("tmax", tmax)), int(t_cfg.get("nt", nt)) + ) + if "func" in save_cfg: + continue + if save_key == "fields": + save_cfg["func"] = get_fields_save_func() + elif save_key == "hermite": + save_cfg["func"] = get_hermite_save_func() + elif save_key == "legendre": + save_cfg["func"] = get_legendre_save_func() + + if "save" not in cfg: + cfg["save"] = {} + if "default" not in cfg["save"]: + cfg["save"]["default"] = {"t": {"ax": np.linspace(0.0, tmax, nt)}} + elif "t" not in cfg["save"]["default"]: + cfg["save"]["default"]["t"] = {"ax": np.linspace(0.0, tmax, nt)} + elif "ax" not in cfg["save"]["default"]["t"]: + cfg["save"]["default"]["t"]["ax"] = np.linspace(0.0, tmax, nt) + + cfg["save"]["default"]["func"] = get_default_save_func(alpha, u, width, sigma1, sigma2, sigma_bar, Lx) + return cfg + + +# --------------------------------------------------------------------------- +# Post-processing storage helpers +# --------------------------------------------------------------------------- + + +def store_fields_timeseries(fields_dict: dict, t_array: np.ndarray, binary_dir: str, x: np.ndarray) -> xr.Dataset: + das = { + k: xr.DataArray(np.asarray(v), coords=[("t", t_array), ("x", x)], name=k) + for k, v in fields_dict.items() + if np.asarray(v).ndim == 2 + } + ds = xr.Dataset(das) + ds.to_netcdf(os.path.join(binary_dir, f"fields-t={round(float(t_array[-1]), 4)}.nc")) + return ds + + +def store_coeff_timeseries(name: str, arr: np.ndarray, t_array: np.ndarray, binary_dir: str) -> xr.Dataset: + """Save (nt, Nmodes, Nx) complex coefficient timeseries to netCDF.""" + Nx = arr.shape[-1] + Nmodes = arr.shape[-2] + kx = np.fft.fftshift(np.fft.fftfreq(Nx, d=1.0 / Nx)) + arr_shifted = np.fft.fftshift(arr, axes=-1) + ds = xr.Dataset( + {name: (["t", "mode", "kx"], arr_shifted)}, + coords={"t": t_array, "mode": np.arange(Nmodes), "kx": kx}, + ) + ds.to_netcdf(os.path.join(binary_dir, f"{name}-t={round(float(t_array[-1]), 4)}.nc")) + return ds diff --git a/adept/_hermite_legendre_1d/vector_field.py b/adept/_hermite_legendre_1d/vector_field.py new file mode 100644 index 00000000..0324c71e --- /dev/null +++ b/adept/_hermite_legendre_1d/vector_field.py @@ -0,0 +1,490 @@ +""" +Core vector field for the 1D mixed Hermite-Legendre Vlasov-Poisson solver. + +Implements the mixed method of Issan, Delzanno & Roytershteyn (arXiv:2606.12322, +"Mixed Hermite-Legendre spectral method for kinetic plasma simulations"). + +The electron distribution is split f = f0 + df, with + f0(x, v, t) = sum_{n=0}^{Nh-1} C_n(x, t) psi_n(v; alpha, u) [AW Hermite] + df(x, v, t) = sum_{m=0}^{Nl-1} B_m(x, t) xi_m(v; v_a, v_b) [Legendre] + +evolved by the coupled system (paper eqns 23-25) on a periodic spatial domain. +Spatial dependence is carried in Fourier space (d_x -> i*kx); velocity is spectral. + +Normalization (paper sec 2.1): t*wpe, x/lambda_D, v/vthe, phi*e*lambda_D/Te, so +the electric field is E = -d_x phi and the ion background density is 1. + +Integration: Lawson-RK4 (free-streaming exact for both bases via prediagonalized +symmetric-tridiagonal streaming matrices; E-field force, Dirichlet penalty, and the +Hermite->Legendre coupling are explicit). Uses the Stepper (discrete-map) convention +from adept._base_, mirroring adept._hermite_poisson_1d.vector_field. +""" + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +# --------------------------------------------------------------------------- +# Basis constants +# --------------------------------------------------------------------------- + + +def safe_col(N: int) -> Array: + """Artificial-collision profile col[n] = n(n-1)(n-2) / ((N-1)(N-2)(N-3)). + + This is the Lenard-Bernstein-based hyper-collision spectrum of paper sec 2.5: + cubic in the mode index, normalized to 1 at n = N-1, and identically zero for + n = 0, 1, 2 so the operator conserves mass, momentum, and energy exactly. The + damping is applied per Lawson substep as exp(-nu * col[n] * s). + """ + n = jnp.arange(N, dtype=jnp.float64) + term = n * (n - 1) * (n - 2) + denom = (N - 1) * (N - 2) * (N - 3) if N > 3 else 1.0 + return jnp.where(N > 3, term / denom, jnp.zeros(N, dtype=jnp.float64)) + + +def hermite_streaming_matrix(Nh: int, u: float, alpha: float) -> np.ndarray: + """Symmetric tridiagonal T_H for AW-Hermite free streaming (paper eqn 7). + + d_t C_n + alpha*sqrt((n+1)/2) d_x C_{n+1} + alpha*sqrt(n/2) d_x C_{n-1} + + u d_x C_n = 0 ==> d_t C = -alpha * T_H * (d_x C), + with off-diagonal sqrt((n+1)/2) and diagonal u/alpha. The scalar prefactor + -i*alpha*kx is applied by StreamingExp1D. + """ + n = np.arange(Nh, dtype=np.float64) + T = np.zeros((Nh, Nh), dtype=np.float64) + if Nh > 1: + off = np.sqrt((n[:-1] + 1.0) / 2.0) # sqrt((n+1)/2), n = 0..Nh-2 + T += np.diag(off, 1) + np.diag(off, -1) + T += np.diag(np.full(Nh, u / alpha)) + return T + + +def legendre_constants(Nl: int, v_a: float, v_b: float) -> dict: + """Constants for the shifted/scaled Legendre basis xi_m on [v_a, v_b]. + + Returns: + T_L: symmetric tridiagonal velocity matrix (paper eqn 11), + off-diagonal sigma_{n+1}, diagonal sigma_bar = (v_a+v_b)/2. + sigma: sub/super-diagonal coefficients sigma_n (sigma_0 = 0). + sigma_bar: scalar (v_a+v_b)/2. + deriv: strictly-lower-triangular derivative matrix sigma_{m,i} (paper eqn 10), + d xi_m / dv = sum_{i 1: + sigma[1:] = (width / 2.0) * n[1:] / np.sqrt((2.0 * n[1:] + 1.0) * (2.0 * n[1:] - 1.0)) + + T_L = np.diag(np.full(Nl, sigma_bar)) + if Nl > 1: + off = sigma[1:] # T_L[m, m+1] = T_L[m+1, m] = sigma_{m+1} + T_L += np.diag(off, 1) + np.diag(off, -1) + + # Derivative matrix sigma_{m,i}: nonzero only when (m - i) is odd, i < m. + deriv = np.zeros((Nl, Nl), dtype=np.float64) + for m in range(Nl): + for i in range(m): + if (m - i) % 2 == 1: + deriv[m, i] = 2.0 * np.sqrt((2.0 * m + 1.0) * (2.0 * i + 1.0)) / width + + sqrt_2m1 = np.sqrt(2.0 * n + 1.0) + xi_b = sqrt_2m1 + xi_a = sqrt_2m1 * ((-1.0) ** n) + + return { + "T_L": T_L, + "sigma": jnp.asarray(sigma), + "sigma_bar": float(sigma_bar), + "deriv": jnp.asarray(deriv), + "xi_b": jnp.asarray(xi_b), + "xi_a": jnp.asarray(xi_a), + } + + +def _hermite_function_values(Nh_plus_1: int, v: np.ndarray, u: float, alpha: float) -> np.ndarray: + """psi_n(v; u, alpha) for n = 0..Nh, shape (Nh+1, len(v)). Used for J integrals. + + psi_n = (pi 2^n n!)^{-1/2} H_n(z) exp(-z^2), z = (v - u)/alpha, + built from the physicists' Hermite recurrence H_{n+1} = 2 z H_n - 2 n H_{n-1}. + """ + z = (v - u) / alpha + H = np.zeros((Nh_plus_1, v.size), dtype=np.float64) + H[0] = 1.0 + if Nh_plus_1 > 1: + H[1] = 2.0 * z + for k in range(1, Nh_plus_1 - 1): + H[k + 1] = 2.0 * z * H[k] - 2.0 * k * H[k - 1] + norm = 1.0 / np.sqrt(np.pi * (2.0 ** np.arange(Nh_plus_1)) * _factorials(Nh_plus_1)) + return norm[:, None] * H * np.exp(-(z**2))[None, :] + + +def _legendre_basis_values(Nl: int, v: np.ndarray, v_a: float, v_b: float) -> np.ndarray: + """xi_m(v; v_a, v_b) = sqrt(2m+1) L_m(s), s = (2v-(v_a+v_b))/(v_b-v_a), + for m = 0..Nl-1, shape (Nl, len(v)). Built from the Legendre recurrence.""" + s = (2.0 * v - (v_a + v_b)) / (v_b - v_a) + L = np.zeros((Nl, v.size), dtype=np.float64) + L[0] = 1.0 + if Nl > 1: + L[1] = s + for k in range(1, Nl - 1): + L[k + 1] = ((2.0 * k + 1.0) * s * L[k] - k * L[k - 1]) / (k + 1.0) + scale = np.sqrt(2.0 * np.arange(Nl) + 1.0) + return scale[:, None] * L + + +def _factorials(n: int) -> np.ndarray: + out = np.ones(n, dtype=np.float64) + for k in range(1, n): + out[k] = out[k - 1] * k + return out + + +def hermite_legendre_coupling_vector( + Nh: int, Nl: int, alpha: float, u: float, v_a: float, v_b: float, enforce_conservation: bool = True +) -> Array: + """J_{Nh, m} = integral_{v_a}^{v_b} psi_{Nh}(v; alpha, u) xi_m(v) dv, m = 0..Nl-1. + + This is the one-way coupling from the (closed-off) highest Hermite coefficient + C_{Nh-1} into the Legendre modes (paper eqn 24). Evaluated by Gauss-Legendre + quadrature on [v_a, v_b]. When ``enforce_conservation`` is set, J_{Nh,0..2} are + zeroed (paper sec 3.4/4): removing the coupling to the first three Legendre + coefficients makes the discrete method conserve mass, momentum, and energy + independent of Nh parity and the spectral shift/scale parameters. + """ + deg = max(4 * (Nh + Nl), 200) + nodes, weights = np.polynomial.legendre.leggauss(deg) + v = 0.5 * (v_b - v_a) * nodes + 0.5 * (v_b + v_a) + w = 0.5 * (v_b - v_a) * weights + + psi_Nh = _hermite_function_values(Nh + 1, v, u, alpha)[Nh] # (len(v),) + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + J = xi @ (psi_Nh * w) # (Nl,) + + if enforce_conservation: + J[: min(3, Nl)] = 0.0 + return jnp.asarray(J) + + +# --------------------------------------------------------------------------- +# Exact exponential operators +# --------------------------------------------------------------------------- + + +class StreamingExp1D: + """Exact free-streaming exponential exp(prefactor * kx * T * s) for one basis. + + Prediagonalizes the symmetric tridiagonal streaming matrix T (Hermite or + Legendre) once, then applies exp(L*s) to a coefficient array (Nmodes, Nx) by + rotating into the eigenbasis, scaling by exp(prefactor * kx * eigval * s), and + rotating back. Mirrors adept._hermite_poisson_1d.vector_field.FreeStreamingExp1D + (Hermite: prefactor = -i*alpha; Legendre: prefactor = -i). + """ + + def __init__(self, T: np.ndarray, prefactor: complex, kx_1d: Array): + self.prefactor = prefactor + self.kx_1d = kx_1d + eigvals, V = np.linalg.eigh(T) + self.V = jnp.asarray(V) + self.eigenvalues = jnp.asarray(eigvals) + + def apply(self, Ck: Array, s: float) -> Array: + C_eig = self.V.T @ Ck # (Nmodes, Nx) + exp_fac = jnp.exp(self.prefactor * s * self.eigenvalues[:, None] * self.kx_1d[None, :]) + return self.V @ (C_eig * exp_fac) + + +class DiagonalCollisionExp1D: + """Exact exponential exp(-nu * col[n] * s), diagonal in the mode index.""" + + def __init__(self, nu: float, col: Array): + self.nu = float(nu) + self.col = col + + def apply(self, Ck: Array, s: float) -> Array: + if self.nu == 0.0: + return Ck + return Ck * jnp.exp(-self.nu * self.col[:, None] * s) + + +class CombinedLinearExp1D: + """Applies exp(L*s) to the full {Ck, Bk} state; real diagnostics pass through.""" + + def __init__( + self, + hermite_stream: StreamingExp1D, + legendre_stream: StreamingExp1D, + hermite_coll: DiagonalCollisionExp1D, + legendre_coll: DiagonalCollisionExp1D, + ): + self.hermite_stream = hermite_stream + self.legendre_stream = legendre_stream + self.hermite_coll = hermite_coll + self.legendre_coll = legendre_coll + + def apply(self, state: dict, s: float) -> dict: + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + Ck = self.hermite_coll.apply(self.hermite_stream.apply(Ck, s), s) + Bk = self.legendre_coll.apply(self.legendre_stream.apply(Bk, s), s) + out = dict(state) + out["Ck"] = Ck.view(jnp.float64) + out["Bk"] = Bk.view(jnp.float64) + return out + + +# --------------------------------------------------------------------------- +# Poisson solver +# --------------------------------------------------------------------------- + + +class PoissonSolver1D: + """Spectral Poisson solve for the mixed-method charge density (paper eqn 25). + + rho(x) = 1 - alpha*C_0(x) - (v_b - v_a)*B_0(x) (immobile ion background = 1) + -d_x^2 phi = rho ==> phi_k = rho_k / kx^2, E = -d_x phi ==> E_k = -i rho_k / kx. + The k=0 component is set to zero (quasineutral domain). + """ + + def __init__(self, one_over_kx: Array, kx_sq: Array, alpha: float, width: float): + self.one_over_kx = one_over_kx + self.kx_sq = kx_sq + self.alpha = float(alpha) + self.width = float(width) + + def _rho_k(self, Ck: Array, Bk: Array) -> Array: + # C_0(x), B_0(x) from their Fourier representations (norm="forward" -> [0] is mean) + n_f0 = self.alpha * jnp.fft.ifft(Ck[0], norm="forward").real + n_df = self.width * jnp.fft.ifft(Bk[0], norm="forward").real + rho = 1.0 - n_f0 - n_df + return jnp.fft.fft(rho, norm="forward") + + def electric_field(self, Ck: Array, Bk: Array) -> Array: + E_k = -1j * self.one_over_kx * self._rho_k(Ck, Bk) + return jnp.fft.ifft(E_k, norm="forward").real + + def potential(self, Ck: Array, Bk: Array) -> Array: + rho_k = self._rho_k(Ck, Bk) + phi_k = jnp.where(self.kx_sq > 0, rho_k / jnp.where(self.kx_sq > 0, self.kx_sq, 1.0), 0.0) + return jnp.fft.ifft(phi_k, norm="forward").real + + +# --------------------------------------------------------------------------- +# Explicit nonlinear terms (E-field force, penalty, Hermite->Legendre coupling) +# --------------------------------------------------------------------------- + + +def _to_real(Ck: Array, mask23: Array | None) -> Array: + """IFFT (Nmodes, Nx) k-space coeffs to real space, optionally 2/3-dealiased.""" + if mask23 is not None: + Ck = Ck * mask23[None, :] + return jnp.fft.ifft(Ck, axis=-1, norm="forward") + + +def _hermite_force(Ck: Array, E: Array, sqrt_2n_over_alpha: Array, mask23: Array | None) -> Array: + """d_t C_n |force = -(sqrt(2n)/alpha) E(x) C_{n-1}(x) (electron, paper eqn 19/23). + + E = -d_x phi is the self-consistent field; the n-th coefficient is forced by the + (n-1)-th (Hermite differentiation raises the index). Returns k-space (Nh, Nx). + """ + Nh, Nx = Ck.shape + if mask23 is not None: + E = jnp.fft.ifft(jnp.fft.fft(E, norm="forward") * mask23, norm="forward").real + C = _to_real(Ck, mask23) + C_up = jnp.concatenate([jnp.zeros((1, Nx), dtype=C.dtype), C[:-1, :]], axis=0) # C_{n-1} + integrand = -sqrt_2n_over_alpha[:, None] * E[None, :] * C_up + out = jnp.fft.fft(integrand, axis=-1, norm="forward") + if mask23 is not None: + out = out * mask23[None, :] + return out + + +def _legendre_force( + Bk: Array, + E: Array, + deriv: Array, + gamma_vec: Array, + xi_a: Array, + xi_b: Array, + width: float, + mask23: Array | None, +) -> Array: + """d_t B_m |force = -E(x) [ sum_{i Array: + """Hermite -> Legendre coupling (paper eqn 24, RHS): + + d_t B_m |coupling = -(alpha/width) J_{Nh,m} sqrt(Nh/2) [ d_x C_{Nh-1} + (2/alpha^2) E C_{Nh-1} ]. + + The d_x C_{Nh-1} part is linear (i*kx in Fourier); the E*C_{Nh-1} part is the + nonlinear product. Returns k-space (Nl, Nx). coupling_vec already folds in + -(alpha/width) sqrt(Nh/2) J_{Nh,m} (and the J_{Nh,0..2}=0 conservation gate). + """ + Nh = Ck.shape[0] + C_last_k = Ck[Nh - 1] # (Nx,) k-space + dx_C = 1j * kx_1d * C_last_k # d_x C_{Nh-1} in k-space + + C_last_real = _to_real(Ck[Nh - 1 : Nh], mask23)[0] # (Nx,) + if mask23 is not None: + E = jnp.fft.ifft(jnp.fft.fft(E, norm="forward") * mask23, norm="forward").real + nl_real = (2.0 / alpha**2) * E * C_last_real + nl_k = jnp.fft.fft(nl_real, norm="forward") + if mask23 is not None: + nl_k = nl_k * mask23 + + bracket = dx_C + nl_k # (Nx,) k-space + return coupling_vec[:, None] * bracket[None, :] # (Nl, Nx) + + +# --------------------------------------------------------------------------- +# Lawson-RK4 vector field (Stepper / discrete-map convention) +# --------------------------------------------------------------------------- + + +def _tree_add(a: dict, b: dict) -> dict: + return jax.tree.map(lambda x, y: x + y, a, b) + + +def _tree_scale(a: dict, c: float) -> dict: + return jax.tree.map(lambda x: c * x, a) + + +class HermiteLegendre1DVectorField: + """Advances the mixed Hermite-Legendre state by one timestep dt. + + State dict (complex arrays stored as float64 views, like _hermite_poisson_1d): + Ck: (Nh, Nx) Hermite-Fourier coefficients of f0 + Bk: (Nl, Nx) Legendre-Fourier coefficients of df + e: (Nx,) electric field (diagnostic) + phi: (Nx,) potential (diagnostic) + + Called by adept._base_.Stepper, which treats the return value as the new state. + """ + + def __init__( + self, + combined_exp: CombinedLinearExp1D, + poisson: PoissonSolver1D, + kx_1d: Array, + sqrt_2n_over_alpha: Array, + deriv: Array, + gamma_vec: Array, + xi_a: Array, + xi_b: Array, + coupling_vec: Array, + alpha: float, + width: float, + dt: float, + mask23: Array | None = None, + field_on: bool = True, + ): + self.combined_exp = combined_exp + self.poisson = poisson + self.kx_1d = kx_1d + self.sqrt_2n_over_alpha = sqrt_2n_over_alpha + self.deriv = deriv + self.gamma_vec = gamma_vec + self.xi_a = xi_a + self.xi_b = xi_b + self.coupling_vec = coupling_vec + self.alpha = float(alpha) + self.width = float(width) + self.dt = float(dt) + self.mask23 = mask23 + self.field_on = bool(field_on) + + def _nonlinear_rhs(self, t: float, state: dict, args: dict) -> dict: + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + + # field_on=False -> pure advection (phi=0); the linear Hermite->Legendre + # closure flux (d_x C_{Nh-1}) still acts, only E-dependent terms vanish. + E = self.poisson.electric_field(Ck, Bk) if self.field_on else jnp.zeros(Ck.shape[1]) # (Nx,) + + dCk = _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) + dBk = _legendre_force( + Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23 + ) + _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + + out = dict(state) + out["Ck"] = dCk.view(jnp.float64) + out["Bk"] = dBk.view(jnp.float64) + for k in ("e", "phi"): + if k in out: + out[k] = jnp.zeros_like(state[k]) + return out + + def _lawson_rk4(self, t: float, state: dict, args: dict) -> dict: + dt = self.dt + exp_L = self.combined_exp.apply + + Eh_y = exp_L(state, dt / 2) + Ef_y = exp_L(state, dt) + + N1 = self._nonlinear_rhs(t, state, args) + Eh_N1 = exp_L(N1, dt / 2) + y_star = _tree_add(Eh_y, _tree_scale(Eh_N1, dt / 2)) + N2 = self._nonlinear_rhs(t + dt / 2, y_star, args) + + y_dstar = _tree_add(Eh_y, _tree_scale(N2, dt / 2)) + N3 = self._nonlinear_rhs(t + dt / 2, y_dstar, args) + + Eh_N3 = exp_L(N3, dt / 2) + y_tstar = _tree_add(Ef_y, _tree_scale(Eh_N3, dt)) + N4 = self._nonlinear_rhs(t + dt, y_tstar, args) + + Ef_N1 = exp_L(N1, dt) + Eh_N2 = exp_L(N2, dt / 2) + + weighted = _tree_scale( + _tree_add( + _tree_add(Ef_N1, _tree_scale(Eh_N2, 2.0)), + _tree_add(_tree_scale(Eh_N3, 2.0), N4), + ), + dt / 6.0, + ) + return _tree_add(Ef_y, weighted) + + def __call__(self, t: float, y: dict, args: dict) -> dict: + y_new = self._lawson_rk4(t, y, args) + Ck = y_new["Ck"].view(jnp.complex128) + Bk = y_new["Bk"].view(jnp.complex128) + if self.field_on: + e = self.poisson.electric_field(Ck, Bk) + phi = self.poisson.potential(Ck, Bk) + else: + e = jnp.zeros(Ck.shape[1]) # phi = 0 : pure advection + phi = jnp.zeros(Ck.shape[1]) + return {"Ck": y_new["Ck"], "Bk": y_new["Bk"], "e": e, "phi": phi} diff --git a/adept/hermite_legendre_1d.py b/adept/hermite_legendre_1d.py new file mode 100644 index 00000000..967ba447 --- /dev/null +++ b/adept/hermite_legendre_1d.py @@ -0,0 +1,5 @@ +"""Public entry point for the mixed Hermite-Legendre 1D ADEPT module.""" + +from ._hermite_legendre_1d.modules import BaseHermiteLegendre1D + +__all__ = ["BaseHermiteLegendre1D"] diff --git a/configs/hermite-legendre-1d/bump-on-tail.yaml b/configs/hermite-legendre-1d/bump-on-tail.yaml new file mode 100644 index 00000000..e6d7a244 --- /dev/null +++ b/configs/hermite-legendre-1d/bump-on-tail.yaml @@ -0,0 +1,54 @@ +# Mixed Hermite-Legendre 1D --- Bump-on-tail instability (paper sec 4.4) +# +# f0(x,v,0) = (1-n_b)/sqrt(2*pi) [1 + eps cos(x/10)] exp(-v^2/2) +# df(x,v,0) = n_b/sqrt(2*pi) exp(-(v-10)^2/2) (beam projected onto Legendre) +# The beam is well separated from the bulk, so f0<->df coupling is weak and a large +# Hermite collision rate nu_H = 10 is acceptable. Parameters match Chapurin et al. + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: bump-on-tail + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 62.83185307179586 # 20*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: 4.0 + v_b: 15.0 + gamma: 0.5 + nu_H: 10.0 + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 128 + Nh: 128 + Nl: 128 + tmax: 120.0 + dt: 0.01 + +initialization: + type: bump-on-tail + eps: 1.0e-4 + mode: 1 + n_beam: 0.01 + v_drift: 10.0 + v_th: 1.0 + +save: + fields: + t: + nt: 481 + hermite: + t: + nt: 61 + legendre: + t: + nt: 61 diff --git a/configs/hermite-legendre-1d/linear-advection.yaml b/configs/hermite-legendre-1d/linear-advection.yaml new file mode 100644 index 00000000..933cd89b --- /dev/null +++ b/configs/hermite-legendre-1d/linear-advection.yaml @@ -0,0 +1,51 @@ +# Mixed Hermite-Legendre 1D --- Linear advection benchmark (paper sec 4.2) +# +# Pure advection (phi = 0): d_t f + v d_x f = 0, analytic f(x,v,t)=f(x-vt,v,0). +# f(x,v,0) = (1+cos x)/sqrt(2*pi) exp(-v^2/2) -> C_0(x,0)=(1+cos x)/sqrt(2), df=0. +# The Legendre representation of df grows to compensate Hermite recurrence (Fig 3), +# and the recurrence period matches the better of the two bases (Fig 4). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: linear-advection + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 6.283185307179586 # 2*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -5.0 + v_b: 5.0 + gamma: 0.5 + nu_H: 0.0 + nu_L: 0.0 + enforce_conservation: true + field: false # phi = 0 : pure linear advection + +grid: + Nx: 64 + Nh: 100 + Nl: 100 + tmax: 20.0 + dt: 0.01 + +initialization: + type: linear-advection + eps: 1.0 + mode: 1 + +save: + fields: + t: + nt: 201 + hermite: + t: + nt: 41 + legendre: + t: + nt: 41 diff --git a/configs/hermite-legendre-1d/two-stream.yaml b/configs/hermite-legendre-1d/two-stream.yaml new file mode 100644 index 00000000..c62a22b6 --- /dev/null +++ b/configs/hermite-legendre-1d/two-stream.yaml @@ -0,0 +1,51 @@ +# Mixed Hermite-Legendre 1D --- Two-stream instability (paper sec 4.3) +# +# f(x,v,0) = (1+0.01 cos(0.5 x))/sqrt(2*pi) v^2 exp(-v^2/2) +# -> C_0(x,0)=(1+0.01 cos(0.5 x))/sqrt(2), C_2 = sqrt(2) C_0, df=0. +# The phase-space vortex (strongly non-Maxwellian) is captured by the Legendre +# component; with J_{Nh,0..2}=0 enforced, mass/momentum/energy are conserved (Fig 7c/d). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: two-stream + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 12.566370614359172 # 4*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 # keep nu_H -> 0 so f0 feeds df via the last Hermite moment + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 64 + Nh: 85 + Nl: 171 + tmax: 35.0 + dt: 0.01 + +initialization: + type: two-stream + eps: 0.01 + mode: 1 + +save: + fields: + t: + nt: 351 + hermite: + t: + nt: 71 + legendre: + t: + nt: 71 diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index f1bebac6..1cbd2b37 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -17,3 +17,4 @@ Quick links to configuration references: - [Vlasov-2D Config](source/solvers/vlasov2d/config.md) - [LPSE-2D Config](source/solvers/lpse2d/config.md) - [Spectrax-1D Config](source/solvers/spectrax1d/config.md) +- [Hermite-Legendre-1D Config](source/solvers/hermite_legendre_1d/config.md) diff --git a/docs/RUNNING_A_SIM.md b/docs/RUNNING_A_SIM.md index ae50b6b3..7c834954 100644 --- a/docs/RUNNING_A_SIM.md +++ b/docs/RUNNING_A_SIM.md @@ -12,5 +12,6 @@ uv run run.py --cfg path_to_my_config - [Vlasov-2D](source/solvers/vlasov2d/config.md) - 2D2V Vlasov-Maxwell solver - [LPSE-2D (Envelope-2D)](source/solvers/lpse2d/config.md) - 2D laser-plasma envelope solver for TPD/SRS - [Spectrax-1D](source/solvers/spectrax1d/config.md) - 1D Hermite-Fourier Vlasov-Maxwell solver +- [Hermite-Legendre-1D](source/solvers/hermite_legendre_1d/config.md) - 1D-1V mixed Hermite-Legendre electrostatic Vlasov-Poisson solver See the [full documentation](https://ergodicio.github.io/adept/) for detailed guides and API reference. diff --git a/docs/source/solvers/hermite_legendre_1d/config.md b/docs/source/solvers/hermite_legendre_1d/config.md new file mode 100644 index 00000000..7c20257b --- /dev/null +++ b/docs/source/solvers/hermite_legendre_1d/config.md @@ -0,0 +1,136 @@ +# Mixed Hermite-Legendre 1D Configuration Reference + +This document describes how to construct a configuration file for the +`hermite-legendre-1d` solver, which implements the **mixed Hermite-Legendre +spectral method** for the 1D-1V electrostatic Vlasov-Poisson system (Issan, +Delzanno & Roytershteyn, arXiv:2606.12322). + +The electron distribution is split `f = f0 + df`: + +- `f0` (near-Maxwellian bulk) is expanded in the **asymmetrically-weighted (AW) + Hermite** basis in velocity, with coefficients `C_n(x, t)`, `n = 0 .. Nh-1`. +- `df` (strongly non-Maxwellian features: beams, plateaus, filamentation) is + expanded in the **Legendre** basis on a bounded velocity window `[v_a, v_b]`, + with coefficients `B_m(x, t)`, `m = 0 .. Nl-1`. + +The highest Hermite coefficient `C_{Nh-1}` feeds the Legendre modes (one-way +coupling), and both feed the self-consistent field through Poisson. The method is +most accurate, at fixed total velocity DOFs, when non-Maxwellian features are +localized in velocity. + +**Normalization** (paper sec 2.1): time by `1/ω_pe`, space by the Debye length +`λ_D`, velocity by the electron thermal velocity `v_the`. A single electron species +is evolved against an immobile neutralizing ion background of density 1. + +**Numerics.** Space is treated spectrally (Fourier, periodic domain); both +free-streaming operators are symmetric-tridiagonal in mode index and integrated +*exactly* via prediagonalized matrix exponentials. The E-field force, the Legendre +Dirichlet penalty, and the Hermite→Legendre coupling are advanced explicitly with +**Lawson-RK4**. (The paper uses an implicit-midpoint integrator for machine-precision +energy conservation; this module uses an explicit integrator — energy is then +conserved to the time-integrator's order, which converges with `dt`, while mass and +momentum remain conserved to machine precision.) + +## Top-Level Structure + +```yaml +solver: hermite-legendre-1d +mlflow: ... +units: ... +physics: ... +grid: ... +initialization: ... +save: ... +``` + +--- + +## physics + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Lx` | float | — | Domain length in x (normalized to `λ_D`) | +| `alpha` | float | — | AW-Hermite velocity **scale** parameter `α` (the benchmarks use `√2`) | +| `u` | float | `0.0` | AW-Hermite velocity **shift** parameter `u` | +| `v_a`, `v_b` | float | — | Legendre velocity-window bounds (`df` is resolved on `[v_a, v_b]`) | +| `gamma` | float | `0.5` | Penalty coefficient `γ` for the weak Legendre Dirichlet BC (`df(v_a)=df(v_b)=0`). Applied only to modes `m ≥ 3` to preserve conservation. | +| `nu_H` | float | `0.0` | Artificial (Lenard-Bernstein) Hermite collision rate `ν_H`. Keep small/zero so `f0` can feed `df` through the last Hermite moment. | +| `nu_L` | float | `0.0` | Artificial Legendre collision rate `ν_L`. Controls filamentation/recurrence in `df`. | +| `enforce_conservation` | bool | `true` | Zero the coupling integrals `J_{Nh,0}=J_{Nh,1}=J_{Nh,2}=0` so the discrete method conserves mass, momentum, and energy independent of `Nh` parity and of `α, u` (paper sec 3.4/4). | +| `field` | bool | `true` | Self-consistent Poisson field. Set `false` for the pure linear-advection test (`φ = 0`); the linear Hermite→Legendre closure flux still acts. | + +The artificial collision operator (paper sec 2.5) uses the cubic spectrum +`col[n] = n(n-1)(n-2) / ((N-1)(N-2)(N-3))`, which is identically zero for +`n = 0, 1, 2` — so collisions never touch the mass/momentum/energy moments. + +--- + +## grid + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Nx` | int | — | Number of Fourier modes in x | +| `Nh` | int | — | Number of AW-Hermite modes for `f0` (closure by truncation: `C_{Nh}=0`) | +| `Nl` | int | — | Number of Legendre modes for `df` (closure by truncation: `B_{Nl}=0`) | +| `tmax` | float | — | Final simulation time (normalized). Snapped to an exact multiple of `dt`. | +| `dt` | float | `0.01` | Timestep | + +--- + +## initialization + +Selects how the initial `C_n(x)` and `B_m(x)` coefficients are built. + +| `type` | Parameters | Description | +|--------|-----------|-------------| +| `linear-advection` | `eps`, `mode` | `f0 = (1 + eps·cos(k x))/√(2π)·exp(-v²/2)`; `df = 0`. (`C_0 = n(x)/α`.) | +| `two-stream` | `eps`, `mode` | `f0 ∝ (1 + eps·cos(k x))·v²·exp(-v²/2)`: `C_0 = n(x)/α`, `C_2 = √2·C_0`; `df = 0`. | +| `bump-on-tail` | `eps`, `mode`, `n_beam`, `v_drift`, `v_th` | Bulk Maxwellian in `f0`; a drifting Gaussian beam `n_beam/(√(2π) v_th)·exp(-(v-v_drift)²/2v_th²)` projected onto Legendre as `df`. | +| `custom` | `hermite: {n: {base, eps, mode}}`, `df: {beams: [{amp, v_drift, v_th}], eps, mode}` | Generic Hermite coefficient profiles plus a beam/sum-of-Gaussians `df` projected onto Legendre. | + +Here `k = 2π·mode/Lx`. The Legendre projection uses Gauss-Legendre quadrature. + +--- + +## save + +Standard ADEPT `save` block with `t: {nt: ...}` (or `tmin`/`tmax`/`nt`) sub-axes. + +| Key | Contents | +|-----|----------| +| `fields` | Electric field `e(x,t)` and potential `phi(x,t)` | +| `hermite` | AW-Hermite-Fourier coefficient timeseries `Ck` (shape `nt × Nh × Nx`) | +| `legendre` | Legendre-Fourier coefficient timeseries `Bk` (shape `nt × Nl × Nx`) | +| `default` | Scalar invariants `mass`, `momentum`, `energy` (paper eqns 26, 28, 30-31) plus field energy and density extrema. Always added; the primary correctness gate. | + +`post_process` writes netCDF binaries and spacetime/scalar plots, and reports the +relative drift of each invariant as the metrics `reldrift_{mass,momentum,energy}`. + +--- + +## Example: two-stream instability + +```yaml +solver: hermite-legendre-1d +mlflow: {experiment: hermite-legendre-1d, run: two-stream} +units: {normalizing_density: 1e20/cc, normalizing_temperature: 1keV} +physics: + Lx: 12.566370614359172 # 4π + alpha: 1.4142135623730951 + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 + nu_L: 1.0 + enforce_conservation: true + field: true +grid: {Nx: 64, Nh: 85, Nl: 171, tmax: 35.0, dt: 0.01} +initialization: {type: two-stream, eps: 0.01, mode: 1} +save: + fields: {t: {nt: 351}} + legendre: {t: {nt: 71}} +``` + +See `configs/hermite-legendre-1d/` for the linear-advection, two-stream, and +bump-on-tail benchmark configurations. diff --git a/tests/test_hermite_legendre_1d/__init__.py b/tests/test_hermite_legendre_1d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_hermite_legendre_1d/test_conservation.py b/tests/test_hermite_legendre_1d/test_conservation.py new file mode 100644 index 00000000..758b7762 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_conservation.py @@ -0,0 +1,77 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Conservation gate for the mixed Hermite-Legendre solver (paper sec 3, Fig 7). + +With the constraint J_{Nh,0}=J_{Nh,1}=J_{Nh,2}=0 enforced, the self-consistent +(field-on) mixed method conserves total mass and momentum to machine precision and +total energy to the explicit time integrator's accuracy (the paper's machine- +precision energy relies on the implicit-midpoint integrator; the explicit Lawson-RK4 +used here is time-integration-limited but convergent). The run also stays finite. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run(cfg): + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"] + + +def _two_stream_cfg(enforce, Nh=48, Nl=48): + return { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -2.5, + "v_b": 2.5, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 1.0, + "enforce_conservation": enforce, + "field": True, + }, + "grid": {"Nx": 48, "Nh": Nh, "Nl": Nl, "tmax": 10.0, "dt": 0.01}, + "initialization": {"type": "two-stream", "eps": 0.05, "mode": 1}, + "save": {"default": {"t": {"nt": 100}}}, + "units": {}, + } + + +def test_two_stream_conserves_invariants_with_enforcement(): + sol = _run(_two_stream_cfg(enforce=True)) + d = sol.ys["default"] + drifts = {} + for k in ("mass", "momentum", "energy"): + a = np.asarray(d[k]) + assert np.all(np.isfinite(a)), f"{k} went non-finite" + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + drifts[k] = float(np.max(np.abs(a - a[0])) / base) + assert drifts["mass"] < 1e-10, drifts + assert drifts["momentum"] < 1e-10, drifts + assert drifts["energy"] < 1e-5, drifts # time-integration-limited (dt=0.01) + + +def test_collisions_only_damp_high_modes(): + """nu_L damps Legendre modes m>=3 only; mass/momentum/energy stay conserved + because the collision profile vanishes on the first three moments.""" + sol = _run(_two_stream_cfg(enforce=True)) + d = sol.ys["default"] + # mass uses C_0, B_0; momentum C_0,C_1,B_0,B_1; energy up to C_2,B_2 -- all + # below the collisional cutoff, so collisions must not spoil them. + for k, tol in (("mass", 1e-10), ("momentum", 1e-10)): + a = np.asarray(d[k]) + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + assert np.max(np.abs(a - a[0])) / base < tol diff --git a/tests/test_hermite_legendre_1d/test_linear_advection.py b/tests/test_hermite_legendre_1d/test_linear_advection.py new file mode 100644 index 00000000..90dcb800 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_linear_advection.py @@ -0,0 +1,105 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Linear-advection physics gate for the mixed Hermite-Legendre solver. + +With phi = 0 the Vlasov equation reduces to advection d_t f + v d_x f = 0, whose +exact solution is f(x, v, t) = f(x - v t, v, 0). Running the full lifecycle on a +coarse grid and reconstructing f = f0 + df from the spectral coefficients should +match the analytic solution at a time before the spectral velocity recurrence +(paper sec 4.2, Fig 3). Also gates that mass/momentum/energy are conserved to +machine precision under pure advection (all k=0 moments are time-invariant). +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D +from adept._hermite_legendre_1d.vector_field import _hermite_function_values, _legendre_basis_values + + +def _run(cfg): + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m, m(trainable_modules={})["solver result"] + + +def test_linear_advection_matches_analytic(): + Lx = 2.0 * np.pi + alpha = np.sqrt(2.0) + Nh = Nl = 48 + Nx = 48 + v_a, v_b = -6.0, 6.0 + t_check = 2.0 + + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": Lx, + "alpha": alpha, + "u": 0.0, + "v_a": v_a, + "v_b": v_b, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": False, + }, + "grid": {"Nx": Nx, "Nh": Nh, "Nl": Nl, "tmax": t_check, "dt": 0.01}, + "initialization": {"type": "linear-advection", "eps": 1.0, "mode": 1}, + "save": {"hermite": {"t": {"nt": 2}}, "legendre": {"t": {"nt": 2}}}, + "units": {}, + } + m, sol = _run(cfg) + x = np.asarray(m.cfg["grid"]["x"]) + + Ck = np.asarray(sol.ys["hermite"]["Ck"])[-1] # (Nh, Nx) k-space + Bk = np.asarray(sol.ys["legendre"]["Bk"])[-1] # (Nl, Nx) + C = np.fft.ifft(Ck, axis=-1, norm="forward").real # (Nh, Nx) + B = np.fft.ifft(Bk, axis=-1, norm="forward").real # (Nl, Nx) + + v = np.linspace(v_a + 0.1, v_b - 0.1, 120) + psi = _hermite_function_values(Nh, v, u=0.0, alpha=alpha) # (Nh, len(v)) + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + + f = C.T @ psi + B.T @ xi # (Nx, len(v)) + XX, VV = np.meshgrid(x, v, indexing="ij") + f_exact = (1.0 + np.cos(XX - VV * t_check)) / np.sqrt(2.0 * np.pi) * np.exp(-(VV**2) / 2.0) + + rel_l2 = np.linalg.norm(f - f_exact) / np.linalg.norm(f_exact) + assert rel_l2 < 0.02, f"linear advection rel L2 = {rel_l2:.4f}" + + +def test_advection_conserves_moments_to_machine_precision(): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 2.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -6.0, + "v_b": 6.0, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": False, + }, + "grid": {"Nx": 48, "Nh": 48, "Nl": 48, "tmax": 5.0, "dt": 0.02}, + "initialization": {"type": "linear-advection", "eps": 1.0, "mode": 1}, + "save": {"default": {"t": {"nt": 60}}}, + "units": {}, + } + _, sol = _run(cfg) + d = sol.ys["default"] + for k in ("mass", "momentum", "energy"): + a = np.asarray(d[k]) + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + assert np.max(np.abs(a - a[0])) / base < 1e-11, f"{k} not conserved under advection" diff --git a/tests/test_hermite_legendre_1d/test_streaming.py b/tests/test_hermite_legendre_1d/test_streaming.py new file mode 100644 index 00000000..8c0bc5cf --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_streaming.py @@ -0,0 +1,103 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Unit tests for the mixed Hermite-Legendre building blocks. + +Covers the prediagonalized free-streaming exponentials (Golub-Welsch: streaming +matrix eigenvalues are the Gauss quadrature nodes), the streaming round-trip, and +the analytic parity of the Hermite->Legendre coupling integrals J_{Nh,m} (paper +eqns 27/29/34), which underpin the conservation properties. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np +import pytest +from jax import numpy as jnp + +from adept._hermite_legendre_1d.vector_field import ( + StreamingExp1D, + hermite_legendre_coupling_vector, + hermite_streaming_matrix, + legendre_constants, + safe_col, +) + + +def test_hermite_matrix_eigvals_are_gauss_hermite_nodes(): + """Golub-Welsch: eigenvalues of T_H (u=0) equal the Gauss-Hermite nodes.""" + Nh = 24 + T = hermite_streaming_matrix(Nh, u=0.0, alpha=1.0) + eig = np.sort(np.linalg.eigvalsh(T)) + nodes = np.sort(np.polynomial.hermite.hermgauss(Nh)[0]) + np.testing.assert_allclose(eig, nodes, atol=1e-10) + + +def test_legendre_matrix_eigvals_are_gauss_legendre_nodes(): + """Eigenvalues of T_L equal the Gauss-Legendre nodes mapped to [v_a, v_b].""" + Nl, v_a, v_b = 20, -2.5, 2.5 + leg = legendre_constants(Nl, v_a, v_b) + eig = np.sort(np.linalg.eigvalsh(np.asarray(leg["T_L"]))) + nodes = np.sort(0.5 * (v_b - v_a) * np.polynomial.legendre.leggauss(Nl)[0] + 0.5 * (v_a + v_b)) + np.testing.assert_allclose(eig, nodes, atol=1e-10) + + +def test_streaming_exp_identity_and_roundtrip(): + """exp(L*0) is identity, and exp(L*s) exp(L*-s) returns the original array.""" + Nh, Nx = 16, 8 + kx = jnp.fft.fftfreq(Nx) * Nx * 2 * jnp.pi / (2 * jnp.pi) + T = hermite_streaming_matrix(Nh, u=0.3, alpha=1.4) + exp = StreamingExp1D(T, prefactor=-1j * 1.4, kx_1d=kx) + + rng = np.random.default_rng(0) + C = jnp.asarray(rng.standard_normal((Nh, Nx)) + 1j * rng.standard_normal((Nh, Nx))) + + np.testing.assert_allclose(np.asarray(exp.apply(C, 0.0)), np.asarray(C), atol=1e-12) + back = exp.apply(exp.apply(C, 0.37), -0.37) + np.testing.assert_allclose(np.asarray(back), np.asarray(C), atol=1e-10) + + +def test_legendre_derivative_matrix_sparsity(): + """sigma_{m,i} is strictly lower triangular and nonzero only when (m-i) is odd.""" + Nl, v_a, v_b = 12, -3.0, 3.0 + deriv = np.asarray(legendre_constants(Nl, v_a, v_b)["deriv"]) + for m in range(Nl): + for i in range(Nl): + if i >= m or (m - i) % 2 == 0: + assert deriv[m, i] == 0.0, (m, i) + else: + assert deriv[m, i] != 0.0, (m, i) + + +def test_collision_profile_conserves_first_three_moments(): + """safe_col is exactly zero for n=0,1,2 (mass/momentum/energy preserved).""" + col = np.asarray(safe_col(32)) + assert col[0] == 0.0 and col[1] == 0.0 and col[2] == 0.0 + assert np.all(col[3:] > 0.0) + assert col[-1] == pytest.approx(1.0) # normalized to 1 at n = N-1 + + +@pytest.mark.parametrize("Nh,vanish", [(25, (0, 2)), (24, (1,))]) +def test_coupling_integral_parity(Nh, vanish): + """J_{Nh,m} parity on a symmetric domain with u=0 (paper eqns 27/29/34): + odd Nh -> J_{Nh,0} = J_{Nh,2} = 0 (mass & energy); + even Nh -> J_{Nh,1} = 0 (momentum).""" + J = np.asarray( + hermite_legendre_coupling_vector( + Nh, Nl=10, alpha=np.sqrt(2.0), u=0.0, v_a=-3.0, v_b=3.0, enforce_conservation=False + ) + ) + for m in vanish: + assert abs(J[m]) < 1e-9, (Nh, m, J[m]) + # a non-vanishing entry should actually be present (sanity) + assert np.max(np.abs(J)) > 1e-6 + + +def test_enforce_conservation_zeros_first_three(): + J = np.asarray( + hermite_legendre_coupling_vector( + 30, Nl=10, alpha=np.sqrt(2.0), u=0.0, v_a=-2.5, v_b=2.5, enforce_conservation=True + ) + ) + assert np.all(J[:3] == 0.0) From e650abf38a985a0ae3f5f1c23f08bf74e20f027b Mon Sep 17 00:00:00 2001 From: archis Date: Mon, 15 Jun 2026 09:01:08 -0700 Subject: [PATCH 2/3] feat(hermite-legendre-1d): external Ex driver, IMEX integrator, Landau test Adds to the mixed Hermite-Legendre solver: - External longitudinal `ex` driver (ExternalExDriver), mirroring the vlasov1d / hermite_poisson `ex` driver: a prescribed E added to the velocity-space force (never enters Poisson), e.g. a resonant EPW kick. Saved as `de` in fields. - IMEX integrator (grid.integrator: imex). The explicit Lawson-RK4 step blows up once the E.d_v f Lorentz force gets stiff (operator norm ~Nl^2/width*|E| for the Legendre block, the dominant term). IMEX keeps streaming/collisions/closure-flux explicit and advances the Lorentz force with an unconditionally stable frozen-E Backward-Euler substep (first-order Lie split), mirroring _spectrax1d/imex_E.py. Two-stream then runs stably at dt=0.02 (vs explicit's 0.002), mass conserved to ~1e-12. Current impl uses a per-x dense solve; structured (bidiagonal/Woodbury) solve noted as the large-Nx optimization. - Driven Landau-damping test (test_landau_damping): drives a uniform Maxwellian at the resonant (k0, w0) and measures E_x(k1) ringing -- frequency matches the kinetic dispersion to <1%, damping to ~5-9% (finite-Nh Hermite). - IMEX stability test (test_imex): explicit blows up at dt=0.01, IMEX stays finite and conserves mass. - Fix: _hermite_function_values overflowed float64 at Nh>=171 (formed 2^n n! directly), silently zeroing the J_{Nh,m} coupling -- now a stable normalized recurrence. Seed the t=0 field diagnostics from the actual Poisson field. - two-stream.yaml dt 0.01 -> 0.002 (explicit CFL); add two-stream-imex.yaml. Suite: 16 passed. Two-stream works in both explicit and IMEX; bump-on-tail (the hardest case: asymmetric domain, weak instability, long time) still blows up near saturation in both -- needs the implicit-midpoint path (next commit). Co-Authored-By: Claude Opus 4.8 (1M context) --- adept/_hermite_legendre_1d/modules.py | 22 ++ adept/_hermite_legendre_1d/storage.py | 7 +- adept/_hermite_legendre_1d/vector_field.py | 199 ++++++++++++++++-- .../hermite-legendre-1d/two-stream-imex.yaml | 52 +++++ configs/hermite-legendre-1d/two-stream.yaml | 5 +- .../solvers/hermite_legendre_1d/config.md | 55 ++++- tests/test_hermite_legendre_1d/test_imex.py | 64 ++++++ .../test_landau_damping.py | 99 +++++++++ 8 files changed, 476 insertions(+), 27 deletions(-) create mode 100644 configs/hermite-legendre-1d/two-stream-imex.yaml create mode 100644 tests/test_hermite_legendre_1d/test_imex.py create mode 100644 tests/test_hermite_legendre_1d/test_landau_damping.py diff --git a/adept/_hermite_legendre_1d/modules.py b/adept/_hermite_legendre_1d/modules.py index 1dc9cc13..d9931f66 100644 --- a/adept/_hermite_legendre_1d/modules.py +++ b/adept/_hermite_legendre_1d/modules.py @@ -36,12 +36,15 @@ from adept._hermite_legendre_1d.vector_field import ( CombinedLinearExp1D, DiagonalCollisionExp1D, + ExternalExDriver, HermiteLegendre1DVectorField, PoissonSolver1D, StreamingExp1D, + hermite_force_operator, hermite_legendre_coupling_vector, hermite_streaming_matrix, legendre_constants, + legendre_force_operator, safe_col, ) @@ -230,6 +233,7 @@ def g(v): "Bk": Bk.view(jnp.float64), "e": e0, "phi": phi0, + "de": jnp.zeros(Nx), # external Ex driver field (diagnostic) } self.args = {} @@ -250,6 +254,8 @@ def init_diffeqsolve(self) -> None: nu_L = float(physics.get("nu_L", 0.0)) enforce = bool(physics.get("enforce_conservation", True)) field_on = bool(physics.get("field", True)) + integrator = str(grid.get("integrator", "lawson")).lower() + imex = integrator == "imex" dt = float(grid["dt"]) kx_1d = grid["kx_1d"] @@ -269,6 +275,10 @@ def init_diffeqsolve(self) -> None: poisson = PoissonSolver1D(one_over_kx=one_over_kx, kx_sq=kx_sq, alpha=alpha, width=width) + # External longitudinal (Ex) driver, e.g. a resonant EPW kick for Landau damping + ex_cfg = self.cfg.get("drivers", {}).get("ex", {}) + ex_driver = ExternalExDriver(grid["x"], ex_cfg) if ex_cfg else None + # Explicit-term constants n = jnp.arange(Nh, dtype=jnp.float64) sqrt_2n_over_alpha = jnp.sqrt(2.0 * n) / alpha @@ -277,6 +287,14 @@ def init_diffeqsolve(self) -> None: J = hermite_legendre_coupling_vector(Nh, Nl, alpha, u, v_a, v_b, enforce_conservation=enforce) coupling_vec = -(alpha / width) * jnp.sqrt(Nh / 2.0) * J # folds prefactor into J_{Nh,m} + # IMEX force operators (only built when integrator == "imex") + G_C = jnp.asarray(hermite_force_operator(Nh, alpha)) if imex else None + G_B = ( + jnp.asarray(legendre_force_operator(leg["deriv"], gamma_vec, leg["xi_a"], leg["xi_b"], width)) + if imex + else None + ) + vector_field = HermiteLegendre1DVectorField( combined_exp=combined_exp, poisson=poisson, @@ -292,6 +310,10 @@ def init_diffeqsolve(self) -> None: dt=dt, mask23=mask23, field_on=field_on, + ex_driver=ex_driver, + imex=imex, + G_C=G_C, + G_B=G_B, ) self.cfg = get_save_quantities(self.cfg) diff --git a/adept/_hermite_legendre_1d/storage.py b/adept/_hermite_legendre_1d/storage.py index c423ce9d..dfa82696 100644 --- a/adept/_hermite_legendre_1d/storage.py +++ b/adept/_hermite_legendre_1d/storage.py @@ -27,10 +27,13 @@ def get_fields_save_func(): - """Save electric field e and potential phi.""" + """Save electric field e, potential phi, and external Ex driver field de.""" def fields_save_func(t, y, args): - return {"e": y["e"], "phi": y["phi"]} + out = {"e": y["e"], "phi": y["phi"]} + if "de" in y: + out["de"] = y["de"] + return out return fields_save_func diff --git a/adept/_hermite_legendre_1d/vector_field.py b/adept/_hermite_legendre_1d/vector_field.py index 0324c71e..3e8ef2a4 100644 --- a/adept/_hermite_legendre_1d/vector_field.py +++ b/adept/_hermite_legendre_1d/vector_field.py @@ -25,6 +25,57 @@ import numpy as np from jax import Array +from adept._base_ import get_envelope + +# --------------------------------------------------------------------------- +# External longitudinal field driver +# --------------------------------------------------------------------------- + + +class ExternalExDriver: + """Prescribed longitudinal field E_drive(x, t) added to the velocity-space force. + + E_drive(x, t) = sum_pulses env(x, t) (w0+dw0) a0 sin(k0 x - (w0+dw0) t) + + The driver enters only the E.d_v f force term (it drives EPWs directly, e.g. a + resonant kick for a Landau-damping measurement) and never the Poisson solve, so + the self-consistent field energy diagnostic excludes it. Mirrors the convention + of adept._hermite_poisson_1d.vector_field.LongitudinalElectricFieldDriver and of + the vlasov1d ``ex`` driver. Reads cfg["drivers"]["ex"] (pulse_name -> pulse_dict). + Output shape (Nx,), on the interior x grid. + """ + + def __init__(self, x: Array, ex_driver_cfg: dict): + self.x = x + x_last = float(x[-1]) + parsed = [] + for _, pulse in ex_driver_cfg.items() if isinstance(ex_driver_cfg, dict) else []: + if not isinstance(pulse, dict): + continue + parsed.append( + ( + float(pulse["k0"]), + float(pulse["w0"]) + float(pulse.get("dw0", 0.0)), + float(pulse["a0"]), + float(pulse.get("t_center", 0.0)), + 0.5 * float(pulse.get("t_width", 1e10)), + float(pulse.get("t_rise", 0.0)), + float(pulse.get("x_center", 0.5 * x_last)), + 0.5 * float(pulse.get("x_width", 1e10)), + float(pulse.get("x_rise", 0.0)), + ) + ) + self.parsed_pulses = parsed + + def __call__(self, t: float, args) -> Array: + total = jnp.zeros_like(self.x) + for k0, w_total, a0, t_center, t_half, t_rise, x_center, x_half, x_rise in self.parsed_pulses: + env_t = get_envelope(t_rise, t_rise, t_center - t_half, t_center + t_half, t) + env_x = get_envelope(x_rise, x_rise, x_center - x_half, x_center + x_half, self.x) + total = total + env_t * env_x * w_total * a0 * jnp.sin(k0 * self.x - w_total * t) + return total + + # --------------------------------------------------------------------------- # Basis constants # --------------------------------------------------------------------------- @@ -111,18 +162,24 @@ def legendre_constants(Nl: int, v_a: float, v_b: float) -> dict: def _hermite_function_values(Nh_plus_1: int, v: np.ndarray, u: float, alpha: float) -> np.ndarray: """psi_n(v; u, alpha) for n = 0..Nh, shape (Nh+1, len(v)). Used for J integrals. - psi_n = (pi 2^n n!)^{-1/2} H_n(z) exp(-z^2), z = (v - u)/alpha, - built from the physicists' Hermite recurrence H_{n+1} = 2 z H_n - 2 n H_{n-1}. + psi_n = (pi 2^n n!)^{-1/2} H_n(z) exp(-z^2), z = (v - u)/alpha. Built from the + *normalized* three-term recurrence + + psi_0 = pi^{-1/2} exp(-z^2), psi_1 = sqrt(2) z psi_0, + psi_{n+1} = z sqrt(2/(n+1)) psi_n - sqrt(n/(n+1)) psi_{n-1}, + + which keeps every psi_n O(1). Forming H_n and 2^n n! separately (as a naive + implementation would) overflows float64 for n >= 171, silently zeroing the + high-order coupling integrals. """ z = (v - u) / alpha - H = np.zeros((Nh_plus_1, v.size), dtype=np.float64) - H[0] = 1.0 + psi = np.zeros((Nh_plus_1, v.size), dtype=np.float64) + psi[0] = np.pi**-0.5 * np.exp(-(z**2)) if Nh_plus_1 > 1: - H[1] = 2.0 * z - for k in range(1, Nh_plus_1 - 1): - H[k + 1] = 2.0 * z * H[k] - 2.0 * k * H[k - 1] - norm = 1.0 / np.sqrt(np.pi * (2.0 ** np.arange(Nh_plus_1)) * _factorials(Nh_plus_1)) - return norm[:, None] * H * np.exp(-(z**2))[None, :] + psi[1] = np.sqrt(2.0) * z * psi[0] + for n in range(1, Nh_plus_1 - 1): + psi[n + 1] = z * np.sqrt(2.0 / (n + 1)) * psi[n] - np.sqrt(n / (n + 1)) * psi[n - 1] + return psi def _legendre_basis_values(Nl: int, v: np.ndarray, v_a: float, v_b: float) -> np.ndarray: @@ -139,13 +196,6 @@ def _legendre_basis_values(Nl: int, v: np.ndarray, v_a: float, v_b: float) -> np return scale[:, None] * L -def _factorials(n: int) -> np.ndarray: - out = np.ones(n, dtype=np.float64) - for k in range(1, n): - out[k] = out[k - 1] * k - return out - - def hermite_legendre_coupling_vector( Nh: int, Nl: int, alpha: float, u: float, v_a: float, v_b: float, enforce_conservation: bool = True ) -> Array: @@ -368,6 +418,39 @@ def _cross_coupling( return coupling_vec[:, None] * bracket[None, :] # (Nl, Nx) +# --------------------------------------------------------------------------- +# Implicit (IMEX) force operators +# --------------------------------------------------------------------------- + + +def hermite_force_operator(Nh: int, alpha: float) -> np.ndarray: + """G_C with (G_C C)_n = -sqrt(2n)/alpha C_{n-1} (strictly lower-bidiagonal). + + The Hermite Lorentz force is dC/dt|force = E(x) G_C C. G_C is nilpotent, with + operator norm ~ sqrt(2 Nh)/alpha * |E|, so explicit RK4 has a CFL-like limit + that tightens with Nh. Backward Euler on it is unconditionally stable. + """ + G = np.zeros((Nh, Nh), dtype=np.float64) + n = np.arange(1, Nh) + G[n, n - 1] = -np.sqrt(2.0 * n) / alpha + return G + + +def legendre_force_operator(deriv: Array, gamma_vec: Array, xi_a: Array, xi_b: Array, width: float) -> np.ndarray: + """G_B = P - D for the Legendre Lorentz force dB/dt|force = E(x) G_B B. + + D = deriv (strictly lower-triangular d_v matrix, paper eqn 10) and P is the + rank-2 Dirichlet penalty P[m, j] = (gamma_m/width)(xi_b[m] xi_b[j] - xi_a[m] xi_a[j]). + The d_v operator norm scales as ~ Nl^2/width, making this the *dominant* + stiffness in the mixed method -- hence it is treated implicitly in IMEX. + """ + D = np.asarray(deriv) + g = np.asarray(gamma_vec) / width + xa, xb = np.asarray(xi_a), np.asarray(xi_b) + P = g[:, None] * (np.outer(xb, xb) - np.outer(xa, xa)) + return P - D + + # --------------------------------------------------------------------------- # Lawson-RK4 vector field (Stepper / discrete-map convention) # --------------------------------------------------------------------------- @@ -409,6 +492,10 @@ def __init__( dt: float, mask23: Array | None = None, field_on: bool = True, + ex_driver: "ExternalExDriver | None" = None, + imex: bool = False, + G_C: Array | None = None, + G_B: Array | None = None, ): self.combined_exp = combined_exp self.poisson = poisson @@ -424,6 +511,12 @@ def __init__( self.dt = float(dt) self.mask23 = mask23 self.field_on = bool(field_on) + self.ex_driver = ex_driver + # IMEX: treat the (stiff) E.d_v f Lorentz force implicitly via a frozen-E + # backward-Euler substep, keeping the rest of the RHS in the explicit Lawson step. + self.imex = bool(imex) + self.G_C = None if G_C is None else jnp.asarray(G_C, dtype=jnp.complex128) + self.G_B = None if G_B is None else jnp.asarray(G_B, dtype=jnp.complex128) def _nonlinear_rhs(self, t: float, state: dict, args: dict) -> dict: Ck = state["Ck"].view(jnp.complex128) @@ -432,16 +525,27 @@ def _nonlinear_rhs(self, t: float, state: dict, args: dict) -> dict: # field_on=False -> pure advection (phi=0); the linear Hermite->Legendre # closure flux (d_x C_{Nh-1}) still acts, only E-dependent terms vanish. E = self.poisson.electric_field(Ck, Bk) if self.field_on else jnp.zeros(Ck.shape[1]) # (Nx,) - - dCk = _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) - dBk = _legendre_force( - Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23 - ) + _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + # external longitudinal driver (evaluated at the substep time; never enters Poisson) + if self.ex_driver is not None: + E = E + self.ex_driver(t, args) + + cross = _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + if self.imex: + # E.d_v f (Hermite + Legendre force) is handled by the implicit substep; + # only the non-stiff Hermite->Legendre closure flux stays explicit. + dCk = jnp.zeros_like(Ck) + dBk = cross + else: + dCk = _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) + dBk = ( + _legendre_force(Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23) + + cross + ) out = dict(state) out["Ck"] = dCk.view(jnp.float64) out["Bk"] = dBk.view(jnp.float64) - for k in ("e", "phi"): + for k in ("e", "phi", "de"): if k in out: out[k] = jnp.zeros_like(state[k]) return out @@ -477,8 +581,56 @@ def _lawson_rk4(self, t: float, state: dict, args: dict) -> dict: ) return _tree_add(Ef_y, weighted) + def _implicit_E_substep(self, state: dict, E_real: Array, dt: float) -> dict: + """Backward-Euler substep for the E.d_v f Lorentz force with frozen E. + + Per spatial point x the force is dC/dt = E(x) G_C C and dB/dt = E(x) G_B B + (block-diagonal: Hermite force touches only C, Legendre only B). Backward + Euler gives (I - dt E(x) G) X_new = X, an unconditionally stable per-x linear + solve (G_C is nilpotent; G_B is lower-triangular + a rank-2 penalty). E(x) is + diagonal in real space, so we transform k->x, solve per x, and transform back. + + NOTE: this uses a dense per-x solve, O(Nx * N^3). Fine for moderate Nx, but for + large Nx the structured solve is far cheaper: a bidiagonal forward-substitution + for the nilpotent Hermite block (cf. _spectrax1d.imex_E) and a Woodbury solve + (lower-triangular + rank-2) for the Legendre block, both O(Nx * N^2). + """ + mask = self.mask23 + maskc = mask[None, :] if mask is not None else 1.0 + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + Nh, Nl = Ck.shape[0], Bk.shape[0] + + C = jnp.fft.ifft(Ck * maskc, axis=-1, norm="forward") # (Nh, Nx) + B = jnp.fft.ifft(Bk * maskc, axis=-1, norm="forward") # (Nl, Nx) + + scale = (dt * E_real).astype(jnp.complex128) # (Nx,) + M_C = jnp.eye(Nh, dtype=jnp.complex128)[None] - scale[:, None, None] * self.G_C[None] # (Nx, Nh, Nh) + M_B = jnp.eye(Nl, dtype=jnp.complex128)[None] - scale[:, None, None] * self.G_B[None] # (Nx, Nl, Nl) + + C_new = jnp.linalg.solve(M_C, C.T[..., None])[..., 0].T # (Nh, Nx) + B_new = jnp.linalg.solve(M_B, B.T[..., None])[..., 0].T # (Nl, Nx) + + Ck_new = jnp.fft.fft(C_new, axis=-1, norm="forward") * maskc + Bk_new = jnp.fft.fft(B_new, axis=-1, norm="forward") * maskc + out = dict(state) + out["Ck"] = Ck_new.view(jnp.float64) + out["Bk"] = Bk_new.view(jnp.float64) + return out + def __call__(self, t: float, y: dict, args: dict) -> dict: y_new = self._lawson_rk4(t, y, args) + + if self.imex: + # Frozen E from the post-explicit state (self-consistent + external driver + # at the end of the step), then one implicit Lorentz substep. + Ckp = y_new["Ck"].view(jnp.complex128) + Bkp = y_new["Bk"].view(jnp.complex128) + E_frozen = self.poisson.electric_field(Ckp, Bkp) if self.field_on else jnp.zeros(Ckp.shape[1]) + if self.ex_driver is not None: + E_frozen = E_frozen + self.ex_driver(t + self.dt, args) + y_new = self._implicit_E_substep(y_new, E_frozen, self.dt) + Ck = y_new["Ck"].view(jnp.complex128) Bk = y_new["Bk"].view(jnp.complex128) if self.field_on: @@ -487,4 +639,5 @@ def __call__(self, t: float, y: dict, args: dict) -> dict: else: e = jnp.zeros(Ck.shape[1]) # phi = 0 : pure advection phi = jnp.zeros(Ck.shape[1]) - return {"Ck": y_new["Ck"], "Bk": y_new["Bk"], "e": e, "phi": phi} + de = self.ex_driver(t, args) if self.ex_driver is not None else jnp.zeros(Ck.shape[1]) + return {"Ck": y_new["Ck"], "Bk": y_new["Bk"], "e": e, "phi": phi, "de": de} diff --git a/configs/hermite-legendre-1d/two-stream-imex.yaml b/configs/hermite-legendre-1d/two-stream-imex.yaml new file mode 100644 index 00000000..9ad59a0c --- /dev/null +++ b/configs/hermite-legendre-1d/two-stream-imex.yaml @@ -0,0 +1,52 @@ +# Mixed Hermite-Legendre 1D --- Two-stream instability, IMEX integrator (paper sec 4.3) +# +# Same physics as two-stream.yaml, but the stiff E.d_v f Lorentz force is advanced +# with an unconditionally stable frozen-E Backward-Euler substep (integrator: imex). +# This removes the explicit CFL limit, so the step can be ~5-10x larger than the +# fully-explicit two-stream.yaml (dt=0.002). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: two-stream-imex + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 12.566370614359172 # 4*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 64 + Nh: 85 + Nl: 171 + tmax: 35.0 + dt: 0.02 + integrator: imex # Lawson-RK4 + implicit (Backward-Euler) Lorentz substep + +initialization: + type: two-stream + eps: 0.01 + mode: 1 + +save: + fields: + t: + nt: 351 + hermite: + t: + nt: 71 + legendre: + t: + nt: 71 diff --git a/configs/hermite-legendre-1d/two-stream.yaml b/configs/hermite-legendre-1d/two-stream.yaml index c62a22b6..5577b80b 100644 --- a/configs/hermite-legendre-1d/two-stream.yaml +++ b/configs/hermite-legendre-1d/two-stream.yaml @@ -32,7 +32,10 @@ grid: Nh: 85 Nl: 171 tmax: 35.0 - dt: 0.01 + # Explicit Lawson-RK4 has a stability (CFL) limit once the field saturates; for + # these parameters dt <= ~0.0025 is required (the paper's dt=0.01 relies on its + # unconditionally stable implicit-midpoint integrator). dt=0.002 is converged. + dt: 0.002 initialization: type: two-stream diff --git a/docs/source/solvers/hermite_legendre_1d/config.md b/docs/source/solvers/hermite_legendre_1d/config.md index 7c20257b..909f7a09 100644 --- a/docs/source/solvers/hermite_legendre_1d/config.md +++ b/docs/source/solvers/hermite_legendre_1d/config.md @@ -74,6 +74,31 @@ The artificial collision operator (paper sec 2.5) uses the cubic spectrum | `Nl` | int | — | Number of Legendre modes for `df` (closure by truncation: `B_{Nl}=0`) | | `tmax` | float | — | Final simulation time (normalized). Snapped to an exact multiple of `dt`. | | `dt` | float | `0.01` | Timestep | +| `integrator` | str | `"lawson"` | Time integrator: `"lawson"` (fully explicit Lawson-RK4) or `"imex"` (Lawson-RK4 + implicit Lorentz substep — see below). | + +### `integrator: imex` + +The stiffness that limits the explicit step is the `E·∂_v f` Lorentz force: in the +spectral velocity bases it is strictly lower-triangular (nilpotent for Hermite, +lower-triangular + a rank-2 penalty for Legendre) with operator norm `~Nl²/width·|E|` +— so explicit RK4's `|dt·‖L‖|≲2.8` limit tightens as modes/field grow. Setting +`integrator: imex` keeps free-streaming, collisions, and the Hermite→Legendre closure +flux in the explicit Lawson step, and advances the Lorentz force with an +**unconditionally stable frozen-E Backward-Euler substep** (a per-`x` triangular/dense +linear solve; first-order Lie split). This removes the CFL limit, letting two-stream +run at `dt ≈ 0.02` instead of `0.002`. Trade-offs: Backward Euler is mildly dissipative +and the split is first-order in `dt`, so for high-accuracy/conservation studies prefer +small-`dt` `lawson`; for robustness at large mode counts or large `Nx`, prefer `imex`. + +**Choosing `dt`.** Free-streaming and collisions are integrated exactly, but the +explicit Lawson-RK4 treatment of the E-field force has a stability (CFL) limit that +tightens as the self-consistent field grows. For small-amplitude/linear runs (e.g. +driven Landau damping) `dt = 0.05` is fine; for nonlinear instabilities that saturate +to a large field (two-stream) a smaller step is needed — `dt ≈ 0.002` is stable and +converged for the two-stream benchmark. (The paper's `dt = 0.01` relies on its +unconditionally stable implicit-midpoint integrator; this explicit module trades that +for a smaller step and a much smaller memory footprint.) A run that goes `NaN` partway +through is the signature of `dt` above the CFL limit — halve it. --- @@ -92,13 +117,41 @@ Here `k = 2π·mode/Lx`. The Legendre projection uses Gauss-Legendre quadrature. --- +## drivers (optional) + +An external longitudinal field `ex` can be applied to the velocity-space force (it +never enters the Poisson solve), e.g. to drive a resonant EPW for a Landau-damping +measurement — the analogue of the Vlasov-1D `ex` driver. Omit the `drivers` block for +self-consistent runs. + +```yaml +drivers: + ex: + '0': # one entry per pulse + k0: 0.4 # wavenumber + w0: 1.285 # angular frequency (e.g. Re(omega) from the dispersion relation) + dw0: 0.0 # frequency offset (added to w0) + a0: 1.0e-3 # amplitude + t_center: 20.0 # pulse: center / full width / rise(+fall) time + t_width: 20.0 + t_rise: 5.0 + x_center: 7.85 # spatial envelope: center / width / rise (defaults span the box) + x_width: 1.0e6 + x_rise: 1.0 +``` + +The driver field is `E_drive(x,t) = Σ env(x,t)·(w0+dw0)·a0·sin(k0 x − (w0+dw0) t)` and +is saved as `de` in the `fields` group. + +--- + ## save Standard ADEPT `save` block with `t: {nt: ...}` (or `tmin`/`tmax`/`nt`) sub-axes. | Key | Contents | |-----|----------| -| `fields` | Electric field `e(x,t)` and potential `phi(x,t)` | +| `fields` | Electric field `e(x,t)`, potential `phi(x,t)`, and external driver field `de(x,t)` | | `hermite` | AW-Hermite-Fourier coefficient timeseries `Ck` (shape `nt × Nh × Nx`) | | `legendre` | Legendre-Fourier coefficient timeseries `Bk` (shape `nt × Nl × Nx`) | | `default` | Scalar invariants `mass`, `momentum`, `energy` (paper eqns 26, 28, 30-31) plus field energy and density extrema. Always added; the primary correctness gate. | diff --git a/tests/test_hermite_legendre_1d/test_imex.py b/tests/test_hermite_legendre_1d/test_imex.py new file mode 100644 index 00000000..ed7f3489 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_imex.py @@ -0,0 +1,64 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""IMEX stability gate for the mixed Hermite-Legendre solver. + +The explicit Lawson-RK4 step has a CFL-like limit set by the (stiff) E.d_v f Lorentz +force, whose spectral operator norm scales as ~Nl^2/width * |E|; for the two-stream +benchmark it blows up by t~20 at dt=0.01. The IMEX integrator advances that force with +a frozen-E Backward-Euler substep (unconditionally stable), so the same run stays +finite and well-behaved at a step several times larger. This test gates that: + + - explicit (lawson) at dt=0.01 goes non-finite before t=35 (the failure IMEX fixes), + - IMEX at dt=0.01 stays finite through t=35, conserves mass (the force leaves the + C_0/B_0 moments untouched), and saturates to a physically reasonable field energy. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_two_stream(integrator: str, dt: float, tmax: float = 35.0, Nh: int = 85, Nl: int = 171): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, "alpha": np.sqrt(2.0), "u": 0.0, "v_a": -2.5, "v_b": 2.5, + "gamma": 0.5, "nu_H": 0.0, "nu_L": 1.0, "enforce_conservation": True, "field": True, + }, + "grid": {"Nx": 64, "Nh": Nh, "Nl": Nl, "tmax": tmax, "dt": dt, "integrator": integrator}, + "initialization": {"type": "two-stream", "eps": 0.01, "mode": 1}, + "save": {"default": {"t": {"nt": 80}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"].ys["default"] + + +def test_explicit_blows_up_at_large_dt(): + """The failure mode that motivates IMEX: explicit two-stream is unstable at dt=0.01.""" + d = _run_two_stream("lawson", dt=0.01) + assert not np.all(np.isfinite(np.asarray(d["energy"]))), "expected explicit blow-up at dt=0.01" + + +def test_imex_stable_at_large_dt(): + d = _run_two_stream("imex", dt=0.01) + energy = np.asarray(d["energy"]) + ee = np.asarray(d["e_energy"]) + mass = np.asarray(d["mass"]) + + assert np.all(np.isfinite(energy)), "IMEX went non-finite at dt=0.01" + # mass is untouched by the Lorentz force (G_C/G_B leave the C_0/B_0 moments alone) + assert np.max(np.abs(mass - mass[0])) / abs(mass[0]) < 1e-10, "IMEX broke mass conservation" + # the instability grew and saturated to a finite field energy comparable to the + # converged explicit value (~0.33); generous bounds (Backward Euler is dissipative). + assert ee.max() > 1e-2, "two-stream field energy did not grow under IMEX" + assert 0.1 < ee[-1] < 1.0, f"saturated field energy out of range: {ee[-1]:.3f}" diff --git a/tests/test_hermite_legendre_1d/test_landau_damping.py b/tests/test_hermite_legendre_1d/test_landau_damping.py new file mode 100644 index 00000000..f00780b6 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_landau_damping.py @@ -0,0 +1,99 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Driven Landau-damping gate for the mixed Hermite-Legendre solver. + +Mirrors the BaseVlasov1D driven-resonance test (tests/test_vlasov1d/test_landau_damping.py): +a uniform Maxwellian is driven by a small external longitudinal field `ex` at the +resonant wavenumber `k0` and frequency `w0 = Re(omega)` of the kinetic electrostatic +dispersion relation, on a box of length `2*pi/k0` (so the driven mode is k=1). After +the driver ramps off, the electron plasma wave free-rings at its natural frequency +and decays at the Landau rate; we measure both from `E_x(k=1, t)` and compare to the +dispersion-relation root. + +The bulk is carried by the AW-Hermite expansion, which captures Landau damping; the +Legendre part stays negligible in this linear regime. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np +import pytest + +from adept import electrostatic +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_driven(klambda_D: float, w0: float, Nh: int = 256, tmax: float = 80.0, dt: float = 0.05): + Lx = 2.0 * np.pi / klambda_D + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": Lx, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -6.0, + "v_b": 6.0, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": True, + }, + "grid": {"Nx": 64, "Nh": Nh, "Nl": 16, "tmax": tmax, "dt": dt}, + "initialization": {"type": "linear-advection", "eps": 0.0, "mode": 1}, # uniform Maxwellian + "drivers": { + "ex": { + "0": { + "k0": 2.0 * np.pi / Lx, + "w0": w0, + "dw0": 0.0, + "a0": 1.0e-3, + "t_center": 20.0, + "t_width": 20.0, + "t_rise": 5.0, + "x_center": 0.5 * Lx, + "x_width": 1.0e6, + "x_rise": 1.0, + } + } + }, + "save": {"fields": {"t": {"nt": int(tmax * 4)}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + sol = m(trainable_modules={})["solver result"] + e = np.asarray(sol.ys["fields"]["e"]) + t = np.asarray(sol.ts["fields"]) + ek1 = np.fft.fft(e, axis=1)[:, 1] / e.shape[1] + return t, ek1 + + +@pytest.mark.parametrize("klambda_D", [0.30, 0.35]) +def test_driven_landau_damping(klambda_D): + root = electrostatic.get_roots_to_electrostatic_dispersion(1.0, 1.0, klambda_D, maxwellian_convention_factor=2.0) + expected_freq = float(np.real(root)) + expected_damp = float(np.imag(root)) + + t, ek1 = _run_driven(klambda_D, w0=expected_freq) + assert np.all(np.isfinite(ek1)), "driven run went non-finite" + + # free-decay window: driver is off by ~t=40, recurrence is well beyond t=70 at Nh=256 + win = (t > 45.0) & (t < 70.0) + mag = np.abs(ek1[win]) + measured_damp = float(np.polyfit(t[win], np.log(mag), 1)[0]) # d/dt ln|E_k1| + measured_freq = float(-np.polyfit(t[win], np.unwrap(np.angle(ek1[win])), 1)[0]) # -d(arg)/dt + + print( + f"\nklambda_D={klambda_D:.2f} freq {measured_freq:.4f} (exp {expected_freq:.4f}) " + f"damp {measured_damp:.5f} (exp {expected_damp:.5f})" + ) + # frequency is captured to <1%; finite-Nh Hermite slightly under-damps (~5-9%). + np.testing.assert_allclose(measured_freq, expected_freq, rtol=0.02) + np.testing.assert_allclose(measured_damp, expected_damp, rtol=0.15) From b05dc5d886dcc970dea9102763ce17c1c8e938f1 Mon Sep 17 00:00:00 2001 From: archis Date: Mon, 15 Jun 2026 10:39:11 -0700 Subject: [PATCH 3/3] feat(hermite-legendre-1d): implicit-midpoint integrator via AD-JFNK + preconditioner Adds `grid.integrator: implicit` -- the implicit-midpoint rule y1 = y0 + dt F((y0+y1)/2) solved by Jacobian-free Newton-Krylov, with the Jacobian applied as an EXACT autodiff JVP (jax.linearize) to a matrix-free GMRES (jax.scipy.sparse.linalg.gmres). The Jacobian is never assembled (memory is the state plus a few Krylov vectors), so it is the laptop-feasible method the paper uses, here realized with autodiff. Complex coefficients are carried as real (re,im) pytree leaves so the Krylov inner products are the standard real ones. Implicit midpoint is A-stable (no CFL) and conserves quadratic invariants: on two-stream it conserves mass exactly and energy to the solve tolerance (2nd order in dt), and is stable far past where explicit/IMEX die. Preconditioning (`grid.precondition: true`, default) is essential -- unpreconditioned GMRES stalls on the skew streaming spectrum and Newton then injects energy. The preconditioner M = I - dt/2(L_stream + L_force) is applied as a composition of cheap structured solves: streaming is block-diagonal in k and tridiagonal in mode (per-k tridiagonal solve); the Hermite force is lower-bidiagonal and the Legendre force is dominated by the lower-triangular derivative matrix (per-x triangular solves; the rank-2 Dirichlet penalty is left to GMRES). This takes bump-on-tail's linear-phase energy drift from 928% (unpreconditioned) to ~1e-9. Config knobs: integrator, precondition, newton_iters, gmres_restart/maxiter/tol. bump-on-tail.yaml -> integrator: implicit, precondition, dt=0.02. The bump saturation phase needs a small step for the nonlinear solve to converge (the paper uses dt=0.01); large-dt robustness at saturation is the motivation for a learned preconditioner (follow-up). test_implicit gates mass-exact / energy-conserving implicit midpoint on two-stream. Suite: 17 passed. Co-Authored-By: Claude Opus 4.8 (1M context) --- adept/_hermite_legendre_1d/modules.py | 18 +- adept/_hermite_legendre_1d/vector_field.py | 183 ++++++++++++++++++ configs/hermite-legendre-1d/bump-on-tail.yaml | 15 +- .../solvers/hermite_legendre_1d/config.md | 26 ++- .../test_hermite_legendre_1d/test_implicit.py | 64 ++++++ 5 files changed, 302 insertions(+), 4 deletions(-) create mode 100644 tests/test_hermite_legendre_1d/test_implicit.py diff --git a/adept/_hermite_legendre_1d/modules.py b/adept/_hermite_legendre_1d/modules.py index d9931f66..9f8663cd 100644 --- a/adept/_hermite_legendre_1d/modules.py +++ b/adept/_hermite_legendre_1d/modules.py @@ -256,6 +256,7 @@ def init_diffeqsolve(self) -> None: field_on = bool(physics.get("field", True)) integrator = str(grid.get("integrator", "lawson")).lower() imex = integrator == "imex" + implicit_mp = integrator == "implicit" dt = float(grid["dt"]) kx_1d = grid["kx_1d"] @@ -269,8 +270,9 @@ def init_diffeqsolve(self) -> None: hermite_stream = StreamingExp1D(T_H, prefactor=-1j * alpha, kx_1d=kx_1d) legendre_stream = StreamingExp1D(np.asarray(leg["T_L"]), prefactor=-1j, kx_1d=kx_1d) - hermite_coll = DiagonalCollisionExp1D(nu_H, safe_col(Nh)) - legendre_coll = DiagonalCollisionExp1D(nu_L, safe_col(Nl)) + col_e, col_l = safe_col(Nh), safe_col(Nl) + hermite_coll = DiagonalCollisionExp1D(nu_H, col_e) + legendre_coll = DiagonalCollisionExp1D(nu_L, col_l) combined_exp = CombinedLinearExp1D(hermite_stream, legendre_stream, hermite_coll, legendre_coll) poisson = PoissonSolver1D(one_over_kx=one_over_kx, kx_sq=kx_sq, alpha=alpha, width=width) @@ -314,6 +316,18 @@ def init_diffeqsolve(self) -> None: imex=imex, G_C=G_C, G_B=G_B, + implicit=implicit_mp, + T_H=jnp.asarray(T_H) if implicit_mp else None, + T_L=jnp.asarray(np.asarray(leg["T_L"])) if implicit_mp else None, + col_e=col_e if implicit_mp else None, + col_l=col_l if implicit_mp else None, + nu_H=nu_H, + nu_L=nu_L, + newton_iters=int(grid.get("newton_iters", 3)), + gmres_restart=int(grid.get("gmres_restart", 20)), + gmres_maxiter=int(grid.get("gmres_maxiter", 4)), + gmres_tol=float(grid.get("gmres_tol", 1e-8)), + precondition=bool(grid.get("precondition", True)), ) self.cfg = get_save_quantities(self.cfg) diff --git a/adept/_hermite_legendre_1d/vector_field.py b/adept/_hermite_legendre_1d/vector_field.py index 3e8ef2a4..ab21da17 100644 --- a/adept/_hermite_legendre_1d/vector_field.py +++ b/adept/_hermite_legendre_1d/vector_field.py @@ -496,6 +496,18 @@ def __init__( imex: bool = False, G_C: Array | None = None, G_B: Array | None = None, + implicit: bool = False, + T_H: Array | None = None, + T_L: Array | None = None, + col_e: Array | None = None, + col_l: Array | None = None, + nu_H: float = 0.0, + nu_L: float = 0.0, + newton_iters: int = 3, + gmres_restart: int = 20, + gmres_maxiter: int = 4, + gmres_tol: float = 1e-8, + precondition: bool = True, ): self.combined_exp = combined_exp self.poisson = poisson @@ -517,6 +529,88 @@ def __init__( self.imex = bool(imex) self.G_C = None if G_C is None else jnp.asarray(G_C, dtype=jnp.complex128) self.G_B = None if G_B is None else jnp.asarray(G_B, dtype=jnp.complex128) + # Implicit midpoint (AD-JFNK): needs the raw streaming/collision RHS operators. + self.implicit = bool(implicit) + self.T_H = None if T_H is None else jnp.asarray(T_H) + self.T_L = None if T_L is None else jnp.asarray(T_L) + self.col_e = None if col_e is None else jnp.asarray(col_e) + self.col_l = None if col_l is None else jnp.asarray(col_l) + self.nu_H = float(nu_H) + self.nu_L = float(nu_L) + self.newton_iters = int(newton_iters) + self.gmres_restart = int(gmres_restart) + self.gmres_maxiter = int(gmres_maxiter) + self.gmres_tol = float(gmres_tol) + self.precondition = bool(precondition) + if self.implicit and self.precondition: + self._setup_stream_preconditioner() + + def _setup_stream_preconditioner(self) -> None: + """Precompute the per-k tridiagonal bands of M = I - (dt/2)(L_stream + L_coll). + + L_stream is block-diagonal in k and tridiagonal in mode (the streaming matrices + T_H, T_L are tridiagonal), so M^{-1} is a batched tridiagonal solve. Including the + diagonal collision term improves conditioning for collisional runs. This M + captures the stiff (imaginary) streaming spectrum that otherwise cripples GMRES. + """ + half = 0.5 * self.dt + + def bands(T, prefac, nu, col): + T = jnp.asarray(T) + diagT = jnp.diagonal(T) # (N,) + off = jnp.diagonal(T, offset=1) # (N-1,) + N = diagT.shape[0] + c = prefac * half * self.kx_1d # (Nx,) ; prefac = -1 * (-i alpha) etc. -> see below + # M = I + (dt/2)*(i*coef*kx*T) + (dt/2)*nu*col (coef = alpha for H, 1 for L) + d = 1.0 + c[:, None] * diagT[None, :] + half * nu * col[None, :] + sub = jnp.concatenate([jnp.zeros(1, dtype=off.dtype), off]) # sub[n]=T[n,n-1] + sup = jnp.concatenate([off, jnp.zeros(1, dtype=off.dtype)]) # sup[n]=T[n,n+1] + dl = c[:, None] * sub[None, :] + du = c[:, None] * sup[None, :] + return dl.astype(jnp.complex128), d.astype(jnp.complex128), du.astype(jnp.complex128) + + # L_stream C = -i*alpha*kx*(T_H C) -> M_H = I + i*(dt/2)*alpha*kx*T_H + self._pc_h = bands(self.T_H, 1j * self.alpha, self.nu_H, self.col_e) + # L_stream B = -i*kx*(T_L B) -> M_L = I + i*(dt/2)*kx*T_L + self._pc_l = bands(self.T_L, 1j, self.nu_L, self.col_l) + + def _stream_precond_apply(self, v: dict) -> dict: + """Apply M^{-1} (streaming+collision preconditioner) to a real (re,im) pytree.""" + Ck = v["Cr"] + 1j * v["Ci"] # (Nh, Nx) + Bk = v["Br"] + 1j * v["Bi"] # (Nl, Nx) + dl_h, d_h, du_h = self._pc_h + dl_l, d_l, du_l = self._pc_l + x = jax.lax.linalg.tridiagonal_solve(dl_h, d_h, du_h, Ck.T[..., None])[..., 0].T + y = jax.lax.linalg.tridiagonal_solve(dl_l, d_l, du_l, Bk.T[..., None])[..., 0].T + return {"Cr": x.real, "Ci": x.imag, "Br": y.real, "Bi": y.imag} + + def _force_precond_apply(self, v: dict, sC: Array) -> dict: + """Apply (I - dt/2 E0 G_force)^{-1} (the Lorentz-force preconditioner), per x. + + sC = (dt/2) E0(x) is the frozen-field scale. The Hermite force operator is + lower-bidiagonal -> a per-x tridiagonal (forward-substitution) solve; the + Legendre force is dominated by the derivative matrix D (norm ~Nl^2/width), a + per-x lower-triangular solve. The rank-2 Dirichlet penalty is left to GMRES (a + preconditioner need not be exact). This is the piece that becomes stiff at + saturation, where streaming-only preconditioning stalls. + """ + Ck = v["Cr"] + 1j * v["Ci"] # (Nh, Nx) + Bk = v["Br"] + 1j * v["Bi"] # (Nl, Nx) + Nh, Nx = Ck.shape + Nl = Bk.shape[0] + sC_c = sC.astype(jnp.complex128) + + # Hermite: (I - sC G_C) is lower-bidiagonal; (I-sC G_C)[n,n-1] = sC*sqrt(2n)/alpha + dl = sC_c[:, None] * self.sqrt_2n_over_alpha[None, :].astype(jnp.complex128) # (Nx, Nh) + dl = dl.at[:, 0].set(0.0) + d = jnp.ones((Nx, Nh), dtype=jnp.complex128) + du = jnp.zeros((Nx, Nh), dtype=jnp.complex128) + Cn = jax.lax.linalg.tridiagonal_solve(dl, d, du, Ck.T[..., None])[..., 0].T + + # Legendre: (I + sC D) lower-triangular (unit diagonal); D = deriv matrix + A0 = jnp.eye(Nl, dtype=jnp.complex128)[None] + sC_c[:, None, None] * self.deriv.astype(jnp.complex128)[None] + Bn = jax.lax.linalg.triangular_solve(A0, Bk.T[..., None], left_side=True, lower=True)[..., 0].T + return {"Cr": Cn.real, "Ci": Cn.imag, "Br": Bn.real, "Bi": Bn.imag} def _nonlinear_rhs(self, t: float, state: dict, args: dict) -> dict: Ck = state["Ck"].view(jnp.complex128) @@ -581,6 +675,82 @@ def _lawson_rk4(self, t: float, state: dict, args: dict) -> dict: ) return _tree_add(Ef_y, weighted) + def _full_rhs_complex(self, t: float, Ck: Array, Bk: Array) -> tuple: + """Raw RHS dy/dt = streaming + collisions + Lorentz force + closure coupling. + + Unlike the Lawson path (which integrates streaming/collisions exactly via the + matrix exponential), the implicit-midpoint integrator needs the explicit RHS + operators: streaming d_t C = -i alpha kx (T_H C), d_t B = -i kx (T_L B), the + diagonal collision -nu col[n], and the (reused) E.d_v f force + coupling terms. + """ + E = self.poisson.electric_field(Ck, Bk) if self.field_on else jnp.zeros(Ck.shape[1]) + if self.ex_driver is not None: + E = E + self.ex_driver(t, None) + + dCk = -1j * self.alpha * self.kx_1d[None, :] * (self.T_H @ Ck) - self.nu_H * self.col_e[:, None] * Ck + dBk = -1j * self.kx_1d[None, :] * (self.T_L @ Bk) - self.nu_L * self.col_l[:, None] * Bk + + dCk = dCk + _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) + dBk = ( + dBk + + _legendre_force(Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23) + + _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + ) + return dCk, dBk + + def _implicit_midpoint_solve(self, t: float, Ck0: Array, Bk0: Array) -> tuple: + """One implicit-midpoint step via Jacobian-free Newton-Krylov (AD-JFNK). + + Solves y1 = y0 + dt F((y0+y1)/2). Implicit midpoint is A-stable and conserves + quadratic invariants (energy), so it has no CFL limit and no spurious energy + growth -- needed for the saturated/long-time regimes where the explicit and + IMEX paths blow up. The Newton linear solves use a matrix-free GMRES whose + Jacobian-vector products are EXACT autodiff JVPs (jax.linearize) -- the Jacobian + is never formed. Complex coefficients are carried as real (re, im) pytree leaves + so the Krylov inner products are the standard real ones. + """ + import jax.scipy.sparse.linalg as jsla + + dt = self.dt + t_mid = t + 0.5 * dt + y0 = {"Cr": Ck0.real, "Ci": Ck0.imag, "Br": Bk0.real, "Bi": Bk0.imag} + + # Frozen-field scale for the force preconditioner (E at the step-start state). + E0 = self.poisson.electric_field(Ck0, Bk0) if self.field_on else jnp.zeros(Ck0.shape[1]) + if self.ex_driver is not None: + E0 = E0 + self.ex_driver(t_mid, None) + sC = 0.5 * dt * E0 + + def rhs(yr): + Ck = yr["Cr"] + 1j * yr["Ci"] + Bk = yr["Br"] + 1j * yr["Bi"] + dCk, dBk = self._full_rhs_complex(t_mid, Ck, Bk) + return {"Cr": dCk.real, "Ci": dCk.imag, "Br": dBk.real, "Bi": dBk.imag} + + def residual(y1): + y_mid = jax.tree.map(lambda a, b: 0.5 * (a + b), y0, y1) + f = rhs(y_mid) + return jax.tree.map(lambda a, b, ff: a - b - dt * ff, y1, y0, f) + + # Combined preconditioner M^{-1} = M_force^{-1} . M_stream^{-1}: streaming + # (per-k tridiagonal) handles the linear/growth phase, the force part handles + # saturation where the Lorentz term dominates the Jacobian. + if self.precondition: + def M(vv): + return self._force_precond_apply(self._stream_precond_apply(vv), sC) + else: + M = None + y1 = y0 + for _ in range(self.newton_iters): + r, jvp_fn = jax.linearize(residual, y1) # jvp_fn(v) = J @ v, exact (AD) + neg_r = jax.tree.map(jnp.negative, r) + delta, _ = jsla.gmres( + jvp_fn, neg_r, M=M, tol=self.gmres_tol, atol=0.0, restart=self.gmres_restart, maxiter=self.gmres_maxiter + ) + y1 = jax.tree.map(jnp.add, y1, delta) + + return y1["Cr"] + 1j * y1["Ci"], y1["Br"] + 1j * y1["Bi"] + def _implicit_E_substep(self, state: dict, E_real: Array, dt: float) -> dict: """Backward-Euler substep for the E.d_v f Lorentz force with frozen E. @@ -619,6 +789,19 @@ def _implicit_E_substep(self, state: dict, E_real: Array, dt: float) -> dict: return out def __call__(self, t: float, y: dict, args: dict) -> dict: + if self.implicit: + Ck0 = y["Ck"].view(jnp.complex128) + Bk0 = y["Bk"].view(jnp.complex128) + Ck, Bk = self._implicit_midpoint_solve(t, Ck0, Bk0) + if self.field_on: + e = self.poisson.electric_field(Ck, Bk) + phi = self.poisson.potential(Ck, Bk) + else: + e = jnp.zeros(Ck.shape[1]) + phi = jnp.zeros(Ck.shape[1]) + de = self.ex_driver(t, args) if self.ex_driver is not None else jnp.zeros(Ck.shape[1]) + return {"Ck": Ck.view(jnp.float64), "Bk": Bk.view(jnp.float64), "e": e, "phi": phi, "de": de} + y_new = self._lawson_rk4(t, y, args) if self.imex: diff --git a/configs/hermite-legendre-1d/bump-on-tail.yaml b/configs/hermite-legendre-1d/bump-on-tail.yaml index e6d7a244..98425142 100644 --- a/configs/hermite-legendre-1d/bump-on-tail.yaml +++ b/configs/hermite-legendre-1d/bump-on-tail.yaml @@ -4,6 +4,14 @@ # df(x,v,0) = n_b/sqrt(2*pi) exp(-(v-10)^2/2) (beam projected onto Legendre) # The beam is well separated from the bulk, so f0<->df coupling is weak and a large # Hermite collision rate nu_H = 10 is acceptable. Parameters match Chapurin et al. +# +# This is the most demanding benchmark (asymmetric velocity window, long time). The +# explicit and IMEX integrators both blow up near saturation; it requires the A-stable, +# energy-conserving implicit-midpoint integrator (integrator: implicit) with the combined +# streaming+force preconditioner. At saturation the implicit nonlinear solve is hard, so a +# small step is needed for Newton/GMRES to converge (dt <= ~0.02; the paper uses dt=0.01). +# A larger dt (e.g. 0.1) converges in the linear phase but fails at saturation -- the +# learned/ML preconditioner is the planned route to large-dt robustness here. solver: hermite-legendre-1d @@ -32,7 +40,12 @@ grid: Nh: 128 Nl: 128 tmax: 120.0 - dt: 0.01 + dt: 0.02 # small step so the saturation-phase JFNK solve converges + integrator: implicit # A-stable implicit midpoint; conserves mass exactly, energy to solve tol + precondition: true # combined streaming+force preconditioner + newton_iters: 4 + gmres_restart: 30 + gmres_maxiter: 3 initialization: type: bump-on-tail diff --git a/docs/source/solvers/hermite_legendre_1d/config.md b/docs/source/solvers/hermite_legendre_1d/config.md index 909f7a09..70c9c1cb 100644 --- a/docs/source/solvers/hermite_legendre_1d/config.md +++ b/docs/source/solvers/hermite_legendre_1d/config.md @@ -74,7 +74,31 @@ The artificial collision operator (paper sec 2.5) uses the cubic spectrum | `Nl` | int | — | Number of Legendre modes for `df` (closure by truncation: `B_{Nl}=0`) | | `tmax` | float | — | Final simulation time (normalized). Snapped to an exact multiple of `dt`. | | `dt` | float | `0.01` | Timestep | -| `integrator` | str | `"lawson"` | Time integrator: `"lawson"` (fully explicit Lawson-RK4) or `"imex"` (Lawson-RK4 + implicit Lorentz substep — see below). | +| `integrator` | str | `"lawson"` | Time integrator: `"lawson"` (explicit Lawson-RK4), `"imex"` (Lawson-RK4 + implicit Lorentz substep), or `"implicit"` (implicit midpoint, AD-JFNK) — see below. | +| `newton_iters` | int | `3` | (`implicit`) Newton iterations per step. | +| `gmres_restart`, `gmres_maxiter`, `gmres_tol` | int/int/float | `20`/`4`/`1e-8` | (`implicit`) matrix-free GMRES controls for the Newton linear solves. | +| `precondition` | bool | `true` | (`implicit`) use the streaming+collision operator as a physics-based GMRES preconditioner (see below). | + +### `integrator: implicit` (implicit midpoint via AD-JFNK) + +Advances the full RHS with the implicit-midpoint rule `y1 = y0 + dt·F((y0+y1)/2)`, +solved by **Jacobian-free Newton-Krylov**: each Newton linear system uses a matrix-free +GMRES whose Jacobian-vector products are *exact autodiff JVPs* (`jax.linearize`) — the +Jacobian is never assembled (memory is the state plus a few Krylov vectors). Implicit +midpoint is A-stable (no CFL at all) and conserves quadratic invariants, so it conserves +mass exactly and energy to the solve tolerance, and stays stable into the saturated / +long-time regime where both `lawson` and `imex` blow up (e.g. bump-on-tail). Cost: each +step does `newton_iters × (GMRES iterations) × (RHS evals)`, so it is the most expensive +per step — use it for the hard cases, not the cheap ones. + +**Preconditioning** (`precondition: true`, default). The implicit operator's stiffness is +dominated by the skew streaming term, whose eigenvalues smear along the imaginary axis +(`~dt/2·α·k_max·√(2Nh)`) — the worst case for unpreconditioned GMRES, which then needs +many iterations and can fail to converge at large `dt`/`Nx` (Newton then injects energy). +The preconditioner `M = I − dt/2·(L_streaming + L_collision)` is block-diagonal in `k` and +tridiagonal in mode index, so `M⁻¹` is a cheap per-`k` tridiagonal solve that captures +exactly that stiff spectrum; GMRES on `M⁻¹A` then converges in a handful of iterations. +This is what makes large-`dt` implicit-midpoint runs practical. ### `integrator: imex` diff --git a/tests/test_hermite_legendre_1d/test_implicit.py b/tests/test_hermite_legendre_1d/test_implicit.py new file mode 100644 index 00000000..739bcbc4 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_implicit.py @@ -0,0 +1,64 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Implicit-midpoint (AD-JFNK) gate for the mixed Hermite-Legendre solver. + +`integrator: implicit` advances the FULL right-hand side with the implicit-midpoint +rule y1 = y0 + dt F((y0+y1)/2), solved by Jacobian-free Newton-Krylov: the Newton +linear systems use a matrix-free GMRES whose Jacobian-vector products are exact +autodiff JVPs (jax.linearize) -- the Jacobian is never assembled. Implicit midpoint is +A-stable (no CFL) and conserves quadratic invariants, so unlike the explicit and IMEX +paths it stays stable into the saturated/long-time regime and conserves mass exactly +and energy to the nonlinear-solve tolerance. + +This gates a two-stream run at dt=0.05 -- a step at which the explicit Lawson path is +violently unstable -- staying finite, conserving mass to machine precision, and holding +energy far better than an explicit step of the same size could. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_two_stream(integrator: str, dt: float, tmax: float = 10.0, Nh: int = 48, Nl: int = 48): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -2.5, + "v_b": 2.5, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 1.0, + "enforce_conservation": True, + "field": True, + }, + "grid": {"Nx": 48, "Nh": Nh, "Nl": Nl, "tmax": tmax, "dt": dt, "integrator": integrator}, + "initialization": {"type": "two-stream", "eps": 0.05, "mode": 1}, + "save": {"default": {"t": {"nt": 40}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"].ys["default"] + + +def test_implicit_midpoint_stable_and_conserving(): + d = _run_two_stream("implicit", dt=0.05) + energy = np.asarray(d["energy"]) + mass = np.asarray(d["mass"]) + + assert np.all(np.isfinite(energy)), "implicit midpoint went non-finite" + # implicit midpoint conserves mass exactly and energy to the JFNK solve tolerance + assert np.max(np.abs(mass - mass[0])) / abs(mass[0]) < 1e-11, "mass not conserved" + assert np.max(np.abs(energy - energy[0])) / abs(energy[0]) < 1e-4, "energy drift too large"