Skip to content
21 changes: 11 additions & 10 deletions adept/_vlasov1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from diffrax import Solution
from jax import numpy as jnp
from matplotlib import pyplot as plt
from scipy.special import gamma
from jax.scipy.special import gamma


from adept._base_ import get_envelope
from adept._vlasov1d.storage import store_f, store_fields
Expand Down Expand Up @@ -38,7 +39,7 @@ def _initialize_distribution_(
m=2.0,
T0=1.0,
vmax=6.0,
n_prof=np.ones(1),
n_prof=jnp.ones(1),
noise_val=0.0,
noise_seed=42,
noise_type="Uniform",
Expand All @@ -65,20 +66,20 @@ def _initialize_distribution_(
noise_generator = np.random.default_rng(seed=noise_seed)

dv = 2.0 * vmax / nv
vax = np.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv)
vax = jnp.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv)

alpha = np.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m))
# cst = m / (4 * np.pi * alpha**3.0 * gamma(3.0 / m))
alpha = jnp.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m))
# cst = m / (4 * jnp.pi * alpha**3.0 * gamma(3.0 / m))

single_dist = -(np.power(np.abs((vax[None, :] - v0) / alpha / np.sqrt(T0)), m))
single_dist = -(jnp.power(jnp.abs((vax[None, :] - v0) / alpha / jnp.sqrt(T0)), m))

single_dist = np.exp(single_dist)
# single_dist = np.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0)
single_dist = jnp.exp(single_dist)
# single_dist = jnp.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0)

# for ix in range(nx):
f = np.repeat(single_dist, nx, axis=0)
f = jnp.repeat(single_dist, nx, axis=0)
# normalize
f = f / np.trapz(f, dx=dv, axis=1)[:, None]
f = f / jnp.trapezoid(f, dx=dv, axis=1)[:, None]

if n_prof.size > 1:
# scale by density profile
Expand Down
110 changes: 110 additions & 0 deletions tests/test_vlasov1d/configs/twostream_opt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
units:
laser_wavelength: 351nm
normalizing_temperature: 2000eV
normalizing_density: 1.5e21/cc
Z: 10
Zp: 10


density:
quasineutrality: true
species-electron1:
noise_seed: 420
noise_type: gaussian
noise_val: 0.0
v0: -1.5
T0: 0.2
m: 2.0
basis: sine
baseline: 0.5
amplitude: 1.0e-4
wavenumber: 0.3
species-electron2:
noise_seed: 420
noise_type: gaussian
noise_val: 0.0
v0: 1.5
T0: 0.2
m: 2.0
basis: sine
baseline: 0.5
amplitude: -1.0e-4
wavenumber: 0.3

grid:
dt: 0.1
nv: 4096
nx: 64
tmin: 0.
tmax: 100.0
vmax: 6.4
xmax: 20.94
xmin: 0.0

save:
fields:
t:
tmin: 0.0
tmax: 100.0
nt: 51
electron:
t:
tmin: 0.0
tmax: 100.0
nt: 51

solver: vlasov-1d

mlflow:
experiment: twostream-optimize
run: opt-iter-0

drivers:
ex: {}
ey: {}

diagnostics:
diag-vlasov-dfdt: False
diag-fp-dfdt: False

terms:
field: poisson
edfdv: exponential
time: sixth
fokker_planck:
is_on: True
type: Dougherty
time:
baseline: 1.0e-5
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
krook:
is_on: True
time:
baseline: 1.0e-6
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
140 changes: 140 additions & 0 deletions tests/test_vlasov1d/test_twostream_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import xarray as xr, numpy as np, os, sys
import yaml
import mlflow
from tqdm import tqdm
import time

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
sys.path.append(os.getcwd()) # To load adept

from jax import config
import jax
import jax.numpy as jnp
config.update("jax_enable_x64", True)
import equinox as eqx
import optax

from diffrax import diffeqsolve, SaveAt

from adept import ergoExo
from adept._vlasov1d.modules import BaseVlasov1D
from adept._vlasov1d.helpers import _initialize_total_distribution_

import matplotlib
matplotlib.use("Agg")


def set_dict_leaves(src, dst, key=None):
for k, v in src.items():
if isinstance(dst[k], dict):
set_dict_leaves(v, dst[k])
else:
dst[k] = v


class OptVlasov1D(BaseVlasov1D):
def __init__(self, cfg):
super().__init__(cfg)

def reinitialize_distribution(self, cfg, state):
# return super().init_state_and_args()
_, f = _initialize_total_distribution_(cfg, cfg["grid"])
state["electron"] = f

return state

def __call__(self, params: dict, args: dict) -> dict:
if args is None:
args = self.args
# Overwrite cfg with passed args
cfg = self.cfg
set_dict_leaves(params, cfg)
# Reinitialize the distribution based on args
state = self.reinitialize_distribution(cfg, self.state)
# Solve the equations
solver_result = diffeqsolve(
terms=self.diffeqsolve_quants["terms"],
solver=self.diffeqsolve_quants["solver"],
t0=self.time_quantities["t0"],
t1=self.time_quantities["t1"],
max_steps=cfg["grid"]["max_steps"],
dt0=cfg["grid"]["dt"],
y0=state,
args=args,
saveat=SaveAt(**self.diffeqsolve_quants["saveat"]),
)
# Compute the mean growth rate at the start of the simulation
opt_quantity = jnp.mean(jnp.log10(solver_result.ys['default']['mean_e2'][10:200]))
return_val = (opt_quantity, {"solver result": solver_result})
return return_val

def vg(self, params: dict, args: dict) -> tuple[float, dict, dict]:
return eqx.filter_value_and_grad(self.__call__, has_aux=True)(params, args)


if __name__ == "__main__":

with open("tests/test_vlasov1d/configs/twostream_opt.yaml", 'r') as stream:
cfg = yaml.safe_load(stream)
cfg['mlflow']['experiment'] = "twostream-optimize"
mlflow.set_experiment("twostream-optimize")

params = {"density": {
"species-electron1": {
"v0": jnp.array(cfg["density"]["species-electron1"]["v0"]),
"T0": jnp.array(cfg["density"]["species-electron1"]["T0"]),
},
"species-electron2": {
"v0": jnp.array(cfg["density"]["species-electron2"]["v0"]),
"T0": jnp.array(cfg["density"]["species-electron2"]["T0"]),
}
}
}

mlflow.log_metrics({
"e1_v0": params["density"]["species-electron1"]["v0"].item(),
"e1_T0": params["density"]["species-electron1"]["T0"].item(),
"e2_v0": params["density"]["species-electron2"]["v0"].item(),
"e2_T0": params["density"]["species-electron2"]["T0"].item(),
}, step=0)

optimizer = optax.adam(0.1)
opt_state = optimizer.init(params)

loop_t0 = time.time()
mlflow.log_metrics({"time_loop": time.time() - loop_t0}, step=0)
for i in tqdm(range(5)):
iter_t0 = time.time()
cfg['mlflow']['run'] = f"opt-iter-{i}"

exo = ergoExo(mlflow_nested=True)
exo.setup(cfg=cfg, adept_module=OptVlasov1D)
# Potential optimization to shift post-processing to another thread
val, grad, (sim_out, post_processed_output, mlflow_run_id) = exo.val_and_grad(params)

mlflow.log_metrics({
"gamma-e2": val.item(),
"grad_l2": jnp.linalg.norm(jnp.array(jax.tree.flatten(grad)[0])).item(),
"e1_v0": params["density"]["species-electron1"]["v0"].item(),
"e1_T0": params["density"]["species-electron1"]["T0"].item(),
"e2_v0": params["density"]["species-electron2"]["v0"].item(),
"e2_T0": params["density"]["species-electron2"]["T0"].item(),
}, step=i+1)

updates, opt_state = optimizer.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)

print(f"Mean-log e2 growth rate : {val}")

mlflow.log_metrics({
"time_iter": time.time() - iter_t0,
"time_loop": time.time() - loop_t0,
}, step=i+1)

# The final parameter values are not logged because they do not correspond to
# the final optimized quantity (the update step has been applied to them)
np.testing.assert_almost_equal(val, -8.572186655748087, decimal=2)
np.testing.assert_almost_equal(np.abs(params["density"]["species-electron1"]["v0"]),
np.abs(params["density"]["species-electron2"]["v0"]), decimal=2)
np.testing.assert_almost_equal(params["density"]["species-electron1"]["T0"],
params["density"]["species-electron2"]["T0"], decimal=2)
Loading