Skip to content

Vlasov-1D multi-GPU sharding: declare carry sharding, cover the full step, add tests #290

@joglekara

Description

@joglekara

Summary

The Vlasov-1D solver has a multi-GPU sharding path (grid.parallel: ["x", "v"]) built on jax.shard_map. The per-operation sharding choices are numerically correct, but the implementation is fragile and only partially distributed. This issue tracks tightening it up so it's predictable and scales well.

Note on scope: we are not memory-constrained today, so distributed (sharded) initialization is filed below as a longer-term item rather than the headline. The near-term concerns are predictability, coverage, config ergonomics, and tests.

What's there today (and correct)

Each operator shards the axis it is not transforming, so each shard's math is self-contained:

Config + validation live in adept/_vlasov1d/datamodel.py:72-87; docs in docs/source/solvers/vlasov1d/config.md.

Problems

1. Carry sharding is inferred, not declared (predictability)

The initial f is plain jnp.array(f_s) with no device_put/NamedSharding (adept/_vlasov1d/modules.py:281), and the filter_jit wrapping diffeqsolve declares no in_shardings/out_shardings (adept/_base_.py:409). The sharding of the diffrax loop carry is therefore left entirely to XLA's GSPMD inference. This is correct-by-luck rather than by construction: the layout XLA picks for the carry can change across XLA versions, and there's a reshard at every shard_map boundary that nothing pins down. We should declare the carry sharding so behavior is deterministic.

2. edfdv and vdfdx shard on different axes → all-to-all transpose every substep

edfdv is x-sharded; vdfdx is v-sharded. This is the standard "transpose" method for split-step spectral and is legitimate, but it means an all-to-all between the two pushes every substep. The 6th-order integrator alternates them ~6× per step (adept/_vlasov1d/solvers/vector_field.py:159-184), so the per-step collective count is high. Worth measuring before deciding whether to keep the transpose approach or move to a single persistent x-sharding + distributed x-FFT.

3. parallel config is a footgun

["x"] alone shards edfdv+collisions but leaves vdfdx replicated; ["v"] alone shards only vdfdx. You need ["x","v"] to cover both pushes, and even then the field solve, wave solver, ponderomotive term, Hou–Li filter, and diagnostics are all outside shard_map and run replicated. The name reads like "axes to parallelize across," but it actually means "which push gets sharded on which of its own axes." Either rename/redocument, or (preferably) collapse to a single intent (e.g. parallel: true → shard the whole step on x).

4. Field solve and friends are unsharded

Poisson/Ampere/Hampere, the wave solver, ponderomotive term, Hou–Li filter, and diagnostics all run outside shard_map (adept/_vlasov1d/solvers/pushers/field.py, adept/_vlasov1d/solvers/vector_field.py:88-92). For nx ≫ nv the Poisson x-FFT is one of the more expensive global ops and is currently a serial + reshard boundary.

5. No multi-device test

Nothing in tests/test_vlasov1d exercises the parallel path. On 1 device shard_map is a no-op, so single-host tests pass without catching resharding/coverage/correctness-under-real-sharding regressions.

6. Mesh constructed in three places

Mesh(np.array(jax.devices()), ("device",)) is built independently in vlasov.py (×3) and fokker_planck.py — no single source of truth, no mesh_shape control. Compare the more capable adept/_vlasov2d/distributed.py.

Proposed direction

Near-term (predictability + coverage + tests):

  • Build the mesh + NamedSharding once (single helper, mirroring _vlasov2d/distributed.py), not per-pusher.
  • Declare the carry sharding: with_sharding_constraint on the loop carry and/or in/out_shardings on the top-level jit, so the x-sharding persists across the diffrax loop deterministically.
  • Bring the field solve (and Hou–Li filter / diagnostics) into the sharded path, or explicitly reshard around the global x-FFT (reshard_for_global_axis_fft / reshard_to_partitioned already exist in the 2D module).
  • Simplify/rename the parallel config so partial settings can't silently leave half the step serial; update docs/source/solvers/vlasov1d/config.md.
  • Add a single-host multi-device test (XLA_FLAGS=--xla_force_host_platform_device_count=4) asserting bit-identical results vs. the serial path.

Longer-term (when memory becomes the constraint):

  • Distributed/sharded initialization so each device only materializes its own shard of f (jax.make_array_from_callback), following the _vlasov2d/distributed.py pattern. Not needed yet at current problem sizes.

Context

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions