Summary
The Vlasov-1D solver has a multi-GPU sharding path (grid.parallel: ["x", "v"]) built on jax.shard_map. The per-operation sharding choices are numerically correct, but the implementation is fragile and only partially distributed. This issue tracks tightening it up so it's predictable and scales well.
Note on scope: we are not memory-constrained today, so distributed (sharded) initialization is filed below as a longer-term item rather than the headline. The near-term concerns are predictability, coverage, config ergonomics, and tests.
What's there today (and correct)
Each operator shards the axis it is not transforming, so each shard's math is self-contained:
Config + validation live in adept/_vlasov1d/datamodel.py:72-87; docs in docs/source/solvers/vlasov1d/config.md.
Problems
1. Carry sharding is inferred, not declared (predictability)
The initial f is plain jnp.array(f_s) with no device_put/NamedSharding (adept/_vlasov1d/modules.py:281), and the filter_jit wrapping diffeqsolve declares no in_shardings/out_shardings (adept/_base_.py:409). The sharding of the diffrax loop carry is therefore left entirely to XLA's GSPMD inference. This is correct-by-luck rather than by construction: the layout XLA picks for the carry can change across XLA versions, and there's a reshard at every shard_map boundary that nothing pins down. We should declare the carry sharding so behavior is deterministic.
2. edfdv and vdfdx shard on different axes → all-to-all transpose every substep
edfdv is x-sharded; vdfdx is v-sharded. This is the standard "transpose" method for split-step spectral and is legitimate, but it means an all-to-all between the two pushes every substep. The 6th-order integrator alternates them ~6× per step (adept/_vlasov1d/solvers/vector_field.py:159-184), so the per-step collective count is high. Worth measuring before deciding whether to keep the transpose approach or move to a single persistent x-sharding + distributed x-FFT.
3. parallel config is a footgun
["x"] alone shards edfdv+collisions but leaves vdfdx replicated; ["v"] alone shards only vdfdx. You need ["x","v"] to cover both pushes, and even then the field solve, wave solver, ponderomotive term, Hou–Li filter, and diagnostics are all outside shard_map and run replicated. The name reads like "axes to parallelize across," but it actually means "which push gets sharded on which of its own axes." Either rename/redocument, or (preferably) collapse to a single intent (e.g. parallel: true → shard the whole step on x).
4. Field solve and friends are unsharded
Poisson/Ampere/Hampere, the wave solver, ponderomotive term, Hou–Li filter, and diagnostics all run outside shard_map (adept/_vlasov1d/solvers/pushers/field.py, adept/_vlasov1d/solvers/vector_field.py:88-92). For nx ≫ nv the Poisson x-FFT is one of the more expensive global ops and is currently a serial + reshard boundary.
5. No multi-device test
Nothing in tests/test_vlasov1d exercises the parallel path. On 1 device shard_map is a no-op, so single-host tests pass without catching resharding/coverage/correctness-under-real-sharding regressions.
6. Mesh constructed in three places
Mesh(np.array(jax.devices()), ("device",)) is built independently in vlasov.py (×3) and fokker_planck.py — no single source of truth, no mesh_shape control. Compare the more capable adept/_vlasov2d/distributed.py.
Proposed direction
Near-term (predictability + coverage + tests):
Longer-term (when memory becomes the constraint):
Context
Summary
The Vlasov-1D solver has a multi-GPU sharding path (
grid.parallel: ["x", "v"]) built onjax.shard_map. The per-operation sharding choices are numerically correct, but the implementation is fragile and only partially distributed. This issue tracks tightening it up so it's predictable and scales well.Note on scope: we are not memory-constrained today, so distributed (sharded) initialization is filed below as a longer-term item rather than the headline. The near-term concerns are predictability, coverage, config ergonomics, and tests.
What's there today (and correct)
Each operator shards the axis it is not transforming, so each shard's math is self-contained:
edfdvvelocity push — sharded over x (P("device", None)), FFT overvlocal —adept/_vlasov1d/solvers/pushers/vlasov.py:93-103vdfdxspace push — sharded over v (P(None, "device")), FFT overxlocal —adept/_vlasov1d/solvers/pushers/vlasov.py:216-227v) —adept/_vlasov1d/solvers/pushers/fokker_planck.py:222-230Config + validation live in
adept/_vlasov1d/datamodel.py:72-87; docs indocs/source/solvers/vlasov1d/config.md.Problems
1. Carry sharding is inferred, not declared (predictability)
The initial
fis plainjnp.array(f_s)with nodevice_put/NamedSharding(adept/_vlasov1d/modules.py:281), and thefilter_jitwrappingdiffeqsolvedeclares noin_shardings/out_shardings(adept/_base_.py:409). The sharding of the diffrax loop carry is therefore left entirely to XLA's GSPMD inference. This is correct-by-luck rather than by construction: the layout XLA picks for the carry can change across XLA versions, and there's a reshard at everyshard_mapboundary that nothing pins down. We should declare the carry sharding so behavior is deterministic.2.
edfdvandvdfdxshard on different axes → all-to-all transpose every substepedfdvis x-sharded;vdfdxis v-sharded. This is the standard "transpose" method for split-step spectral and is legitimate, but it means an all-to-all between the two pushes every substep. The 6th-order integrator alternates them ~6× per step (adept/_vlasov1d/solvers/vector_field.py:159-184), so the per-step collective count is high. Worth measuring before deciding whether to keep the transpose approach or move to a single persistentx-sharding + distributed x-FFT.3.
parallelconfig is a footgun["x"]alone shardsedfdv+collisions but leavesvdfdxreplicated;["v"]alone shards onlyvdfdx. You need["x","v"]to cover both pushes, and even then the field solve, wave solver, ponderomotive term, Hou–Li filter, and diagnostics are all outsideshard_mapand run replicated. The name reads like "axes to parallelize across," but it actually means "which push gets sharded on which of its own axes." Either rename/redocument, or (preferably) collapse to a single intent (e.g.parallel: true→ shard the whole step onx).4. Field solve and friends are unsharded
Poisson/Ampere/Hampere, the wave solver, ponderomotive term, Hou–Li filter, and diagnostics all run outside
shard_map(adept/_vlasov1d/solvers/pushers/field.py,adept/_vlasov1d/solvers/vector_field.py:88-92). Fornx ≫ nvthe Poisson x-FFT is one of the more expensive global ops and is currently a serial + reshard boundary.5. No multi-device test
Nothing in
tests/test_vlasov1dexercises the parallel path. On 1 deviceshard_mapis a no-op, so single-host tests pass without catching resharding/coverage/correctness-under-real-sharding regressions.6. Mesh constructed in three places
Mesh(np.array(jax.devices()), ("device",))is built independently invlasov.py(×3) andfokker_planck.py— no single source of truth, nomesh_shapecontrol. Compare the more capableadept/_vlasov2d/distributed.py.Proposed direction
Near-term (predictability + coverage + tests):
NamedShardingonce (single helper, mirroring_vlasov2d/distributed.py), not per-pusher.with_sharding_constrainton the loop carry and/orin/out_shardingson the top-level jit, so thex-sharding persists across the diffrax loop deterministically.reshard_for_global_axis_fft/reshard_to_partitionedalready exist in the 2D module).parallelconfig so partial settings can't silently leave half the step serial; updatedocs/source/solvers/vlasov1d/config.md.XLA_FLAGS=--xla_force_host_platform_device_count=4) asserting bit-identical results vs. the serial path.Longer-term (when memory becomes the constraint):
f(jax.make_array_from_callback), following the_vlasov2d/distributed.pypattern. Not needed yet at current problem sizes.Context
Nx-direction parallelism (Nx > Nv), per design intent.adept/_vlasov2d/distributed.py.