diff --git a/adept/__init__.py b/adept/__init__.py index 1c311e14..3b5effaf 100644 --- a/adept/__init__.py +++ b/adept/__init__.py @@ -1,3 +1,3 @@ from ._base_ import ADEPTModule, ergoExo # noqa: I001 from .mlflow_logging import MlflowLoggingModule -from . import hermite_poisson_1d, lpse2d, vlasov1d, vlasov2d +from . import hermite_legendre_1d, hermite_poisson_1d, lpse2d, vlasov1d, vlasov2d diff --git a/adept/_base_.py b/adept/_base_.py index 949bb5a5..aa21bb83 100644 --- a/adept/_base_.py +++ b/adept/_base_.py @@ -326,6 +326,9 @@ def _get_adept_module_(self, cfg: dict) -> ADEPTModule: elif cfg["solver"] == "pic-1d": from adept.pic1d import BasePIC1D as this_module + elif cfg["solver"] == "hermite-legendre-1d": + from adept.hermite_legendre_1d import BaseHermiteLegendre1D as this_module + else: raise NotImplementedError("This solver approach has not been implemented yet") diff --git a/adept/_hermite_legendre_1d/__init__.py b/adept/_hermite_legendre_1d/__init__.py new file mode 100644 index 00000000..4e8ef1dc --- /dev/null +++ b/adept/_hermite_legendre_1d/__init__.py @@ -0,0 +1,3 @@ +from .modules import BaseHermiteLegendre1D + +__all__ = ["BaseHermiteLegendre1D"] diff --git a/adept/_hermite_legendre_1d/modules.py b/adept/_hermite_legendre_1d/modules.py new file mode 100644 index 00000000..9f8663cd --- /dev/null +++ b/adept/_hermite_legendre_1d/modules.py @@ -0,0 +1,440 @@ +"""Base ADEPTModule for the 1D mixed Hermite-Legendre Vlasov-Poisson solver. + +Implements the method of Issan, Delzanno & Roytershteyn (arXiv:2606.12322). The +electron distribution is split f = f0 + df with an AW-Hermite expansion for the +near-Maxwellian bulk f0 and a Legendre expansion (on a bounded velocity window) for +the strongly non-Maxwellian part df. Electrostatic, single electron species with an +immobile neutralizing ion background; periodic in x (Fourier); explicit Lawson-RK4. + +Normalization (paper sec 2.1): t*wpe, x/lambda_D, v/vthe. + +Config keys in cfg["physics"]: + Lx (float), alpha (float, Hermite scale), u (float, Hermite shift, default 0), + v_a, v_b (float, Legendre velocity window), gamma (penalty, default 0.5), + nu_H, nu_L (artificial collision rates, default 0), enforce_conservation (bool, default True) + +Config keys in cfg["grid"]: + Nx (int), Nh (int, Hermite modes), Nl (int, Legendre modes), tmax (float), dt (float, default 0.01) + +Config cfg["initialization"]: + type: "linear-advection" | "two-stream" | "bump-on-tail" | "custom" (+ type-specific params) +""" + +import os +import sys + +import numpy as np +from diffrax import ConstantStepSize, NoProgressMeter, ODETerm, SaveAt, SubSaveAt, TqdmProgressMeter, diffeqsolve +from jax import numpy as jnp + +from adept._base_ import ADEPTModule, Stepper +from adept._hermite_legendre_1d.storage import ( + get_save_quantities, + store_coeff_timeseries, + store_fields_timeseries, +) +from adept._hermite_legendre_1d.vector_field import ( + CombinedLinearExp1D, + DiagonalCollisionExp1D, + ExternalExDriver, + HermiteLegendre1DVectorField, + PoissonSolver1D, + StreamingExp1D, + hermite_force_operator, + hermite_legendre_coupling_vector, + hermite_streaming_matrix, + legendre_constants, + legendre_force_operator, + safe_col, +) + + +def _density_profile(x: np.ndarray, Lx: float, base: float, eps: float, mode: int) -> np.ndarray: + """base * (1 + eps cos(2*pi*mode*x/Lx)).""" + return base * (1.0 + eps * np.cos(2.0 * np.pi * mode * x / Lx)) + + +def _project_legendre(g_v, Nl: int, v_a: float, v_b: float) -> np.ndarray: + """Project a velocity profile g(v) onto the Legendre basis: B_m = (1/width) int g xi_m dv.""" + from adept._hermite_legendre_1d.vector_field import _legendre_basis_values + + width = v_b - v_a + deg = max(4 * Nl, 400) + nodes, weights = np.polynomial.legendre.leggauss(deg) + v = 0.5 * width * nodes + 0.5 * (v_b + v_a) + w = 0.5 * width * weights + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + return (xi @ (g_v(v) * w)) / width + + +class BaseHermiteLegendre1D(ADEPTModule): + """1D mixed Hermite-Legendre Vlasov-Poisson base solver (normalized units).""" + + def __init__(self, cfg: dict) -> None: + super().__init__(cfg) + + # ------------------------------------------------------------------ + def write_units(self) -> dict: + """Normalized units throughout; pass through any precomputed derived block.""" + return self.cfg.get("units", {}).get("derived", {}) + + # ------------------------------------------------------------------ + def get_derived_quantities(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + grid["dx"] = Lx / Nx + + tmax = float(grid["tmax"]) + dt = float(grid.get("dt", 0.01)) + nt = round(tmax / dt) + grid["dt"] = dt + grid["nt"] = nt + grid["tmax"] = dt * nt # snap to exact multiple + grid["max_steps"] = nt + 4 + + for save_cfg in self.cfg.get("save", {}).values(): + if isinstance(save_cfg, dict) and "t" in save_cfg: + t_cfg = save_cfg["t"] + t_cfg.setdefault("tmin", 0.0) + t_cfg.setdefault("tmax", grid["tmax"]) + + self.cfg["grid"] = grid + + # ------------------------------------------------------------------ + def get_solver_quantities(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + + kx_1d = jnp.fft.fftfreq(Nx) * Nx * 2.0 * jnp.pi / Lx + one_over_kx = np.zeros(Nx, dtype=np.float64) + one_over_kx[1:] = 1.0 / np.asarray(kx_1d[1:]) + kx_sq = kx_1d**2 + + x = jnp.linspace(0.0, Lx, Nx, endpoint=False) + modes = jnp.fft.fftfreq(Nx) * Nx + mask23 = jnp.abs(modes) <= (Nx // 3) + + grid.update( + { + "x": x, + "kx_1d": kx_1d, + "one_over_kx": jnp.asarray(one_over_kx), + "kx_sq": kx_sq, + "mask23": mask23, + "width": width, + } + ) + self.cfg["grid"] = grid + + # ------------------------------------------------------------------ + def init_state_and_args(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Lx = float(physics["Lx"]) + Nx = int(grid["Nx"]) + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + + x = np.asarray(grid["x"]) + C = np.zeros((Nh, Nx), dtype=np.float64) # real-space C_n(x) + B = np.zeros((Nl, Nx), dtype=np.float64) # real-space B_m(x) + + init = self.cfg.get("initialization", {"type": "linear-advection"}) + itype = init.get("type", "linear-advection") + + if itype == "linear-advection": + eps = float(init.get("eps", 1.0)) + mode = int(init.get("mode", 1)) + n_bulk = _density_profile(x, Lx, 1.0, eps, mode) + C[0] = n_bulk / alpha + + elif itype == "two-stream": + eps = float(init.get("eps", 0.01)) + mode = int(init.get("mode", 1)) + n_bulk = _density_profile(x, Lx, 1.0, eps, mode) + C[0] = n_bulk / alpha + # v^2 Maxwellian -> C_2 = sqrt(2) C_0 in the AW-Hermite basis (alpha=sqrt(2)) + if Nh > 2: + C[2] = np.sqrt(2.0) * C[0] + + elif itype == "bump-on-tail": + eps = float(init.get("eps", 1e-4)) + mode = int(init.get("mode", 1)) + n_beam = float(init.get("n_beam", 0.01)) + v_drift = float(init.get("v_drift", 10.0)) + v_th = float(init.get("v_th", 1.0)) + n_bulk = _density_profile(x, Lx, 1.0 - n_beam, eps, mode) + C[0] = n_bulk / alpha + amp = n_beam / (np.sqrt(2.0 * np.pi) * v_th) + B_beam = _project_legendre(lambda v: amp * np.exp(-((v - v_drift) ** 2) / (2.0 * v_th**2)), Nl, v_a, v_b) + B[:] = B_beam[:, None] # spatially uniform beam + + elif itype == "custom": + # hermite: {n: {base, eps, mode}} ; df: {beams: [{amp, v_drift, v_th}], eps, mode} + for n_str, spec in init.get("hermite", {}).items(): + n = int(n_str) + C[n] = _density_profile( + x, Lx, float(spec.get("base", 0.0)), float(spec.get("eps", 0.0)), int(spec.get("mode", 1)) + ) + df = init.get("df", None) + if df: + beams = df.get("beams", []) + + def g(v): + out = np.zeros_like(v) + for b in beams: + out = out + float(b["amp"]) * np.exp( + -((v - float(b.get("v_drift", 0.0))) ** 2) / (2.0 * float(b.get("v_th", 1.0)) ** 2) + ) + return out + + B_beam = _project_legendre(g, Nl, v_a, v_b) + spatial = _density_profile(x, Lx, 1.0, float(df.get("eps", 0.0)), int(df.get("mode", 1))) + B[:] = B_beam[:, None] * spatial[None, :] + else: + raise ValueError(f"Unknown initialization.type {itype!r}") + + Ck = jnp.fft.fft(jnp.asarray(C), axis=-1, norm="forward").astype(jnp.complex128) + Bk = jnp.fft.fft(jnp.asarray(B), axis=-1, norm="forward").astype(jnp.complex128) + + # Seed the field diagnostics with the actual t=0 Poisson field so the energy + # diagnostic is self-consistent at t=0 (the perturbed initial state carries a + # nonzero E ~ eps); leaving e=0 would record a spurious O(eps^2) jump at step 1. + field_on = bool(physics.get("field", True)) + if field_on: + poisson = PoissonSolver1D( + one_over_kx=grid["one_over_kx"], kx_sq=grid["kx_sq"], alpha=alpha, width=(v_b - v_a) + ) + e0 = poisson.electric_field(Ck, Bk) + phi0 = poisson.potential(Ck, Bk) + else: + e0 = jnp.zeros(Nx) + phi0 = jnp.zeros(Nx) + + self.state = { + "Ck": Ck.view(jnp.float64), + "Bk": Bk.view(jnp.float64), + "e": e0, + "phi": phi0, + "de": jnp.zeros(Nx), # external Ex driver field (diagnostic) + } + self.args = {} + + # ------------------------------------------------------------------ + def init_diffeqsolve(self) -> None: + physics = self.cfg["physics"] + grid = self.cfg["grid"] + + Nh = int(grid["Nh"]) + Nl = int(grid["Nl"]) + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + gamma = float(physics.get("gamma", 0.5)) + nu_H = float(physics.get("nu_H", 0.0)) + nu_L = float(physics.get("nu_L", 0.0)) + enforce = bool(physics.get("enforce_conservation", True)) + field_on = bool(physics.get("field", True)) + integrator = str(grid.get("integrator", "lawson")).lower() + imex = integrator == "imex" + implicit_mp = integrator == "implicit" + dt = float(grid["dt"]) + + kx_1d = grid["kx_1d"] + one_over_kx = grid["one_over_kx"] + kx_sq = grid["kx_sq"] + mask23 = grid["mask23"] + + # Streaming exponentials (exact, prediagonalized symmetric tridiagonals) + T_H = hermite_streaming_matrix(Nh, u, alpha) + leg = legendre_constants(Nl, v_a, v_b) + hermite_stream = StreamingExp1D(T_H, prefactor=-1j * alpha, kx_1d=kx_1d) + legendre_stream = StreamingExp1D(np.asarray(leg["T_L"]), prefactor=-1j, kx_1d=kx_1d) + + col_e, col_l = safe_col(Nh), safe_col(Nl) + hermite_coll = DiagonalCollisionExp1D(nu_H, col_e) + legendre_coll = DiagonalCollisionExp1D(nu_L, col_l) + combined_exp = CombinedLinearExp1D(hermite_stream, legendre_stream, hermite_coll, legendre_coll) + + poisson = PoissonSolver1D(one_over_kx=one_over_kx, kx_sq=kx_sq, alpha=alpha, width=width) + + # External longitudinal (Ex) driver, e.g. a resonant EPW kick for Landau damping + ex_cfg = self.cfg.get("drivers", {}).get("ex", {}) + ex_driver = ExternalExDriver(grid["x"], ex_cfg) if ex_cfg else None + + # Explicit-term constants + n = jnp.arange(Nh, dtype=jnp.float64) + sqrt_2n_over_alpha = jnp.sqrt(2.0 * n) / alpha + gamma_vec = jnp.where(jnp.arange(Nl) >= 3, gamma, 0.0) + + J = hermite_legendre_coupling_vector(Nh, Nl, alpha, u, v_a, v_b, enforce_conservation=enforce) + coupling_vec = -(alpha / width) * jnp.sqrt(Nh / 2.0) * J # folds prefactor into J_{Nh,m} + + # IMEX force operators (only built when integrator == "imex") + G_C = jnp.asarray(hermite_force_operator(Nh, alpha)) if imex else None + G_B = ( + jnp.asarray(legendre_force_operator(leg["deriv"], gamma_vec, leg["xi_a"], leg["xi_b"], width)) + if imex + else None + ) + + vector_field = HermiteLegendre1DVectorField( + combined_exp=combined_exp, + poisson=poisson, + kx_1d=kx_1d, + sqrt_2n_over_alpha=sqrt_2n_over_alpha, + deriv=leg["deriv"], + gamma_vec=gamma_vec, + xi_a=leg["xi_a"], + xi_b=leg["xi_b"], + coupling_vec=coupling_vec, + alpha=alpha, + width=width, + dt=dt, + mask23=mask23, + field_on=field_on, + ex_driver=ex_driver, + imex=imex, + G_C=G_C, + G_B=G_B, + implicit=implicit_mp, + T_H=jnp.asarray(T_H) if implicit_mp else None, + T_L=jnp.asarray(np.asarray(leg["T_L"])) if implicit_mp else None, + col_e=col_e if implicit_mp else None, + col_l=col_l if implicit_mp else None, + nu_H=nu_H, + nu_L=nu_L, + newton_iters=int(grid.get("newton_iters", 3)), + gmres_restart=int(grid.get("gmres_restart", 20)), + gmres_maxiter=int(grid.get("gmres_maxiter", 4)), + gmres_tol=float(grid.get("gmres_tol", 1e-8)), + precondition=bool(grid.get("precondition", True)), + ) + + self.cfg = get_save_quantities(self.cfg) + tmax = float(grid["tmax"]) + max_steps = int(grid["max_steps"]) + + subsaves = {} + for k, v in self.cfg["save"].items(): + if isinstance(v, dict) and "t" in v and "func" in v: + subsaves[k] = SubSaveAt(ts=v["t"]["ax"], fn=v["func"]) + + self.time_quantities = {"t0": 0.0, "t1": tmax, "max_steps": max_steps} + self.diffeqsolve_quants = { + "terms": ODETerm(vector_field), + "solver": Stepper(), + "saveat": {"subs": subsaves}, + } + + # ------------------------------------------------------------------ + def __call__(self, trainable_modules: dict, args: dict | None = None) -> dict: + if args is None: + args = self.args + grid = self.cfg["grid"] + sol = diffeqsolve( + terms=self.diffeqsolve_quants["terms"], + solver=self.diffeqsolve_quants["solver"], + stepsize_controller=ConstantStepSize(), + t0=self.time_quantities["t0"], + t1=self.time_quantities["t1"], + dt0=float(grid["dt"]), + y0=self.state, + args=args, + saveat=SaveAt(**self.diffeqsolve_quants["saveat"]), + max_steps=self.time_quantities["max_steps"], + progress_meter=TqdmProgressMeter() if sys.stdout.isatty() else NoProgressMeter(), + ) + return {"solver result": sol} + + # ------------------------------------------------------------------ + def post_process(self, run_output: dict, td: str) -> dict: + import matplotlib.pyplot as plt + + sol = run_output["solver result"] + binary_dir = os.path.join(td, "binary") + plots_dir = os.path.join(td, "plots") + os.makedirs(binary_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + x = np.asarray(self.cfg["grid"]["x"]) + datasets = {} + metrics = {"simulation_completed": True} + + for k in sol.ys.keys(): + t_arr = np.asarray(sol.ts[k]) + data = sol.ys[k] + + if k == "fields": + fields_dict = {name: np.asarray(arr) for name, arr in data.items()} + datasets["fields"] = store_fields_timeseries(fields_dict, t_arr, binary_dir, x) + for field_name, arr in fields_dict.items(): + if arr.ndim == 2: + try: + arr_plot = np.where(np.isfinite(arr), arr, np.nan) + fig, ax = plt.subplots(figsize=(9, 5), tight_layout=True) + im = ax.pcolormesh(x, t_arr, arr_plot, shading="auto", cmap="RdBu_r") + plt.colorbar(im, ax=ax) + ax.set_xlabel("x [norm]") + ax.set_ylabel("t [norm]") + ax.set_title(field_name) + fig.savefig(os.path.join(plots_dir, f"spacetime-{field_name}.png"), bbox_inches="tight") + except Exception as e: + print(f"post_process: spacetime plot for {field_name} failed: {e}") + finally: + plt.close("all") + + elif k in ("hermite", "legendre"): + name = "Ck" if k == "hermite" else "Bk" + arr = np.asarray(data[name]) + datasets[k] = store_coeff_timeseries(name, arr, t_arr, binary_dir) + + elif k == "default": + scalars = {name: np.asarray(arr) for name, arr in data.items()} + for name, arr in scalars.items(): + if arr.ndim == 1: + final_val = float(arr[-1]) + metrics[f"final_{name}"] = final_val if np.isfinite(final_val) else float("nan") + # relative drift for conserved invariants + if name in ("mass", "momentum", "energy") and np.isfinite(arr[0]) and abs(arr[0]) > 0: + metrics[f"reldrift_{name}"] = float(np.max(np.abs(arr - arr[0])) / abs(arr[0])) + try: + arr_plot = np.where(np.isfinite(arr), arr, np.nan) + fig, axes = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True) + axes[0].plot(t_arr, arr_plot) + axes[0].set_xlabel("t") + axes[0].set_ylabel(name) + axes[0].grid(alpha=0.3) + axes[1].semilogy(t_arr, np.abs(arr_plot) + 1e-30) + axes[1].set_xlabel("t") + axes[1].set_ylabel(f"|{name}|") + axes[1].grid(alpha=0.3) + fig.savefig(os.path.join(plots_dir, f"scalar-{name}.png"), bbox_inches="tight") + except Exception as e: + print(f"post_process: scalar plot for {name} failed: {e}") + finally: + plt.close("all") + + if "default" in sol.ts: + metrics["n_timesteps"] = len(sol.ts["default"]) + + return {"metrics": metrics, "datasets": datasets} diff --git a/adept/_hermite_legendre_1d/storage.py b/adept/_hermite_legendre_1d/storage.py new file mode 100644 index 00000000..dfa82696 --- /dev/null +++ b/adept/_hermite_legendre_1d/storage.py @@ -0,0 +1,192 @@ +""" +Storage / save functions for the 1D mixed Hermite-Legendre solver. + +State keys: Ck (Nh, Nx) complex viewed as float64, Bk (Nl, Nx) complex viewed as +float64, e (Nx,), phi (Nx,). + +Supported cfg["save"] keys: + "fields" -> {e, phi} spacetime arrays + "hermite" -> Ck Hermite-Fourier coefficient timeseries + "legendre" -> Bk Legendre-Fourier coefficient timeseries + "default" -> scalar invariants (mass, momentum, energy) + field energy. Always added. + +The scalar invariants follow the analytic definitions of the mixed method (paper +eqns 26, 28, 30-31). With the conservation constraint J_{Nh,0..2}=0 enforced they +are conserved to solver tolerance, which is the primary correctness gate. +""" + +import os + +import numpy as np +import xarray as xr +from jax import numpy as jnp + +# --------------------------------------------------------------------------- +# Save functions (called by diffrax SubSaveAt at specified timesteps) +# --------------------------------------------------------------------------- + + +def get_fields_save_func(): + """Save electric field e, potential phi, and external Ex driver field de.""" + + def fields_save_func(t, y, args): + out = {"e": y["e"], "phi": y["phi"]} + if "de" in y: + out["de"] = y["de"] + return out + + return fields_save_func + + +def get_hermite_save_func(): + def hermite_save_func(t, y, args): + return {"Ck": y["Ck"].view(jnp.complex128)} + + return hermite_save_func + + +def get_legendre_save_func(): + def legendre_save_func(t, y, args): + return {"Bk": y["Bk"].view(jnp.complex128)} + + return legendre_save_func + + +def get_default_save_func( + alpha: float, u: float, width: float, sigma1: float, sigma2: float, sigma_bar: float, Lx: float +): + """Save scalar invariants: total mass, momentum, energy, and field energy. + + Spatial integrals reduce to the domain length times the k=0 Fourier mode, e.g. + integral C_n dx = Lx * Re(Ck[n, 0]) since norm="forward" puts the mean in [0]. + """ + alpha = float(alpha) + u = float(u) + width = float(width) + sigma1 = float(sigma1) + sigma2 = float(sigma2) + sigma_bar = float(sigma_bar) + Lx = float(Lx) + + def default_save_func(t, y, args): + Ck = y["Ck"].view(jnp.complex128) + Bk = y["Bk"].view(jnp.complex128) + e = y["e"] + nx = e.shape[0] + dx_local = Lx / nx + + def intg(arr_k, n): + return Lx * jnp.real(arr_k[n, 0]) if arr_k.shape[0] > n else 0.0 + + C0, C1, C2 = intg(Ck, 0), intg(Ck, 1), intg(Ck, 2) + B0, B1, B2 = intg(Bk, 0), intg(Bk, 1), intg(Bk, 2) + + mass = alpha * C0 + width * B0 + momentum = (alpha**2 / jnp.sqrt(2.0)) * C1 + u * alpha * C0 + width * sigma1 * B1 + sigma_bar * width * B0 + e_kin = (alpha / 2.0) * ( + (alpha**2 / jnp.sqrt(2.0)) * C2 + jnp.sqrt(2.0) * u * alpha * C1 + (alpha**2 / 2.0 + u**2) * C0 + ) + (width / 2.0) * (sigma2 * sigma1 * B2 + 2.0 * sigma1 * sigma_bar * B1 + (sigma1**2 + sigma_bar**2) * B0) + e_pot = 0.5 * jnp.sum(e**2) * dx_local + energy = e_kin + e_pot + + # density extrema for blow-up monitoring + n_f0 = alpha * jnp.fft.ifft(Ck[0], norm="forward").real + n_df = width * jnp.fft.ifft(Bk[0], norm="forward").real + n_e = n_f0 + n_df + + return { + "mass": mass, + "momentum": momentum, + "energy": energy, + "e_energy": e_pot, + "n_e_max": jnp.max(n_e), + "n_e_min": jnp.min(n_e), + "Bk_max": jnp.max(jnp.abs(Bk)), + "Ck_max": jnp.max(jnp.abs(Ck)), + } + + return default_save_func + + +# --------------------------------------------------------------------------- +# Configure save axes and attach save functions +# --------------------------------------------------------------------------- + + +def get_save_quantities(cfg: dict) -> dict: + """Attach time axes and save functions to cfg["save"]. Modifies cfg in place.""" + grid = cfg["grid"] + physics = cfg["physics"] + + tmax = float(grid["tmax"]) + nt = int(grid["nt"]) + + alpha = float(physics["alpha"]) + u = float(physics.get("u", 0.0)) + v_a = float(physics["v_a"]) + v_b = float(physics["v_b"]) + width = v_b - v_a + sigma1 = (width / 2.0) * 1.0 / np.sqrt(3.0 * 1.0) # n=1 + sigma2 = (width / 2.0) * 2.0 / np.sqrt(5.0 * 3.0) # n=2 + sigma_bar = 0.5 * (v_a + v_b) + Lx = float(physics["Lx"]) + + for save_key, save_cfg in cfg.get("save", {}).items(): + if not isinstance(save_cfg, dict): + continue + if "t" in save_cfg and isinstance(save_cfg["t"], dict): + t_cfg = save_cfg["t"] + if "ax" not in t_cfg: + t_cfg["ax"] = np.linspace( + float(t_cfg.get("tmin", 0.0)), float(t_cfg.get("tmax", tmax)), int(t_cfg.get("nt", nt)) + ) + if "func" in save_cfg: + continue + if save_key == "fields": + save_cfg["func"] = get_fields_save_func() + elif save_key == "hermite": + save_cfg["func"] = get_hermite_save_func() + elif save_key == "legendre": + save_cfg["func"] = get_legendre_save_func() + + if "save" not in cfg: + cfg["save"] = {} + if "default" not in cfg["save"]: + cfg["save"]["default"] = {"t": {"ax": np.linspace(0.0, tmax, nt)}} + elif "t" not in cfg["save"]["default"]: + cfg["save"]["default"]["t"] = {"ax": np.linspace(0.0, tmax, nt)} + elif "ax" not in cfg["save"]["default"]["t"]: + cfg["save"]["default"]["t"]["ax"] = np.linspace(0.0, tmax, nt) + + cfg["save"]["default"]["func"] = get_default_save_func(alpha, u, width, sigma1, sigma2, sigma_bar, Lx) + return cfg + + +# --------------------------------------------------------------------------- +# Post-processing storage helpers +# --------------------------------------------------------------------------- + + +def store_fields_timeseries(fields_dict: dict, t_array: np.ndarray, binary_dir: str, x: np.ndarray) -> xr.Dataset: + das = { + k: xr.DataArray(np.asarray(v), coords=[("t", t_array), ("x", x)], name=k) + for k, v in fields_dict.items() + if np.asarray(v).ndim == 2 + } + ds = xr.Dataset(das) + ds.to_netcdf(os.path.join(binary_dir, f"fields-t={round(float(t_array[-1]), 4)}.nc")) + return ds + + +def store_coeff_timeseries(name: str, arr: np.ndarray, t_array: np.ndarray, binary_dir: str) -> xr.Dataset: + """Save (nt, Nmodes, Nx) complex coefficient timeseries to netCDF.""" + Nx = arr.shape[-1] + Nmodes = arr.shape[-2] + kx = np.fft.fftshift(np.fft.fftfreq(Nx, d=1.0 / Nx)) + arr_shifted = np.fft.fftshift(arr, axes=-1) + ds = xr.Dataset( + {name: (["t", "mode", "kx"], arr_shifted)}, + coords={"t": t_array, "mode": np.arange(Nmodes), "kx": kx}, + ) + ds.to_netcdf(os.path.join(binary_dir, f"{name}-t={round(float(t_array[-1]), 4)}.nc")) + return ds diff --git a/adept/_hermite_legendre_1d/vector_field.py b/adept/_hermite_legendre_1d/vector_field.py new file mode 100644 index 00000000..ab21da17 --- /dev/null +++ b/adept/_hermite_legendre_1d/vector_field.py @@ -0,0 +1,826 @@ +""" +Core vector field for the 1D mixed Hermite-Legendre Vlasov-Poisson solver. + +Implements the mixed method of Issan, Delzanno & Roytershteyn (arXiv:2606.12322, +"Mixed Hermite-Legendre spectral method for kinetic plasma simulations"). + +The electron distribution is split f = f0 + df, with + f0(x, v, t) = sum_{n=0}^{Nh-1} C_n(x, t) psi_n(v; alpha, u) [AW Hermite] + df(x, v, t) = sum_{m=0}^{Nl-1} B_m(x, t) xi_m(v; v_a, v_b) [Legendre] + +evolved by the coupled system (paper eqns 23-25) on a periodic spatial domain. +Spatial dependence is carried in Fourier space (d_x -> i*kx); velocity is spectral. + +Normalization (paper sec 2.1): t*wpe, x/lambda_D, v/vthe, phi*e*lambda_D/Te, so +the electric field is E = -d_x phi and the ion background density is 1. + +Integration: Lawson-RK4 (free-streaming exact for both bases via prediagonalized +symmetric-tridiagonal streaming matrices; E-field force, Dirichlet penalty, and the +Hermite->Legendre coupling are explicit). Uses the Stepper (discrete-map) convention +from adept._base_, mirroring adept._hermite_poisson_1d.vector_field. +""" + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from adept._base_ import get_envelope + +# --------------------------------------------------------------------------- +# External longitudinal field driver +# --------------------------------------------------------------------------- + + +class ExternalExDriver: + """Prescribed longitudinal field E_drive(x, t) added to the velocity-space force. + + E_drive(x, t) = sum_pulses env(x, t) (w0+dw0) a0 sin(k0 x - (w0+dw0) t) + + The driver enters only the E.d_v f force term (it drives EPWs directly, e.g. a + resonant kick for a Landau-damping measurement) and never the Poisson solve, so + the self-consistent field energy diagnostic excludes it. Mirrors the convention + of adept._hermite_poisson_1d.vector_field.LongitudinalElectricFieldDriver and of + the vlasov1d ``ex`` driver. Reads cfg["drivers"]["ex"] (pulse_name -> pulse_dict). + Output shape (Nx,), on the interior x grid. + """ + + def __init__(self, x: Array, ex_driver_cfg: dict): + self.x = x + x_last = float(x[-1]) + parsed = [] + for _, pulse in ex_driver_cfg.items() if isinstance(ex_driver_cfg, dict) else []: + if not isinstance(pulse, dict): + continue + parsed.append( + ( + float(pulse["k0"]), + float(pulse["w0"]) + float(pulse.get("dw0", 0.0)), + float(pulse["a0"]), + float(pulse.get("t_center", 0.0)), + 0.5 * float(pulse.get("t_width", 1e10)), + float(pulse.get("t_rise", 0.0)), + float(pulse.get("x_center", 0.5 * x_last)), + 0.5 * float(pulse.get("x_width", 1e10)), + float(pulse.get("x_rise", 0.0)), + ) + ) + self.parsed_pulses = parsed + + def __call__(self, t: float, args) -> Array: + total = jnp.zeros_like(self.x) + for k0, w_total, a0, t_center, t_half, t_rise, x_center, x_half, x_rise in self.parsed_pulses: + env_t = get_envelope(t_rise, t_rise, t_center - t_half, t_center + t_half, t) + env_x = get_envelope(x_rise, x_rise, x_center - x_half, x_center + x_half, self.x) + total = total + env_t * env_x * w_total * a0 * jnp.sin(k0 * self.x - w_total * t) + return total + + +# --------------------------------------------------------------------------- +# Basis constants +# --------------------------------------------------------------------------- + + +def safe_col(N: int) -> Array: + """Artificial-collision profile col[n] = n(n-1)(n-2) / ((N-1)(N-2)(N-3)). + + This is the Lenard-Bernstein-based hyper-collision spectrum of paper sec 2.5: + cubic in the mode index, normalized to 1 at n = N-1, and identically zero for + n = 0, 1, 2 so the operator conserves mass, momentum, and energy exactly. The + damping is applied per Lawson substep as exp(-nu * col[n] * s). + """ + n = jnp.arange(N, dtype=jnp.float64) + term = n * (n - 1) * (n - 2) + denom = (N - 1) * (N - 2) * (N - 3) if N > 3 else 1.0 + return jnp.where(N > 3, term / denom, jnp.zeros(N, dtype=jnp.float64)) + + +def hermite_streaming_matrix(Nh: int, u: float, alpha: float) -> np.ndarray: + """Symmetric tridiagonal T_H for AW-Hermite free streaming (paper eqn 7). + + d_t C_n + alpha*sqrt((n+1)/2) d_x C_{n+1} + alpha*sqrt(n/2) d_x C_{n-1} + + u d_x C_n = 0 ==> d_t C = -alpha * T_H * (d_x C), + with off-diagonal sqrt((n+1)/2) and diagonal u/alpha. The scalar prefactor + -i*alpha*kx is applied by StreamingExp1D. + """ + n = np.arange(Nh, dtype=np.float64) + T = np.zeros((Nh, Nh), dtype=np.float64) + if Nh > 1: + off = np.sqrt((n[:-1] + 1.0) / 2.0) # sqrt((n+1)/2), n = 0..Nh-2 + T += np.diag(off, 1) + np.diag(off, -1) + T += np.diag(np.full(Nh, u / alpha)) + return T + + +def legendre_constants(Nl: int, v_a: float, v_b: float) -> dict: + """Constants for the shifted/scaled Legendre basis xi_m on [v_a, v_b]. + + Returns: + T_L: symmetric tridiagonal velocity matrix (paper eqn 11), + off-diagonal sigma_{n+1}, diagonal sigma_bar = (v_a+v_b)/2. + sigma: sub/super-diagonal coefficients sigma_n (sigma_0 = 0). + sigma_bar: scalar (v_a+v_b)/2. + deriv: strictly-lower-triangular derivative matrix sigma_{m,i} (paper eqn 10), + d xi_m / dv = sum_{i 1: + sigma[1:] = (width / 2.0) * n[1:] / np.sqrt((2.0 * n[1:] + 1.0) * (2.0 * n[1:] - 1.0)) + + T_L = np.diag(np.full(Nl, sigma_bar)) + if Nl > 1: + off = sigma[1:] # T_L[m, m+1] = T_L[m+1, m] = sigma_{m+1} + T_L += np.diag(off, 1) + np.diag(off, -1) + + # Derivative matrix sigma_{m,i}: nonzero only when (m - i) is odd, i < m. + deriv = np.zeros((Nl, Nl), dtype=np.float64) + for m in range(Nl): + for i in range(m): + if (m - i) % 2 == 1: + deriv[m, i] = 2.0 * np.sqrt((2.0 * m + 1.0) * (2.0 * i + 1.0)) / width + + sqrt_2m1 = np.sqrt(2.0 * n + 1.0) + xi_b = sqrt_2m1 + xi_a = sqrt_2m1 * ((-1.0) ** n) + + return { + "T_L": T_L, + "sigma": jnp.asarray(sigma), + "sigma_bar": float(sigma_bar), + "deriv": jnp.asarray(deriv), + "xi_b": jnp.asarray(xi_b), + "xi_a": jnp.asarray(xi_a), + } + + +def _hermite_function_values(Nh_plus_1: int, v: np.ndarray, u: float, alpha: float) -> np.ndarray: + """psi_n(v; u, alpha) for n = 0..Nh, shape (Nh+1, len(v)). Used for J integrals. + + psi_n = (pi 2^n n!)^{-1/2} H_n(z) exp(-z^2), z = (v - u)/alpha. Built from the + *normalized* three-term recurrence + + psi_0 = pi^{-1/2} exp(-z^2), psi_1 = sqrt(2) z psi_0, + psi_{n+1} = z sqrt(2/(n+1)) psi_n - sqrt(n/(n+1)) psi_{n-1}, + + which keeps every psi_n O(1). Forming H_n and 2^n n! separately (as a naive + implementation would) overflows float64 for n >= 171, silently zeroing the + high-order coupling integrals. + """ + z = (v - u) / alpha + psi = np.zeros((Nh_plus_1, v.size), dtype=np.float64) + psi[0] = np.pi**-0.5 * np.exp(-(z**2)) + if Nh_plus_1 > 1: + psi[1] = np.sqrt(2.0) * z * psi[0] + for n in range(1, Nh_plus_1 - 1): + psi[n + 1] = z * np.sqrt(2.0 / (n + 1)) * psi[n] - np.sqrt(n / (n + 1)) * psi[n - 1] + return psi + + +def _legendre_basis_values(Nl: int, v: np.ndarray, v_a: float, v_b: float) -> np.ndarray: + """xi_m(v; v_a, v_b) = sqrt(2m+1) L_m(s), s = (2v-(v_a+v_b))/(v_b-v_a), + for m = 0..Nl-1, shape (Nl, len(v)). Built from the Legendre recurrence.""" + s = (2.0 * v - (v_a + v_b)) / (v_b - v_a) + L = np.zeros((Nl, v.size), dtype=np.float64) + L[0] = 1.0 + if Nl > 1: + L[1] = s + for k in range(1, Nl - 1): + L[k + 1] = ((2.0 * k + 1.0) * s * L[k] - k * L[k - 1]) / (k + 1.0) + scale = np.sqrt(2.0 * np.arange(Nl) + 1.0) + return scale[:, None] * L + + +def hermite_legendre_coupling_vector( + Nh: int, Nl: int, alpha: float, u: float, v_a: float, v_b: float, enforce_conservation: bool = True +) -> Array: + """J_{Nh, m} = integral_{v_a}^{v_b} psi_{Nh}(v; alpha, u) xi_m(v) dv, m = 0..Nl-1. + + This is the one-way coupling from the (closed-off) highest Hermite coefficient + C_{Nh-1} into the Legendre modes (paper eqn 24). Evaluated by Gauss-Legendre + quadrature on [v_a, v_b]. When ``enforce_conservation`` is set, J_{Nh,0..2} are + zeroed (paper sec 3.4/4): removing the coupling to the first three Legendre + coefficients makes the discrete method conserve mass, momentum, and energy + independent of Nh parity and the spectral shift/scale parameters. + """ + deg = max(4 * (Nh + Nl), 200) + nodes, weights = np.polynomial.legendre.leggauss(deg) + v = 0.5 * (v_b - v_a) * nodes + 0.5 * (v_b + v_a) + w = 0.5 * (v_b - v_a) * weights + + psi_Nh = _hermite_function_values(Nh + 1, v, u, alpha)[Nh] # (len(v),) + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + J = xi @ (psi_Nh * w) # (Nl,) + + if enforce_conservation: + J[: min(3, Nl)] = 0.0 + return jnp.asarray(J) + + +# --------------------------------------------------------------------------- +# Exact exponential operators +# --------------------------------------------------------------------------- + + +class StreamingExp1D: + """Exact free-streaming exponential exp(prefactor * kx * T * s) for one basis. + + Prediagonalizes the symmetric tridiagonal streaming matrix T (Hermite or + Legendre) once, then applies exp(L*s) to a coefficient array (Nmodes, Nx) by + rotating into the eigenbasis, scaling by exp(prefactor * kx * eigval * s), and + rotating back. Mirrors adept._hermite_poisson_1d.vector_field.FreeStreamingExp1D + (Hermite: prefactor = -i*alpha; Legendre: prefactor = -i). + """ + + def __init__(self, T: np.ndarray, prefactor: complex, kx_1d: Array): + self.prefactor = prefactor + self.kx_1d = kx_1d + eigvals, V = np.linalg.eigh(T) + self.V = jnp.asarray(V) + self.eigenvalues = jnp.asarray(eigvals) + + def apply(self, Ck: Array, s: float) -> Array: + C_eig = self.V.T @ Ck # (Nmodes, Nx) + exp_fac = jnp.exp(self.prefactor * s * self.eigenvalues[:, None] * self.kx_1d[None, :]) + return self.V @ (C_eig * exp_fac) + + +class DiagonalCollisionExp1D: + """Exact exponential exp(-nu * col[n] * s), diagonal in the mode index.""" + + def __init__(self, nu: float, col: Array): + self.nu = float(nu) + self.col = col + + def apply(self, Ck: Array, s: float) -> Array: + if self.nu == 0.0: + return Ck + return Ck * jnp.exp(-self.nu * self.col[:, None] * s) + + +class CombinedLinearExp1D: + """Applies exp(L*s) to the full {Ck, Bk} state; real diagnostics pass through.""" + + def __init__( + self, + hermite_stream: StreamingExp1D, + legendre_stream: StreamingExp1D, + hermite_coll: DiagonalCollisionExp1D, + legendre_coll: DiagonalCollisionExp1D, + ): + self.hermite_stream = hermite_stream + self.legendre_stream = legendre_stream + self.hermite_coll = hermite_coll + self.legendre_coll = legendre_coll + + def apply(self, state: dict, s: float) -> dict: + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + Ck = self.hermite_coll.apply(self.hermite_stream.apply(Ck, s), s) + Bk = self.legendre_coll.apply(self.legendre_stream.apply(Bk, s), s) + out = dict(state) + out["Ck"] = Ck.view(jnp.float64) + out["Bk"] = Bk.view(jnp.float64) + return out + + +# --------------------------------------------------------------------------- +# Poisson solver +# --------------------------------------------------------------------------- + + +class PoissonSolver1D: + """Spectral Poisson solve for the mixed-method charge density (paper eqn 25). + + rho(x) = 1 - alpha*C_0(x) - (v_b - v_a)*B_0(x) (immobile ion background = 1) + -d_x^2 phi = rho ==> phi_k = rho_k / kx^2, E = -d_x phi ==> E_k = -i rho_k / kx. + The k=0 component is set to zero (quasineutral domain). + """ + + def __init__(self, one_over_kx: Array, kx_sq: Array, alpha: float, width: float): + self.one_over_kx = one_over_kx + self.kx_sq = kx_sq + self.alpha = float(alpha) + self.width = float(width) + + def _rho_k(self, Ck: Array, Bk: Array) -> Array: + # C_0(x), B_0(x) from their Fourier representations (norm="forward" -> [0] is mean) + n_f0 = self.alpha * jnp.fft.ifft(Ck[0], norm="forward").real + n_df = self.width * jnp.fft.ifft(Bk[0], norm="forward").real + rho = 1.0 - n_f0 - n_df + return jnp.fft.fft(rho, norm="forward") + + def electric_field(self, Ck: Array, Bk: Array) -> Array: + E_k = -1j * self.one_over_kx * self._rho_k(Ck, Bk) + return jnp.fft.ifft(E_k, norm="forward").real + + def potential(self, Ck: Array, Bk: Array) -> Array: + rho_k = self._rho_k(Ck, Bk) + phi_k = jnp.where(self.kx_sq > 0, rho_k / jnp.where(self.kx_sq > 0, self.kx_sq, 1.0), 0.0) + return jnp.fft.ifft(phi_k, norm="forward").real + + +# --------------------------------------------------------------------------- +# Explicit nonlinear terms (E-field force, penalty, Hermite->Legendre coupling) +# --------------------------------------------------------------------------- + + +def _to_real(Ck: Array, mask23: Array | None) -> Array: + """IFFT (Nmodes, Nx) k-space coeffs to real space, optionally 2/3-dealiased.""" + if mask23 is not None: + Ck = Ck * mask23[None, :] + return jnp.fft.ifft(Ck, axis=-1, norm="forward") + + +def _hermite_force(Ck: Array, E: Array, sqrt_2n_over_alpha: Array, mask23: Array | None) -> Array: + """d_t C_n |force = -(sqrt(2n)/alpha) E(x) C_{n-1}(x) (electron, paper eqn 19/23). + + E = -d_x phi is the self-consistent field; the n-th coefficient is forced by the + (n-1)-th (Hermite differentiation raises the index). Returns k-space (Nh, Nx). + """ + Nh, Nx = Ck.shape + if mask23 is not None: + E = jnp.fft.ifft(jnp.fft.fft(E, norm="forward") * mask23, norm="forward").real + C = _to_real(Ck, mask23) + C_up = jnp.concatenate([jnp.zeros((1, Nx), dtype=C.dtype), C[:-1, :]], axis=0) # C_{n-1} + integrand = -sqrt_2n_over_alpha[:, None] * E[None, :] * C_up + out = jnp.fft.fft(integrand, axis=-1, norm="forward") + if mask23 is not None: + out = out * mask23[None, :] + return out + + +def _legendre_force( + Bk: Array, + E: Array, + deriv: Array, + gamma_vec: Array, + xi_a: Array, + xi_b: Array, + width: float, + mask23: Array | None, +) -> Array: + """d_t B_m |force = -E(x) [ sum_{i Array: + """Hermite -> Legendre coupling (paper eqn 24, RHS): + + d_t B_m |coupling = -(alpha/width) J_{Nh,m} sqrt(Nh/2) [ d_x C_{Nh-1} + (2/alpha^2) E C_{Nh-1} ]. + + The d_x C_{Nh-1} part is linear (i*kx in Fourier); the E*C_{Nh-1} part is the + nonlinear product. Returns k-space (Nl, Nx). coupling_vec already folds in + -(alpha/width) sqrt(Nh/2) J_{Nh,m} (and the J_{Nh,0..2}=0 conservation gate). + """ + Nh = Ck.shape[0] + C_last_k = Ck[Nh - 1] # (Nx,) k-space + dx_C = 1j * kx_1d * C_last_k # d_x C_{Nh-1} in k-space + + C_last_real = _to_real(Ck[Nh - 1 : Nh], mask23)[0] # (Nx,) + if mask23 is not None: + E = jnp.fft.ifft(jnp.fft.fft(E, norm="forward") * mask23, norm="forward").real + nl_real = (2.0 / alpha**2) * E * C_last_real + nl_k = jnp.fft.fft(nl_real, norm="forward") + if mask23 is not None: + nl_k = nl_k * mask23 + + bracket = dx_C + nl_k # (Nx,) k-space + return coupling_vec[:, None] * bracket[None, :] # (Nl, Nx) + + +# --------------------------------------------------------------------------- +# Implicit (IMEX) force operators +# --------------------------------------------------------------------------- + + +def hermite_force_operator(Nh: int, alpha: float) -> np.ndarray: + """G_C with (G_C C)_n = -sqrt(2n)/alpha C_{n-1} (strictly lower-bidiagonal). + + The Hermite Lorentz force is dC/dt|force = E(x) G_C C. G_C is nilpotent, with + operator norm ~ sqrt(2 Nh)/alpha * |E|, so explicit RK4 has a CFL-like limit + that tightens with Nh. Backward Euler on it is unconditionally stable. + """ + G = np.zeros((Nh, Nh), dtype=np.float64) + n = np.arange(1, Nh) + G[n, n - 1] = -np.sqrt(2.0 * n) / alpha + return G + + +def legendre_force_operator(deriv: Array, gamma_vec: Array, xi_a: Array, xi_b: Array, width: float) -> np.ndarray: + """G_B = P - D for the Legendre Lorentz force dB/dt|force = E(x) G_B B. + + D = deriv (strictly lower-triangular d_v matrix, paper eqn 10) and P is the + rank-2 Dirichlet penalty P[m, j] = (gamma_m/width)(xi_b[m] xi_b[j] - xi_a[m] xi_a[j]). + The d_v operator norm scales as ~ Nl^2/width, making this the *dominant* + stiffness in the mixed method -- hence it is treated implicitly in IMEX. + """ + D = np.asarray(deriv) + g = np.asarray(gamma_vec) / width + xa, xb = np.asarray(xi_a), np.asarray(xi_b) + P = g[:, None] * (np.outer(xb, xb) - np.outer(xa, xa)) + return P - D + + +# --------------------------------------------------------------------------- +# Lawson-RK4 vector field (Stepper / discrete-map convention) +# --------------------------------------------------------------------------- + + +def _tree_add(a: dict, b: dict) -> dict: + return jax.tree.map(lambda x, y: x + y, a, b) + + +def _tree_scale(a: dict, c: float) -> dict: + return jax.tree.map(lambda x: c * x, a) + + +class HermiteLegendre1DVectorField: + """Advances the mixed Hermite-Legendre state by one timestep dt. + + State dict (complex arrays stored as float64 views, like _hermite_poisson_1d): + Ck: (Nh, Nx) Hermite-Fourier coefficients of f0 + Bk: (Nl, Nx) Legendre-Fourier coefficients of df + e: (Nx,) electric field (diagnostic) + phi: (Nx,) potential (diagnostic) + + Called by adept._base_.Stepper, which treats the return value as the new state. + """ + + def __init__( + self, + combined_exp: CombinedLinearExp1D, + poisson: PoissonSolver1D, + kx_1d: Array, + sqrt_2n_over_alpha: Array, + deriv: Array, + gamma_vec: Array, + xi_a: Array, + xi_b: Array, + coupling_vec: Array, + alpha: float, + width: float, + dt: float, + mask23: Array | None = None, + field_on: bool = True, + ex_driver: "ExternalExDriver | None" = None, + imex: bool = False, + G_C: Array | None = None, + G_B: Array | None = None, + implicit: bool = False, + T_H: Array | None = None, + T_L: Array | None = None, + col_e: Array | None = None, + col_l: Array | None = None, + nu_H: float = 0.0, + nu_L: float = 0.0, + newton_iters: int = 3, + gmres_restart: int = 20, + gmres_maxiter: int = 4, + gmres_tol: float = 1e-8, + precondition: bool = True, + ): + self.combined_exp = combined_exp + self.poisson = poisson + self.kx_1d = kx_1d + self.sqrt_2n_over_alpha = sqrt_2n_over_alpha + self.deriv = deriv + self.gamma_vec = gamma_vec + self.xi_a = xi_a + self.xi_b = xi_b + self.coupling_vec = coupling_vec + self.alpha = float(alpha) + self.width = float(width) + self.dt = float(dt) + self.mask23 = mask23 + self.field_on = bool(field_on) + self.ex_driver = ex_driver + # IMEX: treat the (stiff) E.d_v f Lorentz force implicitly via a frozen-E + # backward-Euler substep, keeping the rest of the RHS in the explicit Lawson step. + self.imex = bool(imex) + self.G_C = None if G_C is None else jnp.asarray(G_C, dtype=jnp.complex128) + self.G_B = None if G_B is None else jnp.asarray(G_B, dtype=jnp.complex128) + # Implicit midpoint (AD-JFNK): needs the raw streaming/collision RHS operators. + self.implicit = bool(implicit) + self.T_H = None if T_H is None else jnp.asarray(T_H) + self.T_L = None if T_L is None else jnp.asarray(T_L) + self.col_e = None if col_e is None else jnp.asarray(col_e) + self.col_l = None if col_l is None else jnp.asarray(col_l) + self.nu_H = float(nu_H) + self.nu_L = float(nu_L) + self.newton_iters = int(newton_iters) + self.gmres_restart = int(gmres_restart) + self.gmres_maxiter = int(gmres_maxiter) + self.gmres_tol = float(gmres_tol) + self.precondition = bool(precondition) + if self.implicit and self.precondition: + self._setup_stream_preconditioner() + + def _setup_stream_preconditioner(self) -> None: + """Precompute the per-k tridiagonal bands of M = I - (dt/2)(L_stream + L_coll). + + L_stream is block-diagonal in k and tridiagonal in mode (the streaming matrices + T_H, T_L are tridiagonal), so M^{-1} is a batched tridiagonal solve. Including the + diagonal collision term improves conditioning for collisional runs. This M + captures the stiff (imaginary) streaming spectrum that otherwise cripples GMRES. + """ + half = 0.5 * self.dt + + def bands(T, prefac, nu, col): + T = jnp.asarray(T) + diagT = jnp.diagonal(T) # (N,) + off = jnp.diagonal(T, offset=1) # (N-1,) + N = diagT.shape[0] + c = prefac * half * self.kx_1d # (Nx,) ; prefac = -1 * (-i alpha) etc. -> see below + # M = I + (dt/2)*(i*coef*kx*T) + (dt/2)*nu*col (coef = alpha for H, 1 for L) + d = 1.0 + c[:, None] * diagT[None, :] + half * nu * col[None, :] + sub = jnp.concatenate([jnp.zeros(1, dtype=off.dtype), off]) # sub[n]=T[n,n-1] + sup = jnp.concatenate([off, jnp.zeros(1, dtype=off.dtype)]) # sup[n]=T[n,n+1] + dl = c[:, None] * sub[None, :] + du = c[:, None] * sup[None, :] + return dl.astype(jnp.complex128), d.astype(jnp.complex128), du.astype(jnp.complex128) + + # L_stream C = -i*alpha*kx*(T_H C) -> M_H = I + i*(dt/2)*alpha*kx*T_H + self._pc_h = bands(self.T_H, 1j * self.alpha, self.nu_H, self.col_e) + # L_stream B = -i*kx*(T_L B) -> M_L = I + i*(dt/2)*kx*T_L + self._pc_l = bands(self.T_L, 1j, self.nu_L, self.col_l) + + def _stream_precond_apply(self, v: dict) -> dict: + """Apply M^{-1} (streaming+collision preconditioner) to a real (re,im) pytree.""" + Ck = v["Cr"] + 1j * v["Ci"] # (Nh, Nx) + Bk = v["Br"] + 1j * v["Bi"] # (Nl, Nx) + dl_h, d_h, du_h = self._pc_h + dl_l, d_l, du_l = self._pc_l + x = jax.lax.linalg.tridiagonal_solve(dl_h, d_h, du_h, Ck.T[..., None])[..., 0].T + y = jax.lax.linalg.tridiagonal_solve(dl_l, d_l, du_l, Bk.T[..., None])[..., 0].T + return {"Cr": x.real, "Ci": x.imag, "Br": y.real, "Bi": y.imag} + + def _force_precond_apply(self, v: dict, sC: Array) -> dict: + """Apply (I - dt/2 E0 G_force)^{-1} (the Lorentz-force preconditioner), per x. + + sC = (dt/2) E0(x) is the frozen-field scale. The Hermite force operator is + lower-bidiagonal -> a per-x tridiagonal (forward-substitution) solve; the + Legendre force is dominated by the derivative matrix D (norm ~Nl^2/width), a + per-x lower-triangular solve. The rank-2 Dirichlet penalty is left to GMRES (a + preconditioner need not be exact). This is the piece that becomes stiff at + saturation, where streaming-only preconditioning stalls. + """ + Ck = v["Cr"] + 1j * v["Ci"] # (Nh, Nx) + Bk = v["Br"] + 1j * v["Bi"] # (Nl, Nx) + Nh, Nx = Ck.shape + Nl = Bk.shape[0] + sC_c = sC.astype(jnp.complex128) + + # Hermite: (I - sC G_C) is lower-bidiagonal; (I-sC G_C)[n,n-1] = sC*sqrt(2n)/alpha + dl = sC_c[:, None] * self.sqrt_2n_over_alpha[None, :].astype(jnp.complex128) # (Nx, Nh) + dl = dl.at[:, 0].set(0.0) + d = jnp.ones((Nx, Nh), dtype=jnp.complex128) + du = jnp.zeros((Nx, Nh), dtype=jnp.complex128) + Cn = jax.lax.linalg.tridiagonal_solve(dl, d, du, Ck.T[..., None])[..., 0].T + + # Legendre: (I + sC D) lower-triangular (unit diagonal); D = deriv matrix + A0 = jnp.eye(Nl, dtype=jnp.complex128)[None] + sC_c[:, None, None] * self.deriv.astype(jnp.complex128)[None] + Bn = jax.lax.linalg.triangular_solve(A0, Bk.T[..., None], left_side=True, lower=True)[..., 0].T + return {"Cr": Cn.real, "Ci": Cn.imag, "Br": Bn.real, "Bi": Bn.imag} + + def _nonlinear_rhs(self, t: float, state: dict, args: dict) -> dict: + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + + # field_on=False -> pure advection (phi=0); the linear Hermite->Legendre + # closure flux (d_x C_{Nh-1}) still acts, only E-dependent terms vanish. + E = self.poisson.electric_field(Ck, Bk) if self.field_on else jnp.zeros(Ck.shape[1]) # (Nx,) + # external longitudinal driver (evaluated at the substep time; never enters Poisson) + if self.ex_driver is not None: + E = E + self.ex_driver(t, args) + + cross = _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + if self.imex: + # E.d_v f (Hermite + Legendre force) is handled by the implicit substep; + # only the non-stiff Hermite->Legendre closure flux stays explicit. + dCk = jnp.zeros_like(Ck) + dBk = cross + else: + dCk = _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) + dBk = ( + _legendre_force(Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23) + + cross + ) + + out = dict(state) + out["Ck"] = dCk.view(jnp.float64) + out["Bk"] = dBk.view(jnp.float64) + for k in ("e", "phi", "de"): + if k in out: + out[k] = jnp.zeros_like(state[k]) + return out + + def _lawson_rk4(self, t: float, state: dict, args: dict) -> dict: + dt = self.dt + exp_L = self.combined_exp.apply + + Eh_y = exp_L(state, dt / 2) + Ef_y = exp_L(state, dt) + + N1 = self._nonlinear_rhs(t, state, args) + Eh_N1 = exp_L(N1, dt / 2) + y_star = _tree_add(Eh_y, _tree_scale(Eh_N1, dt / 2)) + N2 = self._nonlinear_rhs(t + dt / 2, y_star, args) + + y_dstar = _tree_add(Eh_y, _tree_scale(N2, dt / 2)) + N3 = self._nonlinear_rhs(t + dt / 2, y_dstar, args) + + Eh_N3 = exp_L(N3, dt / 2) + y_tstar = _tree_add(Ef_y, _tree_scale(Eh_N3, dt)) + N4 = self._nonlinear_rhs(t + dt, y_tstar, args) + + Ef_N1 = exp_L(N1, dt) + Eh_N2 = exp_L(N2, dt / 2) + + weighted = _tree_scale( + _tree_add( + _tree_add(Ef_N1, _tree_scale(Eh_N2, 2.0)), + _tree_add(_tree_scale(Eh_N3, 2.0), N4), + ), + dt / 6.0, + ) + return _tree_add(Ef_y, weighted) + + def _full_rhs_complex(self, t: float, Ck: Array, Bk: Array) -> tuple: + """Raw RHS dy/dt = streaming + collisions + Lorentz force + closure coupling. + + Unlike the Lawson path (which integrates streaming/collisions exactly via the + matrix exponential), the implicit-midpoint integrator needs the explicit RHS + operators: streaming d_t C = -i alpha kx (T_H C), d_t B = -i kx (T_L B), the + diagonal collision -nu col[n], and the (reused) E.d_v f force + coupling terms. + """ + E = self.poisson.electric_field(Ck, Bk) if self.field_on else jnp.zeros(Ck.shape[1]) + if self.ex_driver is not None: + E = E + self.ex_driver(t, None) + + dCk = -1j * self.alpha * self.kx_1d[None, :] * (self.T_H @ Ck) - self.nu_H * self.col_e[:, None] * Ck + dBk = -1j * self.kx_1d[None, :] * (self.T_L @ Bk) - self.nu_L * self.col_l[:, None] * Bk + + dCk = dCk + _hermite_force(Ck, E, self.sqrt_2n_over_alpha, self.mask23) + dBk = ( + dBk + + _legendre_force(Bk, E, self.deriv, self.gamma_vec, self.xi_a, self.xi_b, self.width, self.mask23) + + _cross_coupling(Ck, E, self.kx_1d, self.coupling_vec, self.alpha, self.width, self.mask23) + ) + return dCk, dBk + + def _implicit_midpoint_solve(self, t: float, Ck0: Array, Bk0: Array) -> tuple: + """One implicit-midpoint step via Jacobian-free Newton-Krylov (AD-JFNK). + + Solves y1 = y0 + dt F((y0+y1)/2). Implicit midpoint is A-stable and conserves + quadratic invariants (energy), so it has no CFL limit and no spurious energy + growth -- needed for the saturated/long-time regimes where the explicit and + IMEX paths blow up. The Newton linear solves use a matrix-free GMRES whose + Jacobian-vector products are EXACT autodiff JVPs (jax.linearize) -- the Jacobian + is never formed. Complex coefficients are carried as real (re, im) pytree leaves + so the Krylov inner products are the standard real ones. + """ + import jax.scipy.sparse.linalg as jsla + + dt = self.dt + t_mid = t + 0.5 * dt + y0 = {"Cr": Ck0.real, "Ci": Ck0.imag, "Br": Bk0.real, "Bi": Bk0.imag} + + # Frozen-field scale for the force preconditioner (E at the step-start state). + E0 = self.poisson.electric_field(Ck0, Bk0) if self.field_on else jnp.zeros(Ck0.shape[1]) + if self.ex_driver is not None: + E0 = E0 + self.ex_driver(t_mid, None) + sC = 0.5 * dt * E0 + + def rhs(yr): + Ck = yr["Cr"] + 1j * yr["Ci"] + Bk = yr["Br"] + 1j * yr["Bi"] + dCk, dBk = self._full_rhs_complex(t_mid, Ck, Bk) + return {"Cr": dCk.real, "Ci": dCk.imag, "Br": dBk.real, "Bi": dBk.imag} + + def residual(y1): + y_mid = jax.tree.map(lambda a, b: 0.5 * (a + b), y0, y1) + f = rhs(y_mid) + return jax.tree.map(lambda a, b, ff: a - b - dt * ff, y1, y0, f) + + # Combined preconditioner M^{-1} = M_force^{-1} . M_stream^{-1}: streaming + # (per-k tridiagonal) handles the linear/growth phase, the force part handles + # saturation where the Lorentz term dominates the Jacobian. + if self.precondition: + def M(vv): + return self._force_precond_apply(self._stream_precond_apply(vv), sC) + else: + M = None + y1 = y0 + for _ in range(self.newton_iters): + r, jvp_fn = jax.linearize(residual, y1) # jvp_fn(v) = J @ v, exact (AD) + neg_r = jax.tree.map(jnp.negative, r) + delta, _ = jsla.gmres( + jvp_fn, neg_r, M=M, tol=self.gmres_tol, atol=0.0, restart=self.gmres_restart, maxiter=self.gmres_maxiter + ) + y1 = jax.tree.map(jnp.add, y1, delta) + + return y1["Cr"] + 1j * y1["Ci"], y1["Br"] + 1j * y1["Bi"] + + def _implicit_E_substep(self, state: dict, E_real: Array, dt: float) -> dict: + """Backward-Euler substep for the E.d_v f Lorentz force with frozen E. + + Per spatial point x the force is dC/dt = E(x) G_C C and dB/dt = E(x) G_B B + (block-diagonal: Hermite force touches only C, Legendre only B). Backward + Euler gives (I - dt E(x) G) X_new = X, an unconditionally stable per-x linear + solve (G_C is nilpotent; G_B is lower-triangular + a rank-2 penalty). E(x) is + diagonal in real space, so we transform k->x, solve per x, and transform back. + + NOTE: this uses a dense per-x solve, O(Nx * N^3). Fine for moderate Nx, but for + large Nx the structured solve is far cheaper: a bidiagonal forward-substitution + for the nilpotent Hermite block (cf. _spectrax1d.imex_E) and a Woodbury solve + (lower-triangular + rank-2) for the Legendre block, both O(Nx * N^2). + """ + mask = self.mask23 + maskc = mask[None, :] if mask is not None else 1.0 + Ck = state["Ck"].view(jnp.complex128) + Bk = state["Bk"].view(jnp.complex128) + Nh, Nl = Ck.shape[0], Bk.shape[0] + + C = jnp.fft.ifft(Ck * maskc, axis=-1, norm="forward") # (Nh, Nx) + B = jnp.fft.ifft(Bk * maskc, axis=-1, norm="forward") # (Nl, Nx) + + scale = (dt * E_real).astype(jnp.complex128) # (Nx,) + M_C = jnp.eye(Nh, dtype=jnp.complex128)[None] - scale[:, None, None] * self.G_C[None] # (Nx, Nh, Nh) + M_B = jnp.eye(Nl, dtype=jnp.complex128)[None] - scale[:, None, None] * self.G_B[None] # (Nx, Nl, Nl) + + C_new = jnp.linalg.solve(M_C, C.T[..., None])[..., 0].T # (Nh, Nx) + B_new = jnp.linalg.solve(M_B, B.T[..., None])[..., 0].T # (Nl, Nx) + + Ck_new = jnp.fft.fft(C_new, axis=-1, norm="forward") * maskc + Bk_new = jnp.fft.fft(B_new, axis=-1, norm="forward") * maskc + out = dict(state) + out["Ck"] = Ck_new.view(jnp.float64) + out["Bk"] = Bk_new.view(jnp.float64) + return out + + def __call__(self, t: float, y: dict, args: dict) -> dict: + if self.implicit: + Ck0 = y["Ck"].view(jnp.complex128) + Bk0 = y["Bk"].view(jnp.complex128) + Ck, Bk = self._implicit_midpoint_solve(t, Ck0, Bk0) + if self.field_on: + e = self.poisson.electric_field(Ck, Bk) + phi = self.poisson.potential(Ck, Bk) + else: + e = jnp.zeros(Ck.shape[1]) + phi = jnp.zeros(Ck.shape[1]) + de = self.ex_driver(t, args) if self.ex_driver is not None else jnp.zeros(Ck.shape[1]) + return {"Ck": Ck.view(jnp.float64), "Bk": Bk.view(jnp.float64), "e": e, "phi": phi, "de": de} + + y_new = self._lawson_rk4(t, y, args) + + if self.imex: + # Frozen E from the post-explicit state (self-consistent + external driver + # at the end of the step), then one implicit Lorentz substep. + Ckp = y_new["Ck"].view(jnp.complex128) + Bkp = y_new["Bk"].view(jnp.complex128) + E_frozen = self.poisson.electric_field(Ckp, Bkp) if self.field_on else jnp.zeros(Ckp.shape[1]) + if self.ex_driver is not None: + E_frozen = E_frozen + self.ex_driver(t + self.dt, args) + y_new = self._implicit_E_substep(y_new, E_frozen, self.dt) + + Ck = y_new["Ck"].view(jnp.complex128) + Bk = y_new["Bk"].view(jnp.complex128) + if self.field_on: + e = self.poisson.electric_field(Ck, Bk) + phi = self.poisson.potential(Ck, Bk) + else: + e = jnp.zeros(Ck.shape[1]) # phi = 0 : pure advection + phi = jnp.zeros(Ck.shape[1]) + de = self.ex_driver(t, args) if self.ex_driver is not None else jnp.zeros(Ck.shape[1]) + return {"Ck": y_new["Ck"], "Bk": y_new["Bk"], "e": e, "phi": phi, "de": de} diff --git a/adept/hermite_legendre_1d.py b/adept/hermite_legendre_1d.py new file mode 100644 index 00000000..967ba447 --- /dev/null +++ b/adept/hermite_legendre_1d.py @@ -0,0 +1,5 @@ +"""Public entry point for the mixed Hermite-Legendre 1D ADEPT module.""" + +from ._hermite_legendre_1d.modules import BaseHermiteLegendre1D + +__all__ = ["BaseHermiteLegendre1D"] diff --git a/configs/hermite-legendre-1d/bump-on-tail.yaml b/configs/hermite-legendre-1d/bump-on-tail.yaml new file mode 100644 index 00000000..98425142 --- /dev/null +++ b/configs/hermite-legendre-1d/bump-on-tail.yaml @@ -0,0 +1,67 @@ +# Mixed Hermite-Legendre 1D --- Bump-on-tail instability (paper sec 4.4) +# +# f0(x,v,0) = (1-n_b)/sqrt(2*pi) [1 + eps cos(x/10)] exp(-v^2/2) +# df(x,v,0) = n_b/sqrt(2*pi) exp(-(v-10)^2/2) (beam projected onto Legendre) +# The beam is well separated from the bulk, so f0<->df coupling is weak and a large +# Hermite collision rate nu_H = 10 is acceptable. Parameters match Chapurin et al. +# +# This is the most demanding benchmark (asymmetric velocity window, long time). The +# explicit and IMEX integrators both blow up near saturation; it requires the A-stable, +# energy-conserving implicit-midpoint integrator (integrator: implicit) with the combined +# streaming+force preconditioner. At saturation the implicit nonlinear solve is hard, so a +# small step is needed for Newton/GMRES to converge (dt <= ~0.02; the paper uses dt=0.01). +# A larger dt (e.g. 0.1) converges in the linear phase but fails at saturation -- the +# learned/ML preconditioner is the planned route to large-dt robustness here. + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: bump-on-tail + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 62.83185307179586 # 20*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: 4.0 + v_b: 15.0 + gamma: 0.5 + nu_H: 10.0 + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 128 + Nh: 128 + Nl: 128 + tmax: 120.0 + dt: 0.02 # small step so the saturation-phase JFNK solve converges + integrator: implicit # A-stable implicit midpoint; conserves mass exactly, energy to solve tol + precondition: true # combined streaming+force preconditioner + newton_iters: 4 + gmres_restart: 30 + gmres_maxiter: 3 + +initialization: + type: bump-on-tail + eps: 1.0e-4 + mode: 1 + n_beam: 0.01 + v_drift: 10.0 + v_th: 1.0 + +save: + fields: + t: + nt: 481 + hermite: + t: + nt: 61 + legendre: + t: + nt: 61 diff --git a/configs/hermite-legendre-1d/linear-advection.yaml b/configs/hermite-legendre-1d/linear-advection.yaml new file mode 100644 index 00000000..933cd89b --- /dev/null +++ b/configs/hermite-legendre-1d/linear-advection.yaml @@ -0,0 +1,51 @@ +# Mixed Hermite-Legendre 1D --- Linear advection benchmark (paper sec 4.2) +# +# Pure advection (phi = 0): d_t f + v d_x f = 0, analytic f(x,v,t)=f(x-vt,v,0). +# f(x,v,0) = (1+cos x)/sqrt(2*pi) exp(-v^2/2) -> C_0(x,0)=(1+cos x)/sqrt(2), df=0. +# The Legendre representation of df grows to compensate Hermite recurrence (Fig 3), +# and the recurrence period matches the better of the two bases (Fig 4). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: linear-advection + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 6.283185307179586 # 2*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -5.0 + v_b: 5.0 + gamma: 0.5 + nu_H: 0.0 + nu_L: 0.0 + enforce_conservation: true + field: false # phi = 0 : pure linear advection + +grid: + Nx: 64 + Nh: 100 + Nl: 100 + tmax: 20.0 + dt: 0.01 + +initialization: + type: linear-advection + eps: 1.0 + mode: 1 + +save: + fields: + t: + nt: 201 + hermite: + t: + nt: 41 + legendre: + t: + nt: 41 diff --git a/configs/hermite-legendre-1d/two-stream-imex.yaml b/configs/hermite-legendre-1d/two-stream-imex.yaml new file mode 100644 index 00000000..9ad59a0c --- /dev/null +++ b/configs/hermite-legendre-1d/two-stream-imex.yaml @@ -0,0 +1,52 @@ +# Mixed Hermite-Legendre 1D --- Two-stream instability, IMEX integrator (paper sec 4.3) +# +# Same physics as two-stream.yaml, but the stiff E.d_v f Lorentz force is advanced +# with an unconditionally stable frozen-E Backward-Euler substep (integrator: imex). +# This removes the explicit CFL limit, so the step can be ~5-10x larger than the +# fully-explicit two-stream.yaml (dt=0.002). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: two-stream-imex + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 12.566370614359172 # 4*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 64 + Nh: 85 + Nl: 171 + tmax: 35.0 + dt: 0.02 + integrator: imex # Lawson-RK4 + implicit (Backward-Euler) Lorentz substep + +initialization: + type: two-stream + eps: 0.01 + mode: 1 + +save: + fields: + t: + nt: 351 + hermite: + t: + nt: 71 + legendre: + t: + nt: 71 diff --git a/configs/hermite-legendre-1d/two-stream.yaml b/configs/hermite-legendre-1d/two-stream.yaml new file mode 100644 index 00000000..5577b80b --- /dev/null +++ b/configs/hermite-legendre-1d/two-stream.yaml @@ -0,0 +1,54 @@ +# Mixed Hermite-Legendre 1D --- Two-stream instability (paper sec 4.3) +# +# f(x,v,0) = (1+0.01 cos(0.5 x))/sqrt(2*pi) v^2 exp(-v^2/2) +# -> C_0(x,0)=(1+0.01 cos(0.5 x))/sqrt(2), C_2 = sqrt(2) C_0, df=0. +# The phase-space vortex (strongly non-Maxwellian) is captured by the Legendre +# component; with J_{Nh,0..2}=0 enforced, mass/momentum/energy are conserved (Fig 7c/d). + +solver: hermite-legendre-1d + +mlflow: + experiment: hermite-legendre-1d + run: two-stream + +units: + normalizing_density: 1e20/cc + normalizing_temperature: 1keV + +physics: + Lx: 12.566370614359172 # 4*pi + alpha: 1.4142135623730951 # sqrt(2) + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 # keep nu_H -> 0 so f0 feeds df via the last Hermite moment + nu_L: 1.0 + enforce_conservation: true + field: true + +grid: + Nx: 64 + Nh: 85 + Nl: 171 + tmax: 35.0 + # Explicit Lawson-RK4 has a stability (CFL) limit once the field saturates; for + # these parameters dt <= ~0.0025 is required (the paper's dt=0.01 relies on its + # unconditionally stable implicit-midpoint integrator). dt=0.002 is converged. + dt: 0.002 + +initialization: + type: two-stream + eps: 0.01 + mode: 1 + +save: + fields: + t: + nt: 351 + hermite: + t: + nt: 71 + legendre: + t: + nt: 71 diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index f1bebac6..1cbd2b37 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -17,3 +17,4 @@ Quick links to configuration references: - [Vlasov-2D Config](source/solvers/vlasov2d/config.md) - [LPSE-2D Config](source/solvers/lpse2d/config.md) - [Spectrax-1D Config](source/solvers/spectrax1d/config.md) +- [Hermite-Legendre-1D Config](source/solvers/hermite_legendre_1d/config.md) diff --git a/docs/RUNNING_A_SIM.md b/docs/RUNNING_A_SIM.md index ae50b6b3..7c834954 100644 --- a/docs/RUNNING_A_SIM.md +++ b/docs/RUNNING_A_SIM.md @@ -12,5 +12,6 @@ uv run run.py --cfg path_to_my_config - [Vlasov-2D](source/solvers/vlasov2d/config.md) - 2D2V Vlasov-Maxwell solver - [LPSE-2D (Envelope-2D)](source/solvers/lpse2d/config.md) - 2D laser-plasma envelope solver for TPD/SRS - [Spectrax-1D](source/solvers/spectrax1d/config.md) - 1D Hermite-Fourier Vlasov-Maxwell solver +- [Hermite-Legendre-1D](source/solvers/hermite_legendre_1d/config.md) - 1D-1V mixed Hermite-Legendre electrostatic Vlasov-Poisson solver See the [full documentation](https://ergodicio.github.io/adept/) for detailed guides and API reference. diff --git a/docs/source/solvers/hermite_legendre_1d/config.md b/docs/source/solvers/hermite_legendre_1d/config.md new file mode 100644 index 00000000..70c9c1cb --- /dev/null +++ b/docs/source/solvers/hermite_legendre_1d/config.md @@ -0,0 +1,213 @@ +# Mixed Hermite-Legendre 1D Configuration Reference + +This document describes how to construct a configuration file for the +`hermite-legendre-1d` solver, which implements the **mixed Hermite-Legendre +spectral method** for the 1D-1V electrostatic Vlasov-Poisson system (Issan, +Delzanno & Roytershteyn, arXiv:2606.12322). + +The electron distribution is split `f = f0 + df`: + +- `f0` (near-Maxwellian bulk) is expanded in the **asymmetrically-weighted (AW) + Hermite** basis in velocity, with coefficients `C_n(x, t)`, `n = 0 .. Nh-1`. +- `df` (strongly non-Maxwellian features: beams, plateaus, filamentation) is + expanded in the **Legendre** basis on a bounded velocity window `[v_a, v_b]`, + with coefficients `B_m(x, t)`, `m = 0 .. Nl-1`. + +The highest Hermite coefficient `C_{Nh-1}` feeds the Legendre modes (one-way +coupling), and both feed the self-consistent field through Poisson. The method is +most accurate, at fixed total velocity DOFs, when non-Maxwellian features are +localized in velocity. + +**Normalization** (paper sec 2.1): time by `1/ω_pe`, space by the Debye length +`λ_D`, velocity by the electron thermal velocity `v_the`. A single electron species +is evolved against an immobile neutralizing ion background of density 1. + +**Numerics.** Space is treated spectrally (Fourier, periodic domain); both +free-streaming operators are symmetric-tridiagonal in mode index and integrated +*exactly* via prediagonalized matrix exponentials. The E-field force, the Legendre +Dirichlet penalty, and the Hermite→Legendre coupling are advanced explicitly with +**Lawson-RK4**. (The paper uses an implicit-midpoint integrator for machine-precision +energy conservation; this module uses an explicit integrator — energy is then +conserved to the time-integrator's order, which converges with `dt`, while mass and +momentum remain conserved to machine precision.) + +## Top-Level Structure + +```yaml +solver: hermite-legendre-1d +mlflow: ... +units: ... +physics: ... +grid: ... +initialization: ... +save: ... +``` + +--- + +## physics + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Lx` | float | — | Domain length in x (normalized to `λ_D`) | +| `alpha` | float | — | AW-Hermite velocity **scale** parameter `α` (the benchmarks use `√2`) | +| `u` | float | `0.0` | AW-Hermite velocity **shift** parameter `u` | +| `v_a`, `v_b` | float | — | Legendre velocity-window bounds (`df` is resolved on `[v_a, v_b]`) | +| `gamma` | float | `0.5` | Penalty coefficient `γ` for the weak Legendre Dirichlet BC (`df(v_a)=df(v_b)=0`). Applied only to modes `m ≥ 3` to preserve conservation. | +| `nu_H` | float | `0.0` | Artificial (Lenard-Bernstein) Hermite collision rate `ν_H`. Keep small/zero so `f0` can feed `df` through the last Hermite moment. | +| `nu_L` | float | `0.0` | Artificial Legendre collision rate `ν_L`. Controls filamentation/recurrence in `df`. | +| `enforce_conservation` | bool | `true` | Zero the coupling integrals `J_{Nh,0}=J_{Nh,1}=J_{Nh,2}=0` so the discrete method conserves mass, momentum, and energy independent of `Nh` parity and of `α, u` (paper sec 3.4/4). | +| `field` | bool | `true` | Self-consistent Poisson field. Set `false` for the pure linear-advection test (`φ = 0`); the linear Hermite→Legendre closure flux still acts. | + +The artificial collision operator (paper sec 2.5) uses the cubic spectrum +`col[n] = n(n-1)(n-2) / ((N-1)(N-2)(N-3))`, which is identically zero for +`n = 0, 1, 2` — so collisions never touch the mass/momentum/energy moments. + +--- + +## grid + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `Nx` | int | — | Number of Fourier modes in x | +| `Nh` | int | — | Number of AW-Hermite modes for `f0` (closure by truncation: `C_{Nh}=0`) | +| `Nl` | int | — | Number of Legendre modes for `df` (closure by truncation: `B_{Nl}=0`) | +| `tmax` | float | — | Final simulation time (normalized). Snapped to an exact multiple of `dt`. | +| `dt` | float | `0.01` | Timestep | +| `integrator` | str | `"lawson"` | Time integrator: `"lawson"` (explicit Lawson-RK4), `"imex"` (Lawson-RK4 + implicit Lorentz substep), or `"implicit"` (implicit midpoint, AD-JFNK) — see below. | +| `newton_iters` | int | `3` | (`implicit`) Newton iterations per step. | +| `gmres_restart`, `gmres_maxiter`, `gmres_tol` | int/int/float | `20`/`4`/`1e-8` | (`implicit`) matrix-free GMRES controls for the Newton linear solves. | +| `precondition` | bool | `true` | (`implicit`) use the streaming+collision operator as a physics-based GMRES preconditioner (see below). | + +### `integrator: implicit` (implicit midpoint via AD-JFNK) + +Advances the full RHS with the implicit-midpoint rule `y1 = y0 + dt·F((y0+y1)/2)`, +solved by **Jacobian-free Newton-Krylov**: each Newton linear system uses a matrix-free +GMRES whose Jacobian-vector products are *exact autodiff JVPs* (`jax.linearize`) — the +Jacobian is never assembled (memory is the state plus a few Krylov vectors). Implicit +midpoint is A-stable (no CFL at all) and conserves quadratic invariants, so it conserves +mass exactly and energy to the solve tolerance, and stays stable into the saturated / +long-time regime where both `lawson` and `imex` blow up (e.g. bump-on-tail). Cost: each +step does `newton_iters × (GMRES iterations) × (RHS evals)`, so it is the most expensive +per step — use it for the hard cases, not the cheap ones. + +**Preconditioning** (`precondition: true`, default). The implicit operator's stiffness is +dominated by the skew streaming term, whose eigenvalues smear along the imaginary axis +(`~dt/2·α·k_max·√(2Nh)`) — the worst case for unpreconditioned GMRES, which then needs +many iterations and can fail to converge at large `dt`/`Nx` (Newton then injects energy). +The preconditioner `M = I − dt/2·(L_streaming + L_collision)` is block-diagonal in `k` and +tridiagonal in mode index, so `M⁻¹` is a cheap per-`k` tridiagonal solve that captures +exactly that stiff spectrum; GMRES on `M⁻¹A` then converges in a handful of iterations. +This is what makes large-`dt` implicit-midpoint runs practical. + +### `integrator: imex` + +The stiffness that limits the explicit step is the `E·∂_v f` Lorentz force: in the +spectral velocity bases it is strictly lower-triangular (nilpotent for Hermite, +lower-triangular + a rank-2 penalty for Legendre) with operator norm `~Nl²/width·|E|` +— so explicit RK4's `|dt·‖L‖|≲2.8` limit tightens as modes/field grow. Setting +`integrator: imex` keeps free-streaming, collisions, and the Hermite→Legendre closure +flux in the explicit Lawson step, and advances the Lorentz force with an +**unconditionally stable frozen-E Backward-Euler substep** (a per-`x` triangular/dense +linear solve; first-order Lie split). This removes the CFL limit, letting two-stream +run at `dt ≈ 0.02` instead of `0.002`. Trade-offs: Backward Euler is mildly dissipative +and the split is first-order in `dt`, so for high-accuracy/conservation studies prefer +small-`dt` `lawson`; for robustness at large mode counts or large `Nx`, prefer `imex`. + +**Choosing `dt`.** Free-streaming and collisions are integrated exactly, but the +explicit Lawson-RK4 treatment of the E-field force has a stability (CFL) limit that +tightens as the self-consistent field grows. For small-amplitude/linear runs (e.g. +driven Landau damping) `dt = 0.05` is fine; for nonlinear instabilities that saturate +to a large field (two-stream) a smaller step is needed — `dt ≈ 0.002` is stable and +converged for the two-stream benchmark. (The paper's `dt = 0.01` relies on its +unconditionally stable implicit-midpoint integrator; this explicit module trades that +for a smaller step and a much smaller memory footprint.) A run that goes `NaN` partway +through is the signature of `dt` above the CFL limit — halve it. + +--- + +## initialization + +Selects how the initial `C_n(x)` and `B_m(x)` coefficients are built. + +| `type` | Parameters | Description | +|--------|-----------|-------------| +| `linear-advection` | `eps`, `mode` | `f0 = (1 + eps·cos(k x))/√(2π)·exp(-v²/2)`; `df = 0`. (`C_0 = n(x)/α`.) | +| `two-stream` | `eps`, `mode` | `f0 ∝ (1 + eps·cos(k x))·v²·exp(-v²/2)`: `C_0 = n(x)/α`, `C_2 = √2·C_0`; `df = 0`. | +| `bump-on-tail` | `eps`, `mode`, `n_beam`, `v_drift`, `v_th` | Bulk Maxwellian in `f0`; a drifting Gaussian beam `n_beam/(√(2π) v_th)·exp(-(v-v_drift)²/2v_th²)` projected onto Legendre as `df`. | +| `custom` | `hermite: {n: {base, eps, mode}}`, `df: {beams: [{amp, v_drift, v_th}], eps, mode}` | Generic Hermite coefficient profiles plus a beam/sum-of-Gaussians `df` projected onto Legendre. | + +Here `k = 2π·mode/Lx`. The Legendre projection uses Gauss-Legendre quadrature. + +--- + +## drivers (optional) + +An external longitudinal field `ex` can be applied to the velocity-space force (it +never enters the Poisson solve), e.g. to drive a resonant EPW for a Landau-damping +measurement — the analogue of the Vlasov-1D `ex` driver. Omit the `drivers` block for +self-consistent runs. + +```yaml +drivers: + ex: + '0': # one entry per pulse + k0: 0.4 # wavenumber + w0: 1.285 # angular frequency (e.g. Re(omega) from the dispersion relation) + dw0: 0.0 # frequency offset (added to w0) + a0: 1.0e-3 # amplitude + t_center: 20.0 # pulse: center / full width / rise(+fall) time + t_width: 20.0 + t_rise: 5.0 + x_center: 7.85 # spatial envelope: center / width / rise (defaults span the box) + x_width: 1.0e6 + x_rise: 1.0 +``` + +The driver field is `E_drive(x,t) = Σ env(x,t)·(w0+dw0)·a0·sin(k0 x − (w0+dw0) t)` and +is saved as `de` in the `fields` group. + +--- + +## save + +Standard ADEPT `save` block with `t: {nt: ...}` (or `tmin`/`tmax`/`nt`) sub-axes. + +| Key | Contents | +|-----|----------| +| `fields` | Electric field `e(x,t)`, potential `phi(x,t)`, and external driver field `de(x,t)` | +| `hermite` | AW-Hermite-Fourier coefficient timeseries `Ck` (shape `nt × Nh × Nx`) | +| `legendre` | Legendre-Fourier coefficient timeseries `Bk` (shape `nt × Nl × Nx`) | +| `default` | Scalar invariants `mass`, `momentum`, `energy` (paper eqns 26, 28, 30-31) plus field energy and density extrema. Always added; the primary correctness gate. | + +`post_process` writes netCDF binaries and spacetime/scalar plots, and reports the +relative drift of each invariant as the metrics `reldrift_{mass,momentum,energy}`. + +--- + +## Example: two-stream instability + +```yaml +solver: hermite-legendre-1d +mlflow: {experiment: hermite-legendre-1d, run: two-stream} +units: {normalizing_density: 1e20/cc, normalizing_temperature: 1keV} +physics: + Lx: 12.566370614359172 # 4π + alpha: 1.4142135623730951 + u: 0.0 + v_a: -2.5 + v_b: 2.5 + gamma: 0.5 + nu_H: 0.0 + nu_L: 1.0 + enforce_conservation: true + field: true +grid: {Nx: 64, Nh: 85, Nl: 171, tmax: 35.0, dt: 0.01} +initialization: {type: two-stream, eps: 0.01, mode: 1} +save: + fields: {t: {nt: 351}} + legendre: {t: {nt: 71}} +``` + +See `configs/hermite-legendre-1d/` for the linear-advection, two-stream, and +bump-on-tail benchmark configurations. diff --git a/tests/test_hermite_legendre_1d/__init__.py b/tests/test_hermite_legendre_1d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_hermite_legendre_1d/test_conservation.py b/tests/test_hermite_legendre_1d/test_conservation.py new file mode 100644 index 00000000..758b7762 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_conservation.py @@ -0,0 +1,77 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Conservation gate for the mixed Hermite-Legendre solver (paper sec 3, Fig 7). + +With the constraint J_{Nh,0}=J_{Nh,1}=J_{Nh,2}=0 enforced, the self-consistent +(field-on) mixed method conserves total mass and momentum to machine precision and +total energy to the explicit time integrator's accuracy (the paper's machine- +precision energy relies on the implicit-midpoint integrator; the explicit Lawson-RK4 +used here is time-integration-limited but convergent). The run also stays finite. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run(cfg): + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"] + + +def _two_stream_cfg(enforce, Nh=48, Nl=48): + return { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -2.5, + "v_b": 2.5, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 1.0, + "enforce_conservation": enforce, + "field": True, + }, + "grid": {"Nx": 48, "Nh": Nh, "Nl": Nl, "tmax": 10.0, "dt": 0.01}, + "initialization": {"type": "two-stream", "eps": 0.05, "mode": 1}, + "save": {"default": {"t": {"nt": 100}}}, + "units": {}, + } + + +def test_two_stream_conserves_invariants_with_enforcement(): + sol = _run(_two_stream_cfg(enforce=True)) + d = sol.ys["default"] + drifts = {} + for k in ("mass", "momentum", "energy"): + a = np.asarray(d[k]) + assert np.all(np.isfinite(a)), f"{k} went non-finite" + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + drifts[k] = float(np.max(np.abs(a - a[0])) / base) + assert drifts["mass"] < 1e-10, drifts + assert drifts["momentum"] < 1e-10, drifts + assert drifts["energy"] < 1e-5, drifts # time-integration-limited (dt=0.01) + + +def test_collisions_only_damp_high_modes(): + """nu_L damps Legendre modes m>=3 only; mass/momentum/energy stay conserved + because the collision profile vanishes on the first three moments.""" + sol = _run(_two_stream_cfg(enforce=True)) + d = sol.ys["default"] + # mass uses C_0, B_0; momentum C_0,C_1,B_0,B_1; energy up to C_2,B_2 -- all + # below the collisional cutoff, so collisions must not spoil them. + for k, tol in (("mass", 1e-10), ("momentum", 1e-10)): + a = np.asarray(d[k]) + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + assert np.max(np.abs(a - a[0])) / base < tol diff --git a/tests/test_hermite_legendre_1d/test_imex.py b/tests/test_hermite_legendre_1d/test_imex.py new file mode 100644 index 00000000..ed7f3489 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_imex.py @@ -0,0 +1,64 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""IMEX stability gate for the mixed Hermite-Legendre solver. + +The explicit Lawson-RK4 step has a CFL-like limit set by the (stiff) E.d_v f Lorentz +force, whose spectral operator norm scales as ~Nl^2/width * |E|; for the two-stream +benchmark it blows up by t~20 at dt=0.01. The IMEX integrator advances that force with +a frozen-E Backward-Euler substep (unconditionally stable), so the same run stays +finite and well-behaved at a step several times larger. This test gates that: + + - explicit (lawson) at dt=0.01 goes non-finite before t=35 (the failure IMEX fixes), + - IMEX at dt=0.01 stays finite through t=35, conserves mass (the force leaves the + C_0/B_0 moments untouched), and saturates to a physically reasonable field energy. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_two_stream(integrator: str, dt: float, tmax: float = 35.0, Nh: int = 85, Nl: int = 171): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, "alpha": np.sqrt(2.0), "u": 0.0, "v_a": -2.5, "v_b": 2.5, + "gamma": 0.5, "nu_H": 0.0, "nu_L": 1.0, "enforce_conservation": True, "field": True, + }, + "grid": {"Nx": 64, "Nh": Nh, "Nl": Nl, "tmax": tmax, "dt": dt, "integrator": integrator}, + "initialization": {"type": "two-stream", "eps": 0.01, "mode": 1}, + "save": {"default": {"t": {"nt": 80}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"].ys["default"] + + +def test_explicit_blows_up_at_large_dt(): + """The failure mode that motivates IMEX: explicit two-stream is unstable at dt=0.01.""" + d = _run_two_stream("lawson", dt=0.01) + assert not np.all(np.isfinite(np.asarray(d["energy"]))), "expected explicit blow-up at dt=0.01" + + +def test_imex_stable_at_large_dt(): + d = _run_two_stream("imex", dt=0.01) + energy = np.asarray(d["energy"]) + ee = np.asarray(d["e_energy"]) + mass = np.asarray(d["mass"]) + + assert np.all(np.isfinite(energy)), "IMEX went non-finite at dt=0.01" + # mass is untouched by the Lorentz force (G_C/G_B leave the C_0/B_0 moments alone) + assert np.max(np.abs(mass - mass[0])) / abs(mass[0]) < 1e-10, "IMEX broke mass conservation" + # the instability grew and saturated to a finite field energy comparable to the + # converged explicit value (~0.33); generous bounds (Backward Euler is dissipative). + assert ee.max() > 1e-2, "two-stream field energy did not grow under IMEX" + assert 0.1 < ee[-1] < 1.0, f"saturated field energy out of range: {ee[-1]:.3f}" diff --git a/tests/test_hermite_legendre_1d/test_implicit.py b/tests/test_hermite_legendre_1d/test_implicit.py new file mode 100644 index 00000000..739bcbc4 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_implicit.py @@ -0,0 +1,64 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Implicit-midpoint (AD-JFNK) gate for the mixed Hermite-Legendre solver. + +`integrator: implicit` advances the FULL right-hand side with the implicit-midpoint +rule y1 = y0 + dt F((y0+y1)/2), solved by Jacobian-free Newton-Krylov: the Newton +linear systems use a matrix-free GMRES whose Jacobian-vector products are exact +autodiff JVPs (jax.linearize) -- the Jacobian is never assembled. Implicit midpoint is +A-stable (no CFL) and conserves quadratic invariants, so unlike the explicit and IMEX +paths it stays stable into the saturated/long-time regime and conserves mass exactly +and energy to the nonlinear-solve tolerance. + +This gates a two-stream run at dt=0.05 -- a step at which the explicit Lawson path is +violently unstable -- staying finite, conserving mass to machine precision, and holding +energy far better than an explicit step of the same size could. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_two_stream(integrator: str, dt: float, tmax: float = 10.0, Nh: int = 48, Nl: int = 48): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 4.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -2.5, + "v_b": 2.5, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 1.0, + "enforce_conservation": True, + "field": True, + }, + "grid": {"Nx": 48, "Nh": Nh, "Nl": Nl, "tmax": tmax, "dt": dt, "integrator": integrator}, + "initialization": {"type": "two-stream", "eps": 0.05, "mode": 1}, + "save": {"default": {"t": {"nt": 40}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m(trainable_modules={})["solver result"].ys["default"] + + +def test_implicit_midpoint_stable_and_conserving(): + d = _run_two_stream("implicit", dt=0.05) + energy = np.asarray(d["energy"]) + mass = np.asarray(d["mass"]) + + assert np.all(np.isfinite(energy)), "implicit midpoint went non-finite" + # implicit midpoint conserves mass exactly and energy to the JFNK solve tolerance + assert np.max(np.abs(mass - mass[0])) / abs(mass[0]) < 1e-11, "mass not conserved" + assert np.max(np.abs(energy - energy[0])) / abs(energy[0]) < 1e-4, "energy drift too large" diff --git a/tests/test_hermite_legendre_1d/test_landau_damping.py b/tests/test_hermite_legendre_1d/test_landau_damping.py new file mode 100644 index 00000000..f00780b6 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_landau_damping.py @@ -0,0 +1,99 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Driven Landau-damping gate for the mixed Hermite-Legendre solver. + +Mirrors the BaseVlasov1D driven-resonance test (tests/test_vlasov1d/test_landau_damping.py): +a uniform Maxwellian is driven by a small external longitudinal field `ex` at the +resonant wavenumber `k0` and frequency `w0 = Re(omega)` of the kinetic electrostatic +dispersion relation, on a box of length `2*pi/k0` (so the driven mode is k=1). After +the driver ramps off, the electron plasma wave free-rings at its natural frequency +and decays at the Landau rate; we measure both from `E_x(k=1, t)` and compare to the +dispersion-relation root. + +The bulk is carried by the AW-Hermite expansion, which captures Landau damping; the +Legendre part stays negligible in this linear regime. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np +import pytest + +from adept import electrostatic +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D + + +def _run_driven(klambda_D: float, w0: float, Nh: int = 256, tmax: float = 80.0, dt: float = 0.05): + Lx = 2.0 * np.pi / klambda_D + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": Lx, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -6.0, + "v_b": 6.0, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": True, + }, + "grid": {"Nx": 64, "Nh": Nh, "Nl": 16, "tmax": tmax, "dt": dt}, + "initialization": {"type": "linear-advection", "eps": 0.0, "mode": 1}, # uniform Maxwellian + "drivers": { + "ex": { + "0": { + "k0": 2.0 * np.pi / Lx, + "w0": w0, + "dw0": 0.0, + "a0": 1.0e-3, + "t_center": 20.0, + "t_width": 20.0, + "t_rise": 5.0, + "x_center": 0.5 * Lx, + "x_width": 1.0e6, + "x_rise": 1.0, + } + } + }, + "save": {"fields": {"t": {"nt": int(tmax * 4)}}}, + "units": {}, + } + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + sol = m(trainable_modules={})["solver result"] + e = np.asarray(sol.ys["fields"]["e"]) + t = np.asarray(sol.ts["fields"]) + ek1 = np.fft.fft(e, axis=1)[:, 1] / e.shape[1] + return t, ek1 + + +@pytest.mark.parametrize("klambda_D", [0.30, 0.35]) +def test_driven_landau_damping(klambda_D): + root = electrostatic.get_roots_to_electrostatic_dispersion(1.0, 1.0, klambda_D, maxwellian_convention_factor=2.0) + expected_freq = float(np.real(root)) + expected_damp = float(np.imag(root)) + + t, ek1 = _run_driven(klambda_D, w0=expected_freq) + assert np.all(np.isfinite(ek1)), "driven run went non-finite" + + # free-decay window: driver is off by ~t=40, recurrence is well beyond t=70 at Nh=256 + win = (t > 45.0) & (t < 70.0) + mag = np.abs(ek1[win]) + measured_damp = float(np.polyfit(t[win], np.log(mag), 1)[0]) # d/dt ln|E_k1| + measured_freq = float(-np.polyfit(t[win], np.unwrap(np.angle(ek1[win])), 1)[0]) # -d(arg)/dt + + print( + f"\nklambda_D={klambda_D:.2f} freq {measured_freq:.4f} (exp {expected_freq:.4f}) " + f"damp {measured_damp:.5f} (exp {expected_damp:.5f})" + ) + # frequency is captured to <1%; finite-Nh Hermite slightly under-damps (~5-9%). + np.testing.assert_allclose(measured_freq, expected_freq, rtol=0.02) + np.testing.assert_allclose(measured_damp, expected_damp, rtol=0.15) diff --git a/tests/test_hermite_legendre_1d/test_linear_advection.py b/tests/test_hermite_legendre_1d/test_linear_advection.py new file mode 100644 index 00000000..90dcb800 --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_linear_advection.py @@ -0,0 +1,105 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Linear-advection physics gate for the mixed Hermite-Legendre solver. + +With phi = 0 the Vlasov equation reduces to advection d_t f + v d_x f = 0, whose +exact solution is f(x, v, t) = f(x - v t, v, 0). Running the full lifecycle on a +coarse grid and reconstructing f = f0 + df from the spectral coefficients should +match the analytic solution at a time before the spectral velocity recurrence +(paper sec 4.2, Fig 3). Also gates that mass/momentum/energy are conserved to +machine precision under pure advection (all k=0 moments are time-invariant). +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np + +from adept._hermite_legendre_1d.modules import BaseHermiteLegendre1D +from adept._hermite_legendre_1d.vector_field import _hermite_function_values, _legendre_basis_values + + +def _run(cfg): + m = BaseHermiteLegendre1D(cfg) + m.write_units() + m.get_derived_quantities() + m.get_solver_quantities() + m.init_state_and_args() + m.init_diffeqsolve() + return m, m(trainable_modules={})["solver result"] + + +def test_linear_advection_matches_analytic(): + Lx = 2.0 * np.pi + alpha = np.sqrt(2.0) + Nh = Nl = 48 + Nx = 48 + v_a, v_b = -6.0, 6.0 + t_check = 2.0 + + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": Lx, + "alpha": alpha, + "u": 0.0, + "v_a": v_a, + "v_b": v_b, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": False, + }, + "grid": {"Nx": Nx, "Nh": Nh, "Nl": Nl, "tmax": t_check, "dt": 0.01}, + "initialization": {"type": "linear-advection", "eps": 1.0, "mode": 1}, + "save": {"hermite": {"t": {"nt": 2}}, "legendre": {"t": {"nt": 2}}}, + "units": {}, + } + m, sol = _run(cfg) + x = np.asarray(m.cfg["grid"]["x"]) + + Ck = np.asarray(sol.ys["hermite"]["Ck"])[-1] # (Nh, Nx) k-space + Bk = np.asarray(sol.ys["legendre"]["Bk"])[-1] # (Nl, Nx) + C = np.fft.ifft(Ck, axis=-1, norm="forward").real # (Nh, Nx) + B = np.fft.ifft(Bk, axis=-1, norm="forward").real # (Nl, Nx) + + v = np.linspace(v_a + 0.1, v_b - 0.1, 120) + psi = _hermite_function_values(Nh, v, u=0.0, alpha=alpha) # (Nh, len(v)) + xi = _legendre_basis_values(Nl, v, v_a, v_b) # (Nl, len(v)) + + f = C.T @ psi + B.T @ xi # (Nx, len(v)) + XX, VV = np.meshgrid(x, v, indexing="ij") + f_exact = (1.0 + np.cos(XX - VV * t_check)) / np.sqrt(2.0 * np.pi) * np.exp(-(VV**2) / 2.0) + + rel_l2 = np.linalg.norm(f - f_exact) / np.linalg.norm(f_exact) + assert rel_l2 < 0.02, f"linear advection rel L2 = {rel_l2:.4f}" + + +def test_advection_conserves_moments_to_machine_precision(): + cfg = { + "solver": "hermite-legendre-1d", + "physics": { + "Lx": 2.0 * np.pi, + "alpha": np.sqrt(2.0), + "u": 0.0, + "v_a": -6.0, + "v_b": 6.0, + "gamma": 0.5, + "nu_H": 0.0, + "nu_L": 0.0, + "enforce_conservation": True, + "field": False, + }, + "grid": {"Nx": 48, "Nh": 48, "Nl": 48, "tmax": 5.0, "dt": 0.02}, + "initialization": {"type": "linear-advection", "eps": 1.0, "mode": 1}, + "save": {"default": {"t": {"nt": 60}}}, + "units": {}, + } + _, sol = _run(cfg) + d = sol.ys["default"] + for k in ("mass", "momentum", "energy"): + a = np.asarray(d[k]) + base = abs(a[0]) if abs(a[0]) > 1e-30 else 1.0 + assert np.max(np.abs(a - a[0])) / base < 1e-11, f"{k} not conserved under advection" diff --git a/tests/test_hermite_legendre_1d/test_streaming.py b/tests/test_hermite_legendre_1d/test_streaming.py new file mode 100644 index 00000000..8c0bc5cf --- /dev/null +++ b/tests/test_hermite_legendre_1d/test_streaming.py @@ -0,0 +1,103 @@ +# Copyright (c) Ergodic LLC 2026 +# research@ergodic.io +"""Unit tests for the mixed Hermite-Legendre building blocks. + +Covers the prediagonalized free-streaming exponentials (Golub-Welsch: streaming +matrix eigenvalues are the Gauss quadrature nodes), the streaming round-trip, and +the analytic parity of the Hermite->Legendre coupling integrals J_{Nh,m} (paper +eqns 27/29/34), which underpin the conservation properties. +""" + +from jax import config as jax_config + +jax_config.update("jax_enable_x64", True) + +import numpy as np +import pytest +from jax import numpy as jnp + +from adept._hermite_legendre_1d.vector_field import ( + StreamingExp1D, + hermite_legendre_coupling_vector, + hermite_streaming_matrix, + legendre_constants, + safe_col, +) + + +def test_hermite_matrix_eigvals_are_gauss_hermite_nodes(): + """Golub-Welsch: eigenvalues of T_H (u=0) equal the Gauss-Hermite nodes.""" + Nh = 24 + T = hermite_streaming_matrix(Nh, u=0.0, alpha=1.0) + eig = np.sort(np.linalg.eigvalsh(T)) + nodes = np.sort(np.polynomial.hermite.hermgauss(Nh)[0]) + np.testing.assert_allclose(eig, nodes, atol=1e-10) + + +def test_legendre_matrix_eigvals_are_gauss_legendre_nodes(): + """Eigenvalues of T_L equal the Gauss-Legendre nodes mapped to [v_a, v_b].""" + Nl, v_a, v_b = 20, -2.5, 2.5 + leg = legendre_constants(Nl, v_a, v_b) + eig = np.sort(np.linalg.eigvalsh(np.asarray(leg["T_L"]))) + nodes = np.sort(0.5 * (v_b - v_a) * np.polynomial.legendre.leggauss(Nl)[0] + 0.5 * (v_a + v_b)) + np.testing.assert_allclose(eig, nodes, atol=1e-10) + + +def test_streaming_exp_identity_and_roundtrip(): + """exp(L*0) is identity, and exp(L*s) exp(L*-s) returns the original array.""" + Nh, Nx = 16, 8 + kx = jnp.fft.fftfreq(Nx) * Nx * 2 * jnp.pi / (2 * jnp.pi) + T = hermite_streaming_matrix(Nh, u=0.3, alpha=1.4) + exp = StreamingExp1D(T, prefactor=-1j * 1.4, kx_1d=kx) + + rng = np.random.default_rng(0) + C = jnp.asarray(rng.standard_normal((Nh, Nx)) + 1j * rng.standard_normal((Nh, Nx))) + + np.testing.assert_allclose(np.asarray(exp.apply(C, 0.0)), np.asarray(C), atol=1e-12) + back = exp.apply(exp.apply(C, 0.37), -0.37) + np.testing.assert_allclose(np.asarray(back), np.asarray(C), atol=1e-10) + + +def test_legendre_derivative_matrix_sparsity(): + """sigma_{m,i} is strictly lower triangular and nonzero only when (m-i) is odd.""" + Nl, v_a, v_b = 12, -3.0, 3.0 + deriv = np.asarray(legendre_constants(Nl, v_a, v_b)["deriv"]) + for m in range(Nl): + for i in range(Nl): + if i >= m or (m - i) % 2 == 0: + assert deriv[m, i] == 0.0, (m, i) + else: + assert deriv[m, i] != 0.0, (m, i) + + +def test_collision_profile_conserves_first_three_moments(): + """safe_col is exactly zero for n=0,1,2 (mass/momentum/energy preserved).""" + col = np.asarray(safe_col(32)) + assert col[0] == 0.0 and col[1] == 0.0 and col[2] == 0.0 + assert np.all(col[3:] > 0.0) + assert col[-1] == pytest.approx(1.0) # normalized to 1 at n = N-1 + + +@pytest.mark.parametrize("Nh,vanish", [(25, (0, 2)), (24, (1,))]) +def test_coupling_integral_parity(Nh, vanish): + """J_{Nh,m} parity on a symmetric domain with u=0 (paper eqns 27/29/34): + odd Nh -> J_{Nh,0} = J_{Nh,2} = 0 (mass & energy); + even Nh -> J_{Nh,1} = 0 (momentum).""" + J = np.asarray( + hermite_legendre_coupling_vector( + Nh, Nl=10, alpha=np.sqrt(2.0), u=0.0, v_a=-3.0, v_b=3.0, enforce_conservation=False + ) + ) + for m in vanish: + assert abs(J[m]) < 1e-9, (Nh, m, J[m]) + # a non-vanishing entry should actually be present (sanity) + assert np.max(np.abs(J)) > 1e-6 + + +def test_enforce_conservation_zeros_first_three(): + J = np.asarray( + hermite_legendre_coupling_vector( + 30, Nl=10, alpha=np.sqrt(2.0), u=0.0, v_a=-2.5, v_b=2.5, enforce_conservation=True + ) + ) + assert np.all(J[:3] == 0.0)