Skip to content

feat(onnx): sbi → ONNX exporter (transform_sbi_to_onnx) for NLE + NRE#79

Merged
AlexanderFengler merged 14 commits into
mainfrom
sbi-connector
Jun 20, 2026
Merged

feat(onnx): sbi → ONNX exporter (transform_sbi_to_onnx) for NLE + NRE#79
AlexanderFengler merged 14 commits into
mainfrom
sbi-connector

Conversation

@AlexanderFengler

@AlexanderFengler AlexanderFengler commented May 18, 2026

Copy link
Copy Markdown
Member

Summary

Adds lanfactory.onnx.transform_sbi_to_onnx, an exporter that converts a
trained sbi neural likelihood estimator
(NLE) or neural ratio estimator (NRE) to an ONNX file consumable by HSSM's
loglik_kind="approx_differentiable" path. The exported file behaves
identically to a LAN-trained ONNX from transform_to_onnx, so HSSM
consumes it with zero HSSM-side code changes for the basic case.

Lives as a sibling module to the existing LAN exporter
(lanfactory/onnx/transform_onnx.pylanfactory/onnx/sbi.py). The new
deps (sbi>=0.26, nflows>=0.14) are gated under the existing [all]
extra; no impact on the default install.

What's in this branch

# Commit What
C1 ac16eda Scaffolding: stub module, [all] extra extended with sbi + nflows
C2 f7c93c8 Spike tests — torch.onnx.export{onnxruntime, jaxonnxruntime} round-trip on a vanilla MLP and an nflows MAF. Kept as permanent regression guards.
C3 bdcabd3 Core exporter — NLE path. Wraps estimator.log_prob(x, condition=θ) as a torch.nn.Module that takes a 1D concatenated (theta, x) input. Three-way numerical agreement (torch / ORT / jaxonnxruntime) at atol=1e-5, gradient agreement at atol=1e-4.
C4 f4a54fe NRE path — wraps the classifier logit as the log-ratio. Same numerical-agreement contract.
C5 b1fd188 Embedding-net coverage tests for FCEmbedding and CNNEmbedding.
C6 87704bb Documentation: docs/exporting_sbi_models.md integration guide, README mention, nav update.
C7b 4990e85 End-to-end HSSM integration test (pytest.importorskip("hssm") skip-on-missing).
C9 222adf5 Critical fix — export rank-1 input contract for HSSM's vmap-over-trials. Previous 2D (1, D) dummy caused Slice axes=[1] which fails under vmap with rank-1 per-trial input.

Architectural contract

  • Single-trial graph: the exported ONNX takes a 1D vector of length
    theta_dim + x_dim, matching HSSM's make_jax_logp_funcs_from_onnx
    per-trial vmap contract.
  • Inside the wrapper: splits on axis 0, unsqueeze(0) to satisfy
    sbi's batched log_prob/forward API, returns scalar via reshape(()).
  • Opset pinned to 17 for reproducibility against jaxonnxruntime.

Test status

  • 13 LANfactory sbi tests passing locally
    (pytest tests/test_sbi_spike_*.py tests/test_sbi_nle_export.py tests/test_sbi_nre_export.py tests/test_sbi_embeddings.py).
  • C7b integration test is pytest.importorskip("hssm"); clean skip
    in the current LANfactory env. Designed to run in coordinated cross-
    repo CI where both LANfactory + HSSM are installed together.

Known limitations (documented in C6 guide)

  • NSF flows blocked by missing SearchSorted op in jaxonnxruntime
    (planned upstream PR, ~50 lines). Tracked in
    HSSMSpine/plans/sbi-onnx-integration.md.
  • MNLE same SearchSorted blocker (categorical lookup uses
    torch.searchsorted). Same upstream PR unlocks both.
  • NLE-MAF on DDM produces qualitatively wrong posteriors (rt is
    continuous but choice is discrete; MAF can't represent the mixed
    structure). The exporter still supports mode="nle" correctly —
    this is a sbi-method/data-shape issue, not an exporter bug. MNLE is
    the correct sbi method for DDM-like data; deferred until the
    upstream PR lands.

Companion PRs

Consumed by the HSSM-side keystone tutorial in
lnccbrown/HSSM#sbi-integration.
The tutorial uses an importlib fallback to load this exporter even
without LANfactory installed, so this PR can be merged independently.

Test plan

  • LANfactory sbi tests pass (13/13)
  • Cross-repo CI exercise of the integration test (test_sbi_hssm_integration.py)
    once both packages can be installed in the same env

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features
    • Added SBI-to-ONNX exporting for neural likelihood (NLE) and neural ratio (NRE) estimators to support HSSM’s approx_differentiable workflow via a new public exporter with mode-specific validation.
  • Documentation
    • Added an “Exporting sbi Models” guide, including quick-start examples and supported/excluded architectures.
  • Tests
    • Added export round-trip tests (numerical + gradient checks) across PyTorch, ONNX Runtime, and JAX, plus an HSSM MCMC integration test.
  • Chores
    • Expanded optional/dev dependencies for ONNX/JAX support and updated .gitignore for SBI TensorBoard logs.

AlexanderFengler and others added 8 commits May 13, 2026 17:31
Adds a stub transform_sbi_to_onnx in lanfactory/onnx/sbi.py as a sibling
of the existing LAN exporter. Extends the `all` extra to pull sbi and
nflows, and adds jaxonnxruntime to the dev group for round-trip testing.

First commit of the sbi -> HSSM integration plan
(plans/sbi-onnx-integration.md in HSSMSpine). Implementation lands in
C3 (NLE) and C4 (NRE).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds two permanent regression-guard tests validating the
torch.onnx.export to {onnxruntime, jaxonnxruntime} toolchain that the
sbi exporter (C3) will sit on top of. Both tests assert three-way
numerical agreement to 1e-5 on fixed inputs.

The MAF spike surfaced a real friction: the nflows MAF exports a
Reshape whose shape argument is a Constant node rather than a model
initializer, which jaxonnxruntime default strict mode rejects. Setting
jaxort_only_allow_initializers_as_static_args = False works around it.

Architectural implication for C3: HSSM onnx2jax.py does not set this
flag today, so sbi-exported flow graphs will fail to load through the
HSSM make_jax_logp_funcs_from_onnx path as-is. C3 should either
constant-fold the exported graph (preferred, keeps HSSM untouched) or
we will need a small HSSM-side patch.

Also adds onnxruntime>=1.17 and nflows>=0.14 to the dev dependency
group so uv sync --group dev is sufficient to run these tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
…(C3)

Replaces the C1 stub with the real transform_sbi_to_onnx implementation
for mode=nle. The exporter wraps an sbi ConditionalDensityEstimator
(NLE_A trained estimator) as a torch.nn.Module whose forward(combined)
splits a concatenated (theta, x) input and returns log p(x | theta) with
sbi standardization Jacobian baked into the traced graph. Exports a
single-trial graph at opset 17, matching the LAN convention and HSSM
vmap-over-trials expectation.

Rejection paths:
  - Score-based, flow-matching, TabPFN estimators raise ValueError.
  - NLE mode requires .log_prob(input, condition); clear TypeError if
    absent.
  - NRE mode currently raises NotImplementedError (lands in C4).

Tests in test_sbi_nle_export.py train a tiny 2D Gaussian NLE_A with MAF
and verify:
  1. Three-way numerical agreement (torch / onnxruntime / jaxonnxruntime)
     to atol=1e-5 on a fixed test point.
  2. Gradient agreement (torch.autograd vs jax.grad of the translated
     graph) to atol=1e-4.
  3. Sanity check that log-prob ordering matches the analytical Gaussian
     (near-mean point ranks above far point).
  4. Three rejection-path tests for the error contracts above.

Two findings surfaced during C3 that affect later commits:

  - 1D MAFs in sbi collapse to a degenerate Gaussian path with zero-width
    Gemm contractions that jaxonnxruntime cannot handle. The exporter
    must be exercised with >=2D theta and x. Documented in the simulator
    docstring.

  - jaxonnxruntime silently truncates int64 indices in exported flow
    graphs to int32, causing ~0.5 drift in log-prob outputs. The fix is
    jax.config.update("jax_enable_x64", True) BEFORE any JAX import.
    The test file sets this. C7 will decide whether HSSM onnx2jax.py
    should also set it globally (mirrors the C2.5 flag patch) or whether
    it stays a user responsibility documented in C6/C8.

Also adds sbi-logs/ to .gitignore (sbi auto-writes tensorboard logs
during training).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Extends transform_sbi_to_onnx to support mode="nre". The wrapper splits
the concatenated (theta, x) input and routes through the trained
RatioEstimator forward, returning the logit log r(x, theta). Up to a
theta-independent constant the logit IS log p(x | theta), so MCMC and
HSSM posterior path treat it as the likelihood. No Jacobian correction
is needed since ratios are invariant to z-score standardization.

Rejection: passing an estimator with .log_prob in mode="nre" raises
TypeError, since that signals a density estimator (NLE) rather than a
ratio classifier (NRE). The NLE path has the symmetric check.

New test file tests/test_sbi_nre_export.py trains a tiny 2D Gaussian
NRE_A and verifies the same three-way numerical agreement (atol=1e-5)
and gradient agreement (atol=1e-4) as the NLE path, plus a sanity
ordering check (log-ratio higher at near-theta than far-theta).

The C3 NRE-not-implemented test was repurposed into a cross-mode
rejection test: passing an NLE density estimator with mode="nre" now
raises a clear TypeError instead of NotImplementedError.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds tests/test_sbi_embeddings.py exercising NRE_A with two embedding
nets on x:
  - FCEmbedding (representative flat-MLP embedding)
  - CNNEmbedding (1D conv stack; validates Conv / MaxPool / etc. survive
    torch.onnx.export and translate cleanly into jaxonnxruntime)

Both tests train a tiny 2D-theta / 10-dim-x linear-Gaussian classifier
and assert three-way numerical agreement (torch / onnxruntime /
jaxonnxruntime, atol=1e-5).

Other sbi embeddings (PermutationInvariantEmbedding, ResNetEmbedding1D,
ResNetEmbedding2D, LRUEmbedding, TransformerEmbedding, CausalCNNEmbedding,
SpectralConvEmbedding) are out of v1 scope; can be added as follow-up
regressions if a user needs them.

C5 finding: sbi build_mlp_classifier defaults to nn.LayerNorm between
hidden layers, and jaxonnxruntime does NOT implement the
LayerNormalization op (raises NotImplementedError at translation time).
The fix is to pass norm_layer=nn.Identity through classifier_nn kwargs.
This constraint will be documented in C6.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds docs/exporting_sbi_models.md as a Guides entry alongside the
MLflow and HuggingFace integration guides. Wires it into mkdocs.yml
nav. Adds a one-line mention in README.md pointing users to the
guide.

The guide covers:
  - Installation (pip install lanfactory[all])
  - Quick-start examples for NLE and NRE
  - Supported architecture matrix (NLE+MAF, NRE+MLP/FC/CNN embeddings)
  - Explicitly-out-of-scope list (NSF, FMPE, NPSE, NPE, TabPFN) with
    one-sentence rationales each
  - Known constraints surfaced during C2-C5:
      * Use 2D+ for theta and x (1D MAFs degenerate in sbi)
      * Disable LayerNorm in NRE MLP classifiers (norm_layer=Identity)
      * Enable jax_enable_x64 before importing JAX in the consumer
  - Numerical guarantees from the regression tests (atol=1e-5 forward,
    atol=1e-4 gradients)
  - Float precision interaction with PyMC

The new function transform_sbi_to_onnx is auto-documented on
docs/api/onnx.md via the existing :::lanfactory.onnx mkdocstrings
directive — no manual API page changes needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds tests/test_sbi_hssm_integration.py exercising the full keystone
pipeline:
  1. Train tiny sbi NLE_A on synthetic DDM data (ssm-simulators).
  2. Export via lanfactory.onnx.transform_sbi_to_onnx.
  3. Build HSSM model with model="ddm",
     loglik_kind="approx_differentiable", loglik=<path>.
  4. Short MCMC (500 draws + 500 tune, 2 chains) and verify posterior
     mean recovery within +/- 2 sigma and r_hat < 1.05.

Two test functions:
  - test_hssm_model_builds_from_sbi_onnx: verifies the exported ONNX
    loads cleanly into hssm.HSSM (no sampling).
  - test_hssm_mcmc_recovers_ddm_parameters: full MCMC + recovery
    assertion.

Skip guard via pytest.importorskip("hssm") so the test no-ops when HSSM
is not in the env. Currently the test is a no-op in LANfactory's local
uv venv because LANfactory's flax>=0.10.6 pin pulls a JAX version
incompatible with HSSM's numpyro 0.21.0 pin. The test is intended to
run only in a coordinated cross-repo CI environment that resolves both
packages together. Plan tracks this as future ecosystem cleanup.

The C7a HSSM patch (commit d1d7ffe on HSSM sbi-integration branch)
makes jax_enable_x64 self-managed inside HSSM's onnx2jax, so this test
does not need to set it explicitly.

Marked @pytest.mark.flaky(reruns=2, reruns_delay=5) on both test
functions to match HSSM's existing ONNX-test convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Surfaced by running the C8 notebook in HSSM: pymc.sample raised
"IndexError: list assignment index out of range" inside
jaxonnxruntime/onnx_ops/slice.py:113 (sub_indx[axis] = slices[i]).

Root cause: HSSM's make_jax_logp_funcs_from_onnx vmaps the per-trial
loglik over a 1D concatenated input vector (param_vector + data) — see
HSSM repos/HSSM/src/hssm/distribution_utils/onnx.py around line 115:

    input_vector = jnp.concatenate((param_vector, data))
    return jax_func(input_vector)

But the C3/C4 exporter was tracing with a 2D dummy of shape
(1, theta_dim + x_dim), which made torch.onnx.export emit Slice ops with
axes=[1]. Under HSSM's vmap the per-trial input is rank-1, so axes=[1]
is out of bounds for the inner Slice handler.

LAN exports don't trip on this because the LAN graph is pure
MatMul/Add/activation — broadcast-rank-agnostic. Ours has explicit Slice
ops from `combined[..., :theta_dim]` and `combined[theta_dim:]`.

Fix:
  - Trace the wrapper with a rank-1 dummy (shape (theta_dim+x_dim,))
    so Slice ops emit axes=[0], which survives HSSM's vmap.
  - Inside _NLELogProbWrapper.forward and _NRELogRatioWrapper.forward,
    take a 1D combined input, split on axis 0, then .unsqueeze(0) the
    two halves to satisfy sbi's batched log_prob / classifier APIs.
    Reshape the (1, 1) output back to () so HSSM's downstream .squeeze()
    sees a clean scalar.
  - Updated module docstring to document the rank-1 contract and why.

Tests:
  - test_sbi_nle_export.py, test_sbi_nre_export.py, test_sbi_embeddings.py:
    pass rank-1 inputs through onnxruntime and jaxonnxruntime; rank-1
    theta_np_1d / x_np_1d for the gradient tests.
  - All 13 sbi tests still green at the same atol thresholds
    (1e-5 forward, 1e-4 gradients).

User impact: anyone who already exported a .onnx with the old C3/C4
code needs to re-export with this commit. The exported .onnx is the
durable artifact — no API change in the call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Copilot AI review requested due to automatic review settings May 18, 2026 03:01

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ONNX exporter for trained sbi NLE/NRE estimators so they can be consumed by HSSM’s differentiable likelihood path, alongside documentation and regression coverage for ONNX/JAX round-trips.

Changes:

  • Introduces lanfactory.onnx.transform_sbi_to_onnx for rank-1 single-trial NLE/NRE ONNX export.
  • Adds sbi/nflows-related optional/dev dependencies and lockfile updates.
  • Adds docs and tests covering MLP/MAF round-trips, NLE/NRE exports, embeddings, and optional HSSM integration.

Reviewed changes

Copilot reviewed 12 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/lanfactory/onnx/sbi.py Implements the new sbi-to-ONNX exporter and wrappers.
src/lanfactory/onnx/__init__.py Exposes transform_sbi_to_onnx from the ONNX package.
pyproject.toml Adds optional and dev dependencies for sbi ONNX export/testing.
uv.lock Locks new transitive dependencies.
tests/test_sbi_spike_mlp_roundtrip.py Adds baseline MLP ONNX/JAX/ORT round-trip test.
tests/test_sbi_spike_maf_roundtrip.py Adds nflows MAF ONNX/JAX/ORT round-trip test.
tests/test_sbi_nle_export.py Adds NLE exporter numerical, gradient, and validation tests.
tests/test_sbi_nre_export.py Adds NRE exporter numerical, gradient, and sanity tests.
tests/test_sbi_embeddings.py Adds NRE embedding-net export round-trip tests.
tests/test_sbi_hssm_integration.py Adds optional end-to-end HSSM integration tests.
docs/exporting_sbi_models.md Documents installation, usage, supported architectures, and constraints.
README.md Mentions the new sbi ONNX exporter.
mkdocs.yml Adds the sbi exporter guide to docs navigation.
.gitignore Ignores sbi-generated TensorBoard logs.
Comments suppressed due to low confidence (1)

src/lanfactory/onnx/sbi.py:107

  • NLE mode accepts any module with .log_prob, which also includes posterior-shaped sbi estimators such as NPE/SNPE. Those estimators model p(theta | x) with the opposite input/condition semantics, so this wrapper can silently export a posterior density as if it were a likelihood (especially when theta_dim == x_dim). Add an explicit estimator-family check for true likelihood estimators before building the NLE wrapper, or require callers to pass a likelihood-specific estimator type that can be validated.
    if mode == "nle":
        if not hasattr(estimator, "log_prob"):
            raise TypeError(
                f"NLE mode requires an estimator with "
                f".log_prob(input, condition); got {estimator_cls} which lacks "
                f"it. If this is an NRE ratio classifier, use mode='nre' "
                f"instead."
            )
        wrapper: nn.Module = _NLELogProbWrapper(

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pyproject.toml
@@ -74,6 +79,9 @@ dev = [
"ruff>=0.14.4",
"types-PyYAML",
"mlflow>=3.6.0",
Comment on lines +36 to +43
_UNSUPPORTED_ESTIMATORS: frozenset[str] = frozenset(
{
"ScoreEstimator",
"ConditionalScoreEstimator",
"FlowMatchingEstimator",
"ConditionalFlowMatchingEstimator",
"TabPFNEstimator",
}
AlexanderFengler and others added 4 commits June 20, 2026 01:48
The sbi test modules hard-import `sbi` at module top, but `sbi` was only
declared in the `[sbi]`/`[all]` optional extras. CI runs
`uv sync --all-groups`, which installs dependency-groups but not extras, so
pytest aborted at collection with `ModuleNotFoundError: No module named
'sbi'` (cancelling the whole matrix under fail-fast). Add `sbi>=0.26` to the
`dev` group — mirroring how `nflows`/`jaxonnxruntime`/`onnxruntime` are
already declared there — so the sbi tests are collected and run in CI. End
users remain free of a forced sbi install (it stays an optional extra).

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
The ValueError raised for score-based / flow-matching / TabPFN / NSF
estimators pointed users to `plans/sbi-onnx-integration.md in HSSMSpine`,
which does not exist. Repoint to the shipped guide
`docs/exporting_sbi_models.md` and its "Explicitly out of scope (v1)"
section, which carries the actual supported/unsupported matrix.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Brings the sbi → ONNX exporter branch up to date with main (v0.6.1:
ecosystem logos, docs theme, badges). README.md and mkdocs.yml auto-merged
cleanly — main's logo/theme additions and this branch's sbi-export
doc/nav links live in disjoint sections.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
These files predate the `ruff>=0.14.4` pin added to the dev group and were
never caught by CI because the suite aborted at collection first (missing
`sbi`). With collection fixed, CI's `ruff format --check .` step now runs and
flagged them. Changes are purely cosmetic — collapsing multi-line function
signatures and f-strings that fit within the 88-char limit; no behavior
change.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: ecd3726a-9c0b-486f-a1b2-d6a49d7ffe9c

📥 Commits

Reviewing files that changed from the base of the PR and between 9540c74 and b319efe.

📒 Files selected for processing (4)
  • docs/exporting_sbi_models.md
  • src/lanfactory/onnx/sbi.py
  • tests/test_sbi_hssm_integration.py
  • tests/test_sbi_nle_export.py
✅ Files skipped from review due to trivial changes (1)
  • docs/exporting_sbi_models.md
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/lanfactory/onnx/sbi.py
  • tests/test_sbi_nle_export.py
  • tests/test_sbi_hssm_integration.py

📝 Walkthrough

Walkthrough

Adds transform_sbi_to_onnx to lanfactory.onnx, which wraps trained SBI NLE_A/NRE_A estimators in thin nn.Module adapters and traces them to ONNX via torch.onnx.export. New test modules validate three-way numerical and gradient agreement (PyTorch/ONNX Runtime/JAX) for NLE, NRE, embedding variants, and an HSSM end-to-end MCMC integration. Documentation, mkdocs.yml navigation, and project dependencies are also updated.

Changes

SBI ONNX Exporter Feature

Layer / File(s) Summary
Dependencies and gitignore
pyproject.toml, .gitignore
sbi>=0.26, nflows>=0.14, onnxruntime>=1.17, jaxonnxruntime>=0.3 added to [all] extra and dev deps; sbi-logs/ excluded from version control.
transform_sbi_to_onnx public API and wrapper classes
src/lanfactory/onnx/sbi.py, src/lanfactory/onnx/__init__.py
Adds transform_sbi_to_onnx with class-name–based unsupported-estimator rejection, mode-specific capability checks, _NLELogProbWrapper and _NRELogRatioWrapper that split a concatenated rank-1 input and reshape to scalar, and torch.onnx.export tracing. __init__.py re-exports the new function.
ONNX toolchain baseline spike tests
tests/test_sbi_spike_mlp_roundtrip.py, tests/test_sbi_spike_maf_roundtrip.py
Two standalone regression tests validate torch→onnxruntime→jaxonnxruntime round-trips for a plain MLP and an nflows MAF, including jaxonnxruntime Constant-node static-arg configuration for MAF graphs.
NLE export validation tests
tests/test_sbi_nle_export.py
Trains a tiny NLE_A on a 2D Gaussian toy, exports with mode="nle", asserts three-way numerical and gradient agreement, and validates that unsupported estimator types, missing .log_prob, NLE-in-NRE-mode, and non-positive example dimensions each raise the correct errors.
NRE export validation tests
tests/test_sbi_nre_export.py
Trains a small NRE_A classifier on a 2D Gaussian toy, exports with mode="nre", and asserts three-way numerical and gradient agreement; includes log-ratio ordering sanity check.
NRE export with embedding networks
tests/test_sbi_embeddings.py
Trains NRE_A classifiers using FCEmbedding and CNNEmbedding on a 10D embedded input, exports with mode="nre", and validates three-way numerical agreement; documents the norm_layer=nn.Identity workaround for ONNX LayerNorm incompatibility.
HSSM end-to-end integration test
tests/test_sbi_hssm_integration.py
Trains NLE_A on synthetic DDM data, exports to ONNX, loads into hssm.HSSM with loglik_kind="approx_differentiable", runs short MCMC, and asserts Gelman-Rubin convergence and posterior mean recovery within ±2σ of ground truth.
User documentation and navigation
docs/exporting_sbi_models.md, mkdocs.yml, README.md
Adds a full guide covering NLE/NRE quick-start workflows, supported architecture table, known constraints (LayerNorm, JAX x64, dimensionality), numerical guarantees, and float precision notes. Updates mkdocs.yml navigation and README.md with an ONNX exporter summary.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant transform_sbi_to_onnx
  participant _NLELogProbWrapper
  participant torch_onnx_export as torch.onnx.export
  participant ONNXRuntime
  participant JAXONNXRuntime as jaxonnxruntime
  participant HSSM

  User->>transform_sbi_to_onnx: estimator (NLE_A/NRE_A), path, mode, theta_dim, x_dim
  transform_sbi_to_onnx->>transform_sbi_to_onnx: validate class name + mode capabilities
  transform_sbi_to_onnx->>_NLELogProbWrapper: wrap(estimator, theta_dim, x_dim)
  _NLELogProbWrapper-->>transform_sbi_to_onnx: wrapper.forward(concat_input) → scalar
  transform_sbi_to_onnx->>torch_onnx_export: export(wrapper, dummy rank-1 input, opset=17)
  torch_onnx_export-->>User: .onnx file written to path

  User->>ONNXRuntime: InferenceSession(path).run(concat_input)
  ONNXRuntime-->>User: log_prob scalar
  User->>JAXONNXRuntime: call_onnx_model(onnx_model)(concat_input)
  JAXONNXRuntime-->>User: JAX log_prob + gradients
  User->>HSSM: HSSM(loglik=path, loglik_kind="approx_differentiable")
  HSSM-->>User: posterior samples via MCMC
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Hippity-hop, the networks take flight,
Through ONNX portals into JAX's light.
NLE and NRE, wrapped with care,
Scalar log-probs floating through the air.
Three runtimes agree — what a sight!
The bunny exports everything just right. 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.53% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly summarizes the main addition: introducing a new SBI-to-ONNX exporter for neural likelihood and ratio estimators, which is the primary focus of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch sbi-connector

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
docs/exporting_sbi_models.md (1)

96-134: 💤 Low value

Known constraints section is thorough and actionable.

The three constraints (≥2D, LayerNorm workaround, JAX x64) are well-motivated and include working code examples:

  • Constraint 2 (LayerNorm) is supported by the PR objectives and aligns with the documented sbi-method incompatibility.
  • Constraint 3 (JAX x64) is validated by the HSSM integration test context (Context snippet 4), which notes that HSSM's onnx2jax self-manages the jaxort_only_allow_initializers_as_static_args flag but users must opt into x64 themselves.
  • Each constraint includes a code snippet or clear reproduction steps.

Minor: Consider adding a cross-reference from "Supported architectures" to this section for readers who may wonder why certain combinations are excluded.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/exporting_sbi_models.md` around lines 96 - 134, The "Known constraints"
section is well-documented, but readers consulting the "Supported architectures"
section may not understand why certain architecture combinations are excluded or
unsupported. Add a cross-reference from the "Supported architectures" section to
the "Known constraints" section, explaining that certain combinations (like 1D θ
or x, or LayerNorm in classifiers) have constraints that limit exportability.
This helps readers connect architectural choices to the documented limitations.
tests/test_sbi_spike_maf_roundtrip.py (1)

38-38: Move config mutation into a pytest fixture to ensure proper cleanup.

Module-level mutations of config.update("jaxort_only_allow_initializers_as_static_args", False) at import time (line 38) leak across test modules and never restore the previous state. This appears in at least 4 test files with no visible restoration logic, creating order-dependent test behavior. Wrap the mutation in a fixture that saves and restores the flag value:

`@pytest.fixture`(scope="function", autouse=True)
def restore_jaxort_config():
    from jaxonnxruntime import config
    prev = config._flags.get("jaxort_only_allow_initializers_as_static_args", True)
    config.update("jaxort_only_allow_initializers_as_static_args", False)
    yield
    config.update("jaxort_only_allow_initializers_as_static_args", prev)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_sbi_spike_maf_roundtrip.py` at line 38, The module-level
config.update call at line 38 mutates the
jaxort_only_allow_initializers_as_static_args configuration at import time
without any cleanup or restoration mechanism, causing test state to leak across
test modules. Move this configuration mutation out of module scope and into a
pytest fixture with function scope and autouse enabled, so that the fixture
saves the previous configuration value before updating it, yields control to run
the test, and then restores the original configuration value afterward to
prevent test pollution and order-dependent behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/exporting_sbi_models.md`:
- Around line 101-104: The documentation at the specified location presents the
≥2D constraint ambiguously as a universal requirement when it actually applies
only to NLE with density_estimator="maf", not all density estimators or the
LANfactory exporter itself. Revise the explanation to clarify that this
constraint is specific to sbi's MAF density estimator and is enforced by sbi's
training-time behavior in nflows, not by the transform_sbi_to_onnx exporter.
Additionally, note that other density estimators such as MDN and MoG may not
have this 1D limitation, and that NRE successfully handles 1D inputs, making
clear this is a MAF-specific pre-flight limitation rather than a LANfactory
enforcement.

In `@src/lanfactory/onnx/sbi.py`:
- Around line 47-55: The transform_sbi_to_onnx function accepts
example_theta_dim and example_x_dim parameters but does not validate that they
are positive integers, which can lead to confusing errors later. Add input
validation at the start of the function to check that both example_theta_dim and
example_x_dim are positive integers (greater than zero), and raise a ValueError
with a clear message if either parameter fails validation.

In `@tests/test_sbi_embeddings.py`:
- Around line 79-85: Add a third assertion to complete the three-way pairwise
comparison between all three backends. After the existing assertions comparing
y_torch against y_ort and y_jax, add another assertion that directly compares
y_ort against y_jax using np.allclose with the same atol tolerance parameter.
Follow the same pattern as the existing assertions by including an error message
that displays the maximum absolute difference between y_ort and y_jax to ensure
full pairwise agreement among all three backends.

In `@tests/test_sbi_hssm_integration.py`:
- Around line 59-64: The _build_observed_dataframe function accepts an rng
parameter but never uses it when calling simulator, resulting in uncontrolled
random data generation. Pass the rng parameter to the simulator function call to
ensure the seeded random number generator is actually used for reproducible test
behavior.
- Around line 135-140: The hasattr call on line 135 attempts to access
hssm.utils.summary before verifying that the utils module exists on the hssm
object. If hssm lacks the utils attribute, this raises AttributeError
immediately before the hasattr check can protect against it. Guard the
hssm.utils reference by first checking that hssm has the utils attribute before
checking for the summary attribute within it. Restructure the hasattr condition
to verify hssm.utils exists as a prerequisite before attempting to access
hssm.utils.summary, ensuring the ArviZ fallback executes properly when either
module is missing.

---

Nitpick comments:
In `@docs/exporting_sbi_models.md`:
- Around line 96-134: The "Known constraints" section is well-documented, but
readers consulting the "Supported architectures" section may not understand why
certain architecture combinations are excluded or unsupported. Add a
cross-reference from the "Supported architectures" section to the "Known
constraints" section, explaining that certain combinations (like 1D θ or x, or
LayerNorm in classifiers) have constraints that limit exportability. This helps
readers connect architectural choices to the documented limitations.

In `@tests/test_sbi_spike_maf_roundtrip.py`:
- Line 38: The module-level config.update call at line 38 mutates the
jaxort_only_allow_initializers_as_static_args configuration at import time
without any cleanup or restoration mechanism, causing test state to leak across
test modules. Move this configuration mutation out of module scope and into a
pytest fixture with function scope and autouse enabled, so that the fixture
saves the previous configuration value before updating it, yields control to run
the test, and then restores the original configuration value afterward to
prevent test pollution and order-dependent behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: 799ad103-eb67-453a-ae45-e8cf9f0dd877

📥 Commits

Reviewing files that changed from the base of the PR and between 357a1c3 and fe5d806.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (13)
  • .gitignore
  • README.md
  • docs/exporting_sbi_models.md
  • mkdocs.yml
  • pyproject.toml
  • src/lanfactory/onnx/__init__.py
  • src/lanfactory/onnx/sbi.py
  • tests/test_sbi_embeddings.py
  • tests/test_sbi_hssm_integration.py
  • tests/test_sbi_nle_export.py
  • tests/test_sbi_nre_export.py
  • tests/test_sbi_spike_maf_roundtrip.py
  • tests/test_sbi_spike_mlp_roundtrip.py

Comment thread docs/exporting_sbi_models.md Outdated
Comment thread src/lanfactory/onnx/sbi.py
Comment on lines +79 to +85
assert np.allclose(y_torch, y_ort, atol=atol), (
f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}"
)
assert np.allclose(y_torch, y_jax, atol=atol), (
f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add direct onnxruntime↔jaxonnxruntime assertion to complete the three-way check.

Lines 79–85 only anchor each backend against Torch. Add a direct y_ort vs y_jax assertion so the helper enforces full pairwise agreement.

Proposed patch
@@
     assert np.allclose(y_torch, y_ort, atol=atol), (
         f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}"
     )
     assert np.allclose(y_torch, y_jax, atol=atol), (
         f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}"
     )
+    assert np.allclose(y_ort, y_jax, atol=atol), (
+        f"onnxruntime vs jaxonnxruntime: max |Δ| = {np.abs(y_ort - y_jax).max()}"
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert np.allclose(y_torch, y_ort, atol=atol), (
f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}"
)
assert np.allclose(y_torch, y_jax, atol=atol), (
f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}"
)
assert np.allclose(y_torch, y_ort, atol=atol), (
f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}"
)
assert np.allclose(y_torch, y_jax, atol=atol), (
f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}"
)
assert np.allclose(y_ort, y_jax, atol=atol), (
f"onnxruntime vs jaxonnxruntime: max |Δ| = {np.abs(y_ort - y_jax).max()}"
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_sbi_embeddings.py` around lines 79 - 85, Add a third assertion to
complete the three-way pairwise comparison between all three backends. After the
existing assertions comparing y_torch against y_ort and y_jax, add another
assertion that directly compares y_ort against y_jax using np.allclose with the
same atol tolerance parameter. Follow the same pattern as the existing
assertions by including an error message that displays the maximum absolute
difference between y_ort and y_jax to ensure full pairwise agreement among all
three backends.

Comment thread tests/test_sbi_hssm_integration.py
Comment on lines +135 to +140
summary = hssm.utils.summary(idata) if hasattr(hssm.utils, "summary") else None
# Fall back to arviz if the convenience method is not exposed.
if summary is None:
import arviz as az

summary = az.summary(idata, var_names=_DDM_PARAM_NAMES)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import importlib
h = importlib.import_module("hssm")
print("has hssm.utils:", hasattr(h, "utils"))
if hasattr(h, "utils"):
    print("has hssm.utils.summary:", hasattr(h.utils, "summary"))
PY

Repository: lnccbrown/LANfactory

Length of output: 601


🏁 Script executed:

cat -n tests/test_sbi_hssm_integration.py | sed -n '130,145p'

Repository: lnccbrown/LANfactory

Length of output: 724


🏁 Script executed:

# Look for setup files that might specify HSSM version constraints
find . -type f \( -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" \) | head -20

Repository: lnccbrown/LANfactory

Length of output: 506


🏁 Script executed:

# Check if there's documentation about HSSM versions or changelog
find . -type f -name "*requirements*" -o -name "*version*" | grep -E "\.(txt|py|md)$" | head -20

Repository: lnccbrown/LANfactory

Length of output: 489


🏁 Script executed:

cat pyproject.toml

Repository: lnccbrown/LANfactory

Length of output: 3901


🏁 Script executed:

head -50 tests/test_sbi_hssm_integration.py

Repository: lnccbrown/LANfactory

Length of output: 2082


🏁 Script executed:

grep -n "import hssm\|from hssm" tests/test_sbi_hssm_integration.py

Repository: lnccbrown/LANfactory

Length of output: 46


🏁 Script executed:

# Check the actual structure around line 135 to understand the context better
sed -n '120,150p' tests/test_sbi_hssm_integration.py

Repository: lnccbrown/LANfactory

Length of output: 1024


🌐 Web query:

HSSM Python library utils module version history

💡 Result:

The HSSM (Hierarchical Sequential Sampling Modeling) Python library organizes utility functions across several submodules rather than a single utils module [1][2][3]. The most prominent utility-focused submodule is hssm.distribution_utils, which provides functions for building PyMC distribution classes and handling advanced model construction [2][3]. Other utility functions are found in specialized submodules, such as hssm.utils (e.g., decorate_atomic_simulator) [1] and hssm.rl.utils (introduced in version 0.3.0 for balanced-panel validation in reinforcement learning models) [4]. Regarding the version history, HSSM has undergone continuous development since its initial releases [5][6]. Significant recent updates include: Version 0.3.0: A major release featuring support for choice-only models, the Racing Diffusion Model (RDM3), Poisson Race model, BayesFlow LRE integration, and a new reinforcement learning (RLSSM) configuration system [5][7]. This version also introduced hssm.rl.utils [4]. Version 0.2.12: Introduced compatibility with Python 3.13 and Bambi 0.17.1+, added the make_distribution_for_supported_model utility function, and began private implementations of RLSSM features [5][7]. Version 0.2.x series: These versions focused on incremental improvements, including expanded model support, bug fixes, and documentation updates, building upon the foundations laid in earlier 0.1.x releases [5][6]. For the most up-to-date information on modules and release-specific changes, the official HSSM changelog and API documentation should be consulted [8][5][7].

Citations:


Guard hssm.utils before dereferencing it.

Line 135 evaluates hssm.utils before the hasattr() check can protect against its absence. If a given HSSM version lacks the utils module, this raises AttributeError immediately and the ArviZ fallback never executes. Add a guard for the module's existence:

Suggested fix
-    summary = hssm.utils.summary(idata) if hasattr(hssm.utils, "summary") else None
+    summary = (
+        hssm.utils.summary(idata)
+        if hasattr(hssm, "utils") and hasattr(hssm.utils, "summary")
+        else None
+    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_sbi_hssm_integration.py` around lines 135 - 140, The hasattr call
on line 135 attempts to access hssm.utils.summary before verifying that the
utils module exists on the hssm object. If hssm lacks the utils attribute, this
raises AttributeError immediately before the hasattr check can protect against
it. Guard the hssm.utils reference by first checking that hssm has the utils
attribute before checking for the summary attribute within it. Restructure the
hasattr condition to verify hssm.utils exists as a prerequisite before
attempting to access hssm.utils.summary, ensuring the ArviZ fallback executes
properly when either module is missing.

@codecov

codecov Bot commented Jun 20, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/lanfactory/onnx/sbi.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov/patch flagged line 125 (the `else: raise ValueError` for an
unrecognized `mode`) as the one uncovered line in sbi.py. Add a guard test
that passes an invalid mode and asserts the ValueError, bringing sbi.py to
100% patch coverage. The test verifies a real input-validation contract,
mirroring the existing test_transform_rejects_* guard tests.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Copilot AI review requested due to automatic review settings June 20, 2026 00:34

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 14 changed files in this pull request and generated 7 comments.

Comment on lines +127 to +131
wrapper.eval()
combined_input_dim = example_theta_dim + example_x_dim
# Trace with a rank-1 dummy so the resulting Slice ops use axes=[0],
# which survives HSSM's per-trial vmap (where the input arrives as 1D).
dummy_input = torch.randn(combined_input_dim, requires_grad=True)
Comment on lines +12 to +18
import jax
import numpy as np
import onnx
import onnxruntime as ort
import pytest
import torch
from jaxonnxruntime import call_onnx, config
Comment on lines +5 to +6
regress on a vanilla MLP, this test catches it before debugging the real
exporter. Kept as a permanent regression guard per plans/sbi-onnx-integration.md.
Comment on lines +5 to +7
accumulation, and the affine-autoregressive ops translate cleanly into
jaxonnxruntime. Kept as a permanent regression guard per
plans/sbi-onnx-integration.md.
Comment on lines +8 to +9
4. Run a short MCMC and verify posterior means recover the ground truth
(within ±2σ) and r_hat < 1.01.
Comment on lines +12 to +15
environment. LANfactory's regular CI does not currently install HSSM, so the
test is a no-op locally — it is intended to run in a coordinated cross-repo
CI matrix where both packages are available with compatible JAX pins. See
plans/sbi-onnx-integration.md C7b for the environment-resolution note.
Comment on lines +132 to +134
HSSM's `onnx2jax` consumer sets the related `jaxort_only_allow_initializers_as_static_args = False`
flag automatically, but the x64 setting is process-wide and must be opted
into by the caller.
Three reviewer findings on PR #79:

- sbi.py: validate example_theta_dim/example_x_dim are positive, raising a
  clear ValueError instead of failing deep in ONNX tracing. Adds a guard test
  (keeps sbi.py at 100% patch coverage). (CodeRabbit placed its suggested
  check inside the docstring; this puts it after.)
- docs/exporting_sbi_models.md: clarify the ">=2D" constraint is specific to
  sbi's MAF NLE training (nflows), not enforced by transform_sbi_to_onnx —
  NRE classifiers and other density estimators are not bound by it.
- test_sbi_hssm_integration.py: _build_observed_dataframe accepted an `rng`
  it never used; thread it through ssms simulator's `random_state` so observed
  data is actually reproducible.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@AlexanderFengler AlexanderFengler merged commit f59f0a8 into main Jun 20, 2026
8 of 14 checks passed
AlexanderFengler added a commit that referenced this pull request Jun 20, 2026
Brings the merged sbi → ONNX work (PR #79) and v0.6.1 docs/logos into the
bayesflow branch. Only pyproject.toml conflicted: both sides appended to the
dev dependency group. Resolved as the union — keep bayesflow's
`bayesflow`/`keras` and main's `sbi` — so the full sbi + bayesflow test
matrix is installed by `uv sync --all-groups`. uv.lock regenerated (uv 0.6.5)
to match the resolved manifest. All sbi-file changes (dim validation, doc-ref
fix, ruff formatting, guard tests, seeded-RNG fix) come straight from main.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants