ForgeXLA is a JAX runtime analysis, compiler IR introspection, static graph analysis, and reproducible benchmarking toolkit for scientific computing and deep learning workloads.
I wanted to understand why mathematically equivalent JAX programs can exhibit very different runtime behavior on accelerators. ForgeXLA explores the path from Python functions to traced JAX graphs, lowered compiler IR, measured execution, heuristic optimization analysis, and optional GPU kernel experiments.
- Internal
GraphModuleIR with JSON serialization and validation. - Real JAX tracing through
jax.make_jaxprfor regular,grad,value_and_grad, andjitfunctions. - HLO and StableHLO text extraction through
jax.jit(fn).lower(...).compiler_ir(...)when the local JAX version supports it. - Shallow text-level compiler IR operation counting.
- Static graph analysis: histograms, fan-in/fan-out, depth estimate, static memory estimate, fusion candidates, and rewrite suggestions.
- Runtime environment reporting and benchmark timing with recursive
block_until_ready(). - CPU-safe benchmark workloads: matmul, batched matmul, MLP, attention, layernorm, softmax, 2D heat equation, and N-body update.
- DOT/JSON/HTML graph export and simple benchmark timeline export.
- Distributed device reporting with honest single-device fallback.
- Triton entry points are availability-checked and currently conservative. They raise clear errors until a validated host interop path is added and tested on compatible hardware.
- The CUDA extension source under
forgexla/bindings/cuda/is an experimental build path and is not required for package import. - Distributed execution only runs a tiny
pmapprobe when multiple local devices are actually available.
Python/JAX function
-> jax.make_jaxpr
-> GraphModule
-> static analysis / DOT or JSON export
-> optional compiler IR extraction
-> measured benchmark records
Key packages:
forgexla.graph: IR, tracing, analysis, export.forgexla.compiler: HLO/StableHLO extraction and heuristic rewrite reports.forgexla.runtime: environment collection, execution, timing, report IO.forgexla.benchmarks: deterministic workloads and benchmark runner.forgexla.profiling: memory and arithmetic-intensity estimates.forgexla.distributed: device metadata and tinypmapprobe.forgexla.kernels: optional Triton/CUDA experiment surfaces.
ForgeXLA targets Python 3.11 or newer.
python -m pip install -e ".[dev]"The base package depends on JAX, jaxlib, and NumPy. Optional accelerator and visualization extras are isolated:
python -m pip install -e ".[triton]"
python -m pip install -e ".[cuda]"
python -m pip install -e ".[graphviz]"python scripts/collect_env.py
python examples/trace_mlp.py
python examples/inspect_hlo.py
python scripts/run_benchmarks.py --workload matmul --size small --iterations 10 --warmup 2 --output reports/matmul_smoke.json
python examples/distributed_device_report.pyimport jax.numpy as jnp
from forgexla.graph.trace import trace_jax_function
def fn(x, y):
return jnp.tanh(x @ y)
graph = trace_jax_function(fn, (jnp.ones((4, 8)), jnp.ones((8, 2))))
print(graph.summary())ForgeXLA uses real JAX lowering APIs when available:
from forgexla.compiler.hlo import extract_compiler_ir, summarize_ir_text
result = extract_compiler_ir(fn, (x, y), dialect="hlo")
if result.supported:
print(summarize_ir_text(result.ir_text).to_dict())
else:
print(result.error)The IR inspection is shallow text-level counting, not a full parser.
Benchmark records include raw samples, summary statistics, environment metadata, workload shape, dtype, warmup count, measured iteration count, and correctness status.
python scripts/run_benchmarks.py --workload attention --size small --iterations 20 --warmup 5 --output reports/attention.jsonReports created by the runner contain measured local timing. The repository does not include hard-coded speedup claims. The benchmark JSON does not include accelerator counter data unless a future profiler records it explicitly.
forgexla.graph.analysis.analyze_graph reports operation histograms, shape and dtype histograms, fan-in/fan-out, depth, static memory estimates, hotspots, fusion candidates, and rewrite suggestions. Suggestions are labeled heuristic and should be confirmed through compiler IR and measured benchmarks.
forgexla.kernels.triton_ops imports without Triton installed. Its public functions check availability and raise TritonUnavailableError when the dependency or compatible hardware path is absent.
forgexla/bindings/cuda/ contains an experimental CMake/pybind11 vector-add extension path. The base package does not import or build it. forgexla.kernels.cuda_ops raises CudaExtensionUnavailableError if the compiled module is absent.
examples/distributed_device_report.py prints local device metadata. If only one device exists, the report states that multi-device execution was not exercised.
- Warmup calls run before measurement.
- Each measured sample calls the workload and recursively blocks array-like outputs.
- Timing uses
time.perf_counter. - JSON records include raw samples and summary statistics.
- Hardware metrics that are not measured are not guessed.
Tested locally with:
python -m pytest
python scripts/collect_env.py
python examples/trace_mlp.py
python examples/inspect_hlo.py
python scripts/run_benchmarks.py --workload matmul --size small --iterations 10 --warmup 2 --output reports/matmul_smoke.json
python examples/distributed_device_report.pyCurrent validation: 38 tests passed on CPU backend.
- Compiler IR counting is shallow text inspection.
- Static memory estimates do not model allocator reuse, compiler scheduling, buffer donation, or rematerialization.
- Roofline-style estimates are dimension-based estimates, not hardware counter measurements.
- Triton kernels are not yet validated against a concrete host interop path in this repository.
- CUDA extension source is present but not built by default.
- Distributed execution is only exercised when multiple local devices are actually available.
- Add tested Triton host interop for selected GPU workloads.
- Add richer StableHLO structure extraction when public APIs are available.
- Add benchmark artifact comparison tools that consume committed JSON records.
- Add more rewrite experiments where before/after correctness can be tested directly.
- Add optional graph rendering when Graphviz is installed.
python -m pytest
python -m ruff check forgexla tests scripts examples
python scripts/collect_env.py