Skip to content

EroNinja/ForgeXLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ForgeXLA

ForgeXLA is a JAX runtime analysis, compiler IR introspection, static graph analysis, and reproducible benchmarking toolkit for scientific computing and deep learning workloads.

Motivation

I wanted to understand why mathematically equivalent JAX programs can exhibit very different runtime behavior on accelerators. ForgeXLA explores the path from Python functions to traced JAX graphs, lowered compiler IR, measured execution, heuristic optimization analysis, and optional GPU kernel experiments.

What Works Today

  • Internal GraphModule IR with JSON serialization and validation.
  • Real JAX tracing through jax.make_jaxpr for regular, grad, value_and_grad, and jit functions.
  • HLO and StableHLO text extraction through jax.jit(fn).lower(...).compiler_ir(...) when the local JAX version supports it.
  • Shallow text-level compiler IR operation counting.
  • Static graph analysis: histograms, fan-in/fan-out, depth estimate, static memory estimate, fusion candidates, and rewrite suggestions.
  • Runtime environment reporting and benchmark timing with recursive block_until_ready().
  • CPU-safe benchmark workloads: matmul, batched matmul, MLP, attention, layernorm, softmax, 2D heat equation, and N-body update.
  • DOT/JSON/HTML graph export and simple benchmark timeline export.
  • Distributed device reporting with honest single-device fallback.

Experimental Or Optional

  • Triton entry points are availability-checked and currently conservative. They raise clear errors until a validated host interop path is added and tested on compatible hardware.
  • The CUDA extension source under forgexla/bindings/cuda/ is an experimental build path and is not required for package import.
  • Distributed execution only runs a tiny pmap probe when multiple local devices are actually available.

Architecture

Python/JAX function
  -> jax.make_jaxpr
  -> GraphModule
  -> static analysis / DOT or JSON export
  -> optional compiler IR extraction
  -> measured benchmark records

Key packages:

  • forgexla.graph: IR, tracing, analysis, export.
  • forgexla.compiler: HLO/StableHLO extraction and heuristic rewrite reports.
  • forgexla.runtime: environment collection, execution, timing, report IO.
  • forgexla.benchmarks: deterministic workloads and benchmark runner.
  • forgexla.profiling: memory and arithmetic-intensity estimates.
  • forgexla.distributed: device metadata and tiny pmap probe.
  • forgexla.kernels: optional Triton/CUDA experiment surfaces.

Installation

ForgeXLA targets Python 3.11 or newer.

python -m pip install -e ".[dev]"

The base package depends on JAX, jaxlib, and NumPy. Optional accelerator and visualization extras are isolated:

python -m pip install -e ".[triton]"
python -m pip install -e ".[cuda]"
python -m pip install -e ".[graphviz]"

CPU-Only Quickstart

python scripts/collect_env.py
python examples/trace_mlp.py
python examples/inspect_hlo.py
python scripts/run_benchmarks.py --workload matmul --size small --iterations 10 --warmup 2 --output reports/matmul_smoke.json
python examples/distributed_device_report.py

Tracing Example

import jax.numpy as jnp
from forgexla.graph.trace import trace_jax_function

def fn(x, y):
    return jnp.tanh(x @ y)

graph = trace_jax_function(fn, (jnp.ones((4, 8)), jnp.ones((8, 2))))
print(graph.summary())

HLO And StableHLO Inspection

ForgeXLA uses real JAX lowering APIs when available:

from forgexla.compiler.hlo import extract_compiler_ir, summarize_ir_text

result = extract_compiler_ir(fn, (x, y), dialect="hlo")
if result.supported:
    print(summarize_ir_text(result.ir_text).to_dict())
else:
    print(result.error)

The IR inspection is shallow text-level counting, not a full parser.

Benchmarking

Benchmark records include raw samples, summary statistics, environment metadata, workload shape, dtype, warmup count, measured iteration count, and correctness status.

python scripts/run_benchmarks.py --workload attention --size small --iterations 20 --warmup 5 --output reports/attention.json

Reports created by the runner contain measured local timing. The repository does not include hard-coded speedup claims. The benchmark JSON does not include accelerator counter data unless a future profiler records it explicitly.

Graph Analysis

forgexla.graph.analysis.analyze_graph reports operation histograms, shape and dtype histograms, fan-in/fan-out, depth, static memory estimates, hotspots, fusion candidates, and rewrite suggestions. Suggestions are labeled heuristic and should be confirmed through compiler IR and measured benchmarks.

Triton Experiments

forgexla.kernels.triton_ops imports without Triton installed. Its public functions check availability and raise TritonUnavailableError when the dependency or compatible hardware path is absent.

CUDA Extension Notes

forgexla/bindings/cuda/ contains an experimental CMake/pybind11 vector-add extension path. The base package does not import or build it. forgexla.kernels.cuda_ops raises CudaExtensionUnavailableError if the compiled module is absent.

Distributed Analysis

examples/distributed_device_report.py prints local device metadata. If only one device exists, the report states that multi-device execution was not exercised.

Benchmark Methodology

  • Warmup calls run before measurement.
  • Each measured sample calls the workload and recursively blocks array-like outputs.
  • Timing uses time.perf_counter.
  • JSON records include raw samples and summary statistics.
  • Hardware metrics that are not measured are not guessed.

Validation

Tested locally with:

python -m pytest
python scripts/collect_env.py
python examples/trace_mlp.py
python examples/inspect_hlo.py
python scripts/run_benchmarks.py --workload matmul --size small --iterations 10 --warmup 2 --output reports/matmul_smoke.json
python examples/distributed_device_report.py

Current validation: 38 tests passed on CPU backend.

Limitations

  • Compiler IR counting is shallow text inspection.
  • Static memory estimates do not model allocator reuse, compiler scheduling, buffer donation, or rematerialization.
  • Roofline-style estimates are dimension-based estimates, not hardware counter measurements.
  • Triton kernels are not yet validated against a concrete host interop path in this repository.
  • CUDA extension source is present but not built by default.
  • Distributed execution is only exercised when multiple local devices are actually available.

Roadmap

  • Add tested Triton host interop for selected GPU workloads.
  • Add richer StableHLO structure extraction when public APIs are available.
  • Add benchmark artifact comparison tools that consume committed JSON records.
  • Add more rewrite experiments where before/after correctness can be tested directly.
  • Add optional graph rendering when Graphviz is installed.

Development Commands

python -m pytest
python -m ruff check forgexla tests scripts examples
python scripts/collect_env.py

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors