feat(onnx): sbi → ONNX exporter (transform_sbi_to_onnx) for NLE + NRE#79
Conversation
Adds a stub transform_sbi_to_onnx in lanfactory/onnx/sbi.py as a sibling of the existing LAN exporter. Extends the `all` extra to pull sbi and nflows, and adds jaxonnxruntime to the dev group for round-trip testing. First commit of the sbi -> HSSM integration plan (plans/sbi-onnx-integration.md in HSSMSpine). Implementation lands in C3 (NLE) and C4 (NRE). Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds two permanent regression-guard tests validating the
torch.onnx.export to {onnxruntime, jaxonnxruntime} toolchain that the
sbi exporter (C3) will sit on top of. Both tests assert three-way
numerical agreement to 1e-5 on fixed inputs.
The MAF spike surfaced a real friction: the nflows MAF exports a
Reshape whose shape argument is a Constant node rather than a model
initializer, which jaxonnxruntime default strict mode rejects. Setting
jaxort_only_allow_initializers_as_static_args = False works around it.
Architectural implication for C3: HSSM onnx2jax.py does not set this
flag today, so sbi-exported flow graphs will fail to load through the
HSSM make_jax_logp_funcs_from_onnx path as-is. C3 should either
constant-fold the exported graph (preferred, keeps HSSM untouched) or
we will need a small HSSM-side patch.
Also adds onnxruntime>=1.17 and nflows>=0.14 to the dev dependency
group so uv sync --group dev is sufficient to run these tests.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
…(C3)
Replaces the C1 stub with the real transform_sbi_to_onnx implementation
for mode=nle. The exporter wraps an sbi ConditionalDensityEstimator
(NLE_A trained estimator) as a torch.nn.Module whose forward(combined)
splits a concatenated (theta, x) input and returns log p(x | theta) with
sbi standardization Jacobian baked into the traced graph. Exports a
single-trial graph at opset 17, matching the LAN convention and HSSM
vmap-over-trials expectation.
Rejection paths:
- Score-based, flow-matching, TabPFN estimators raise ValueError.
- NLE mode requires .log_prob(input, condition); clear TypeError if
absent.
- NRE mode currently raises NotImplementedError (lands in C4).
Tests in test_sbi_nle_export.py train a tiny 2D Gaussian NLE_A with MAF
and verify:
1. Three-way numerical agreement (torch / onnxruntime / jaxonnxruntime)
to atol=1e-5 on a fixed test point.
2. Gradient agreement (torch.autograd vs jax.grad of the translated
graph) to atol=1e-4.
3. Sanity check that log-prob ordering matches the analytical Gaussian
(near-mean point ranks above far point).
4. Three rejection-path tests for the error contracts above.
Two findings surfaced during C3 that affect later commits:
- 1D MAFs in sbi collapse to a degenerate Gaussian path with zero-width
Gemm contractions that jaxonnxruntime cannot handle. The exporter
must be exercised with >=2D theta and x. Documented in the simulator
docstring.
- jaxonnxruntime silently truncates int64 indices in exported flow
graphs to int32, causing ~0.5 drift in log-prob outputs. The fix is
jax.config.update("jax_enable_x64", True) BEFORE any JAX import.
The test file sets this. C7 will decide whether HSSM onnx2jax.py
should also set it globally (mirrors the C2.5 flag patch) or whether
it stays a user responsibility documented in C6/C8.
Also adds sbi-logs/ to .gitignore (sbi auto-writes tensorboard logs
during training).
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Extends transform_sbi_to_onnx to support mode="nre". The wrapper splits the concatenated (theta, x) input and routes through the trained RatioEstimator forward, returning the logit log r(x, theta). Up to a theta-independent constant the logit IS log p(x | theta), so MCMC and HSSM posterior path treat it as the likelihood. No Jacobian correction is needed since ratios are invariant to z-score standardization. Rejection: passing an estimator with .log_prob in mode="nre" raises TypeError, since that signals a density estimator (NLE) rather than a ratio classifier (NRE). The NLE path has the symmetric check. New test file tests/test_sbi_nre_export.py trains a tiny 2D Gaussian NRE_A and verifies the same three-way numerical agreement (atol=1e-5) and gradient agreement (atol=1e-4) as the NLE path, plus a sanity ordering check (log-ratio higher at near-theta than far-theta). The C3 NRE-not-implemented test was repurposed into a cross-mode rejection test: passing an NLE density estimator with mode="nre" now raises a clear TypeError instead of NotImplementedError. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds tests/test_sbi_embeddings.py exercising NRE_A with two embedding
nets on x:
- FCEmbedding (representative flat-MLP embedding)
- CNNEmbedding (1D conv stack; validates Conv / MaxPool / etc. survive
torch.onnx.export and translate cleanly into jaxonnxruntime)
Both tests train a tiny 2D-theta / 10-dim-x linear-Gaussian classifier
and assert three-way numerical agreement (torch / onnxruntime /
jaxonnxruntime, atol=1e-5).
Other sbi embeddings (PermutationInvariantEmbedding, ResNetEmbedding1D,
ResNetEmbedding2D, LRUEmbedding, TransformerEmbedding, CausalCNNEmbedding,
SpectralConvEmbedding) are out of v1 scope; can be added as follow-up
regressions if a user needs them.
C5 finding: sbi build_mlp_classifier defaults to nn.LayerNorm between
hidden layers, and jaxonnxruntime does NOT implement the
LayerNormalization op (raises NotImplementedError at translation time).
The fix is to pass norm_layer=nn.Identity through classifier_nn kwargs.
This constraint will be documented in C6.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds docs/exporting_sbi_models.md as a Guides entry alongside the
MLflow and HuggingFace integration guides. Wires it into mkdocs.yml
nav. Adds a one-line mention in README.md pointing users to the
guide.
The guide covers:
- Installation (pip install lanfactory[all])
- Quick-start examples for NLE and NRE
- Supported architecture matrix (NLE+MAF, NRE+MLP/FC/CNN embeddings)
- Explicitly-out-of-scope list (NSF, FMPE, NPSE, NPE, TabPFN) with
one-sentence rationales each
- Known constraints surfaced during C2-C5:
* Use 2D+ for theta and x (1D MAFs degenerate in sbi)
* Disable LayerNorm in NRE MLP classifiers (norm_layer=Identity)
* Enable jax_enable_x64 before importing JAX in the consumer
- Numerical guarantees from the regression tests (atol=1e-5 forward,
atol=1e-4 gradients)
- Float precision interaction with PyMC
The new function transform_sbi_to_onnx is auto-documented on
docs/api/onnx.md via the existing :::lanfactory.onnx mkdocstrings
directive — no manual API page changes needed.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds tests/test_sbi_hssm_integration.py exercising the full keystone
pipeline:
1. Train tiny sbi NLE_A on synthetic DDM data (ssm-simulators).
2. Export via lanfactory.onnx.transform_sbi_to_onnx.
3. Build HSSM model with model="ddm",
loglik_kind="approx_differentiable", loglik=<path>.
4. Short MCMC (500 draws + 500 tune, 2 chains) and verify posterior
mean recovery within +/- 2 sigma and r_hat < 1.05.
Two test functions:
- test_hssm_model_builds_from_sbi_onnx: verifies the exported ONNX
loads cleanly into hssm.HSSM (no sampling).
- test_hssm_mcmc_recovers_ddm_parameters: full MCMC + recovery
assertion.
Skip guard via pytest.importorskip("hssm") so the test no-ops when HSSM
is not in the env. Currently the test is a no-op in LANfactory's local
uv venv because LANfactory's flax>=0.10.6 pin pulls a JAX version
incompatible with HSSM's numpyro 0.21.0 pin. The test is intended to
run only in a coordinated cross-repo CI environment that resolves both
packages together. Plan tracks this as future ecosystem cleanup.
The C7a HSSM patch (commit d1d7ffe on HSSM sbi-integration branch)
makes jax_enable_x64 self-managed inside HSSM's onnx2jax, so this test
does not need to set it explicitly.
Marked @pytest.mark.flaky(reruns=2, reruns_delay=5) on both test
functions to match HSSM's existing ONNX-test convention.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Surfaced by running the C8 notebook in HSSM: pymc.sample raised
"IndexError: list assignment index out of range" inside
jaxonnxruntime/onnx_ops/slice.py:113 (sub_indx[axis] = slices[i]).
Root cause: HSSM's make_jax_logp_funcs_from_onnx vmaps the per-trial
loglik over a 1D concatenated input vector (param_vector + data) — see
HSSM repos/HSSM/src/hssm/distribution_utils/onnx.py around line 115:
input_vector = jnp.concatenate((param_vector, data))
return jax_func(input_vector)
But the C3/C4 exporter was tracing with a 2D dummy of shape
(1, theta_dim + x_dim), which made torch.onnx.export emit Slice ops with
axes=[1]. Under HSSM's vmap the per-trial input is rank-1, so axes=[1]
is out of bounds for the inner Slice handler.
LAN exports don't trip on this because the LAN graph is pure
MatMul/Add/activation — broadcast-rank-agnostic. Ours has explicit Slice
ops from `combined[..., :theta_dim]` and `combined[theta_dim:]`.
Fix:
- Trace the wrapper with a rank-1 dummy (shape (theta_dim+x_dim,))
so Slice ops emit axes=[0], which survives HSSM's vmap.
- Inside _NLELogProbWrapper.forward and _NRELogRatioWrapper.forward,
take a 1D combined input, split on axis 0, then .unsqueeze(0) the
two halves to satisfy sbi's batched log_prob / classifier APIs.
Reshape the (1, 1) output back to () so HSSM's downstream .squeeze()
sees a clean scalar.
- Updated module docstring to document the rank-1 contract and why.
Tests:
- test_sbi_nle_export.py, test_sbi_nre_export.py, test_sbi_embeddings.py:
pass rank-1 inputs through onnxruntime and jaxonnxruntime; rank-1
theta_np_1d / x_np_1d for the gradient tests.
- All 13 sbi tests still green at the same atol thresholds
(1e-5 forward, 1e-4 gradients).
User impact: anyone who already exported a .onnx with the old C3/C4
code needs to re-export with this commit. The exported .onnx is the
durable artifact — no API change in the call site.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
There was a problem hiding this comment.
Pull request overview
Adds an ONNX exporter for trained sbi NLE/NRE estimators so they can be consumed by HSSM’s differentiable likelihood path, alongside documentation and regression coverage for ONNX/JAX round-trips.
Changes:
- Introduces
lanfactory.onnx.transform_sbi_to_onnxfor rank-1 single-trial NLE/NRE ONNX export. - Adds sbi/nflows-related optional/dev dependencies and lockfile updates.
- Adds docs and tests covering MLP/MAF round-trips, NLE/NRE exports, embeddings, and optional HSSM integration.
Reviewed changes
Copilot reviewed 12 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
src/lanfactory/onnx/sbi.py |
Implements the new sbi-to-ONNX exporter and wrappers. |
src/lanfactory/onnx/__init__.py |
Exposes transform_sbi_to_onnx from the ONNX package. |
pyproject.toml |
Adds optional and dev dependencies for sbi ONNX export/testing. |
uv.lock |
Locks new transitive dependencies. |
tests/test_sbi_spike_mlp_roundtrip.py |
Adds baseline MLP ONNX/JAX/ORT round-trip test. |
tests/test_sbi_spike_maf_roundtrip.py |
Adds nflows MAF ONNX/JAX/ORT round-trip test. |
tests/test_sbi_nle_export.py |
Adds NLE exporter numerical, gradient, and validation tests. |
tests/test_sbi_nre_export.py |
Adds NRE exporter numerical, gradient, and sanity tests. |
tests/test_sbi_embeddings.py |
Adds NRE embedding-net export round-trip tests. |
tests/test_sbi_hssm_integration.py |
Adds optional end-to-end HSSM integration tests. |
docs/exporting_sbi_models.md |
Documents installation, usage, supported architectures, and constraints. |
README.md |
Mentions the new sbi ONNX exporter. |
mkdocs.yml |
Adds the sbi exporter guide to docs navigation. |
.gitignore |
Ignores sbi-generated TensorBoard logs. |
Comments suppressed due to low confidence (1)
src/lanfactory/onnx/sbi.py:107
- NLE mode accepts any module with
.log_prob, which also includes posterior-shaped sbi estimators such as NPE/SNPE. Those estimators modelp(theta | x)with the opposite input/condition semantics, so this wrapper can silently export a posterior density as if it were a likelihood (especially whentheta_dim == x_dim). Add an explicit estimator-family check for true likelihood estimators before building the NLE wrapper, or require callers to pass a likelihood-specific estimator type that can be validated.
if mode == "nle":
if not hasattr(estimator, "log_prob"):
raise TypeError(
f"NLE mode requires an estimator with "
f".log_prob(input, condition); got {estimator_cls} which lacks "
f"it. If this is an NRE ratio classifier, use mode='nre' "
f"instead."
)
wrapper: nn.Module = _NLELogProbWrapper(
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -74,6 +79,9 @@ dev = [ | |||
| "ruff>=0.14.4", | |||
| "types-PyYAML", | |||
| "mlflow>=3.6.0", | |||
| _UNSUPPORTED_ESTIMATORS: frozenset[str] = frozenset( | ||
| { | ||
| "ScoreEstimator", | ||
| "ConditionalScoreEstimator", | ||
| "FlowMatchingEstimator", | ||
| "ConditionalFlowMatchingEstimator", | ||
| "TabPFNEstimator", | ||
| } |
The sbi test modules hard-import `sbi` at module top, but `sbi` was only declared in the `[sbi]`/`[all]` optional extras. CI runs `uv sync --all-groups`, which installs dependency-groups but not extras, so pytest aborted at collection with `ModuleNotFoundError: No module named 'sbi'` (cancelling the whole matrix under fail-fast). Add `sbi>=0.26` to the `dev` group — mirroring how `nflows`/`jaxonnxruntime`/`onnxruntime` are already declared there — so the sbi tests are collected and run in CI. End users remain free of a forced sbi install (it stays an optional extra). Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
The ValueError raised for score-based / flow-matching / TabPFN / NSF estimators pointed users to `plans/sbi-onnx-integration.md in HSSMSpine`, which does not exist. Repoint to the shipped guide `docs/exporting_sbi_models.md` and its "Explicitly out of scope (v1)" section, which carries the actual supported/unsupported matrix. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Brings the sbi → ONNX exporter branch up to date with main (v0.6.1: ecosystem logos, docs theme, badges). README.md and mkdocs.yml auto-merged cleanly — main's logo/theme additions and this branch's sbi-export doc/nav links live in disjoint sections. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
These files predate the `ruff>=0.14.4` pin added to the dev group and were never caught by CI because the suite aborted at collection first (missing `sbi`). With collection fixed, CI's `ruff format --check .` step now runs and flagged them. Changes are purely cosmetic — collapsing multi-line function signatures and f-strings that fit within the 88-char limit; no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (4)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds ChangesSBI ONNX Exporter Feature
Sequence Diagram(s)sequenceDiagram
participant User
participant transform_sbi_to_onnx
participant _NLELogProbWrapper
participant torch_onnx_export as torch.onnx.export
participant ONNXRuntime
participant JAXONNXRuntime as jaxonnxruntime
participant HSSM
User->>transform_sbi_to_onnx: estimator (NLE_A/NRE_A), path, mode, theta_dim, x_dim
transform_sbi_to_onnx->>transform_sbi_to_onnx: validate class name + mode capabilities
transform_sbi_to_onnx->>_NLELogProbWrapper: wrap(estimator, theta_dim, x_dim)
_NLELogProbWrapper-->>transform_sbi_to_onnx: wrapper.forward(concat_input) → scalar
transform_sbi_to_onnx->>torch_onnx_export: export(wrapper, dummy rank-1 input, opset=17)
torch_onnx_export-->>User: .onnx file written to path
User->>ONNXRuntime: InferenceSession(path).run(concat_input)
ONNXRuntime-->>User: log_prob scalar
User->>JAXONNXRuntime: call_onnx_model(onnx_model)(concat_input)
JAXONNXRuntime-->>User: JAX log_prob + gradients
User->>HSSM: HSSM(loglik=path, loglik_kind="approx_differentiable")
HSSM-->>User: posterior samples via MCMC
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (2)
docs/exporting_sbi_models.md (1)
96-134: 💤 Low valueKnown constraints section is thorough and actionable.
The three constraints (≥2D, LayerNorm workaround, JAX x64) are well-motivated and include working code examples:
- Constraint 2 (LayerNorm) is supported by the PR objectives and aligns with the documented sbi-method incompatibility.
- Constraint 3 (JAX x64) is validated by the HSSM integration test context (Context snippet 4), which notes that HSSM's
onnx2jaxself-manages thejaxort_only_allow_initializers_as_static_argsflag but users must opt into x64 themselves.- Each constraint includes a code snippet or clear reproduction steps.
Minor: Consider adding a cross-reference from "Supported architectures" to this section for readers who may wonder why certain combinations are excluded.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/exporting_sbi_models.md` around lines 96 - 134, The "Known constraints" section is well-documented, but readers consulting the "Supported architectures" section may not understand why certain architecture combinations are excluded or unsupported. Add a cross-reference from the "Supported architectures" section to the "Known constraints" section, explaining that certain combinations (like 1D θ or x, or LayerNorm in classifiers) have constraints that limit exportability. This helps readers connect architectural choices to the documented limitations.tests/test_sbi_spike_maf_roundtrip.py (1)
38-38: Move config mutation into a pytest fixture to ensure proper cleanup.Module-level mutations of
config.update("jaxort_only_allow_initializers_as_static_args", False)at import time (line 38) leak across test modules and never restore the previous state. This appears in at least 4 test files with no visible restoration logic, creating order-dependent test behavior. Wrap the mutation in a fixture that saves and restores the flag value:`@pytest.fixture`(scope="function", autouse=True) def restore_jaxort_config(): from jaxonnxruntime import config prev = config._flags.get("jaxort_only_allow_initializers_as_static_args", True) config.update("jaxort_only_allow_initializers_as_static_args", False) yield config.update("jaxort_only_allow_initializers_as_static_args", prev)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_sbi_spike_maf_roundtrip.py` at line 38, The module-level config.update call at line 38 mutates the jaxort_only_allow_initializers_as_static_args configuration at import time without any cleanup or restoration mechanism, causing test state to leak across test modules. Move this configuration mutation out of module scope and into a pytest fixture with function scope and autouse enabled, so that the fixture saves the previous configuration value before updating it, yields control to run the test, and then restores the original configuration value afterward to prevent test pollution and order-dependent behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/exporting_sbi_models.md`:
- Around line 101-104: The documentation at the specified location presents the
≥2D constraint ambiguously as a universal requirement when it actually applies
only to NLE with density_estimator="maf", not all density estimators or the
LANfactory exporter itself. Revise the explanation to clarify that this
constraint is specific to sbi's MAF density estimator and is enforced by sbi's
training-time behavior in nflows, not by the transform_sbi_to_onnx exporter.
Additionally, note that other density estimators such as MDN and MoG may not
have this 1D limitation, and that NRE successfully handles 1D inputs, making
clear this is a MAF-specific pre-flight limitation rather than a LANfactory
enforcement.
In `@src/lanfactory/onnx/sbi.py`:
- Around line 47-55: The transform_sbi_to_onnx function accepts
example_theta_dim and example_x_dim parameters but does not validate that they
are positive integers, which can lead to confusing errors later. Add input
validation at the start of the function to check that both example_theta_dim and
example_x_dim are positive integers (greater than zero), and raise a ValueError
with a clear message if either parameter fails validation.
In `@tests/test_sbi_embeddings.py`:
- Around line 79-85: Add a third assertion to complete the three-way pairwise
comparison between all three backends. After the existing assertions comparing
y_torch against y_ort and y_jax, add another assertion that directly compares
y_ort against y_jax using np.allclose with the same atol tolerance parameter.
Follow the same pattern as the existing assertions by including an error message
that displays the maximum absolute difference between y_ort and y_jax to ensure
full pairwise agreement among all three backends.
In `@tests/test_sbi_hssm_integration.py`:
- Around line 59-64: The _build_observed_dataframe function accepts an rng
parameter but never uses it when calling simulator, resulting in uncontrolled
random data generation. Pass the rng parameter to the simulator function call to
ensure the seeded random number generator is actually used for reproducible test
behavior.
- Around line 135-140: The hasattr call on line 135 attempts to access
hssm.utils.summary before verifying that the utils module exists on the hssm
object. If hssm lacks the utils attribute, this raises AttributeError
immediately before the hasattr check can protect against it. Guard the
hssm.utils reference by first checking that hssm has the utils attribute before
checking for the summary attribute within it. Restructure the hasattr condition
to verify hssm.utils exists as a prerequisite before attempting to access
hssm.utils.summary, ensuring the ArviZ fallback executes properly when either
module is missing.
---
Nitpick comments:
In `@docs/exporting_sbi_models.md`:
- Around line 96-134: The "Known constraints" section is well-documented, but
readers consulting the "Supported architectures" section may not understand why
certain architecture combinations are excluded or unsupported. Add a
cross-reference from the "Supported architectures" section to the "Known
constraints" section, explaining that certain combinations (like 1D θ or x, or
LayerNorm in classifiers) have constraints that limit exportability. This helps
readers connect architectural choices to the documented limitations.
In `@tests/test_sbi_spike_maf_roundtrip.py`:
- Line 38: The module-level config.update call at line 38 mutates the
jaxort_only_allow_initializers_as_static_args configuration at import time
without any cleanup or restoration mechanism, causing test state to leak across
test modules. Move this configuration mutation out of module scope and into a
pytest fixture with function scope and autouse enabled, so that the fixture
saves the previous configuration value before updating it, yields control to run
the test, and then restores the original configuration value afterward to
prevent test pollution and order-dependent behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro Plus
Run ID: 799ad103-eb67-453a-ae45-e8cf9f0dd877
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (13)
.gitignoreREADME.mddocs/exporting_sbi_models.mdmkdocs.ymlpyproject.tomlsrc/lanfactory/onnx/__init__.pysrc/lanfactory/onnx/sbi.pytests/test_sbi_embeddings.pytests/test_sbi_hssm_integration.pytests/test_sbi_nle_export.pytests/test_sbi_nre_export.pytests/test_sbi_spike_maf_roundtrip.pytests/test_sbi_spike_mlp_roundtrip.py
| assert np.allclose(y_torch, y_ort, atol=atol), ( | ||
| f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}" | ||
| ) | ||
| assert np.allclose(y_torch, y_jax, atol=atol), ( | ||
| f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Add direct onnxruntime↔jaxonnxruntime assertion to complete the three-way check.
Lines 79–85 only anchor each backend against Torch. Add a direct y_ort vs y_jax assertion so the helper enforces full pairwise agreement.
Proposed patch
@@
assert np.allclose(y_torch, y_ort, atol=atol), (
f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}"
)
assert np.allclose(y_torch, y_jax, atol=atol), (
f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}"
)
+ assert np.allclose(y_ort, y_jax, atol=atol), (
+ f"onnxruntime vs jaxonnxruntime: max |Δ| = {np.abs(y_ort - y_jax).max()}"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert np.allclose(y_torch, y_ort, atol=atol), ( | |
| f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}" | |
| ) | |
| assert np.allclose(y_torch, y_jax, atol=atol), ( | |
| f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}" | |
| ) | |
| assert np.allclose(y_torch, y_ort, atol=atol), ( | |
| f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}" | |
| ) | |
| assert np.allclose(y_torch, y_jax, atol=atol), ( | |
| f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}" | |
| ) | |
| assert np.allclose(y_ort, y_jax, atol=atol), ( | |
| f"onnxruntime vs jaxonnxruntime: max |Δ| = {np.abs(y_ort - y_jax).max()}" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_sbi_embeddings.py` around lines 79 - 85, Add a third assertion to
complete the three-way pairwise comparison between all three backends. After the
existing assertions comparing y_torch against y_ort and y_jax, add another
assertion that directly compares y_ort against y_jax using np.allclose with the
same atol tolerance parameter. Follow the same pattern as the existing
assertions by including an error message that displays the maximum absolute
difference between y_ort and y_jax to ensure full pairwise agreement among all
three backends.
| summary = hssm.utils.summary(idata) if hasattr(hssm.utils, "summary") else None | ||
| # Fall back to arviz if the convenience method is not exposed. | ||
| if summary is None: | ||
| import arviz as az | ||
|
|
||
| summary = az.summary(idata, var_names=_DDM_PARAM_NAMES) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
import importlib
h = importlib.import_module("hssm")
print("has hssm.utils:", hasattr(h, "utils"))
if hasattr(h, "utils"):
print("has hssm.utils.summary:", hasattr(h.utils, "summary"))
PYRepository: lnccbrown/LANfactory
Length of output: 601
🏁 Script executed:
cat -n tests/test_sbi_hssm_integration.py | sed -n '130,145p'Repository: lnccbrown/LANfactory
Length of output: 724
🏁 Script executed:
# Look for setup files that might specify HSSM version constraints
find . -type f \( -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" \) | head -20Repository: lnccbrown/LANfactory
Length of output: 506
🏁 Script executed:
# Check if there's documentation about HSSM versions or changelog
find . -type f -name "*requirements*" -o -name "*version*" | grep -E "\.(txt|py|md)$" | head -20Repository: lnccbrown/LANfactory
Length of output: 489
🏁 Script executed:
cat pyproject.tomlRepository: lnccbrown/LANfactory
Length of output: 3901
🏁 Script executed:
head -50 tests/test_sbi_hssm_integration.pyRepository: lnccbrown/LANfactory
Length of output: 2082
🏁 Script executed:
grep -n "import hssm\|from hssm" tests/test_sbi_hssm_integration.pyRepository: lnccbrown/LANfactory
Length of output: 46
🏁 Script executed:
# Check the actual structure around line 135 to understand the context better
sed -n '120,150p' tests/test_sbi_hssm_integration.pyRepository: lnccbrown/LANfactory
Length of output: 1024
🌐 Web query:
HSSM Python library utils module version history
💡 Result:
The HSSM (Hierarchical Sequential Sampling Modeling) Python library organizes utility functions across several submodules rather than a single utils module [1][2][3]. The most prominent utility-focused submodule is hssm.distribution_utils, which provides functions for building PyMC distribution classes and handling advanced model construction [2][3]. Other utility functions are found in specialized submodules, such as hssm.utils (e.g., decorate_atomic_simulator) [1] and hssm.rl.utils (introduced in version 0.3.0 for balanced-panel validation in reinforcement learning models) [4]. Regarding the version history, HSSM has undergone continuous development since its initial releases [5][6]. Significant recent updates include: Version 0.3.0: A major release featuring support for choice-only models, the Racing Diffusion Model (RDM3), Poisson Race model, BayesFlow LRE integration, and a new reinforcement learning (RLSSM) configuration system [5][7]. This version also introduced hssm.rl.utils [4]. Version 0.2.12: Introduced compatibility with Python 3.13 and Bambi 0.17.1+, added the make_distribution_for_supported_model utility function, and began private implementations of RLSSM features [5][7]. Version 0.2.x series: These versions focused on incremental improvements, including expanded model support, bug fixes, and documentation updates, building upon the foundations laid in earlier 0.1.x releases [5][6]. For the most up-to-date information on modules and release-specific changes, the official HSSM changelog and API documentation should be consulted [8][5][7].
Citations:
- 1: https://lnccbrown.github.io/HSSM/tutorials/add_custom_rlssm_model/
- 2: https://lnccbrown.github.io/HSSM/api/distribution_utils/
- 3: https://github.com/lnccbrown/HSSM/blob/main/docs/api/distribution_utils.md
- 4: Rlssm class make model dist HSSM#915
- 5: https://lnccbrown.github.io/HSSM/changelog/
- 6: https://pypi.org/project/HSSM/0.2.8/
- 7: https://github.com/lnccbrown/HSSM/blob/main/docs/changelog.md
- 8: https://github.com/lnccbrown/HSSM
Guard hssm.utils before dereferencing it.
Line 135 evaluates hssm.utils before the hasattr() check can protect against its absence. If a given HSSM version lacks the utils module, this raises AttributeError immediately and the ArviZ fallback never executes. Add a guard for the module's existence:
Suggested fix
- summary = hssm.utils.summary(idata) if hasattr(hssm.utils, "summary") else None
+ summary = (
+ hssm.utils.summary(idata)
+ if hasattr(hssm, "utils") and hasattr(hssm.utils, "summary")
+ else None
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_sbi_hssm_integration.py` around lines 135 - 140, The hasattr call
on line 135 attempts to access hssm.utils.summary before verifying that the
utils module exists on the hssm object. If hssm lacks the utils attribute, this
raises AttributeError immediately before the hasattr check can protect against
it. Guard the hssm.utils reference by first checking that hssm has the utils
attribute before checking for the summary attribute within it. Restructure the
hasattr condition to verify hssm.utils exists as a prerequisite before
attempting to access hssm.utils.summary, ensuring the ArviZ fallback executes
properly when either module is missing.
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
codecov/patch flagged line 125 (the `else: raise ValueError` for an unrecognized `mode`) as the one uncovered line in sbi.py. Add a guard test that passes an invalid mode and asserts the ValueError, bringing sbi.py to 100% patch coverage. The test verifies a real input-validation contract, mirroring the existing test_transform_rejects_* guard tests. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
| wrapper.eval() | ||
| combined_input_dim = example_theta_dim + example_x_dim | ||
| # Trace with a rank-1 dummy so the resulting Slice ops use axes=[0], | ||
| # which survives HSSM's per-trial vmap (where the input arrives as 1D). | ||
| dummy_input = torch.randn(combined_input_dim, requires_grad=True) |
| import jax | ||
| import numpy as np | ||
| import onnx | ||
| import onnxruntime as ort | ||
| import pytest | ||
| import torch | ||
| from jaxonnxruntime import call_onnx, config |
| regress on a vanilla MLP, this test catches it before debugging the real | ||
| exporter. Kept as a permanent regression guard per plans/sbi-onnx-integration.md. |
| accumulation, and the affine-autoregressive ops translate cleanly into | ||
| jaxonnxruntime. Kept as a permanent regression guard per | ||
| plans/sbi-onnx-integration.md. |
| 4. Run a short MCMC and verify posterior means recover the ground truth | ||
| (within ±2σ) and r_hat < 1.01. |
| environment. LANfactory's regular CI does not currently install HSSM, so the | ||
| test is a no-op locally — it is intended to run in a coordinated cross-repo | ||
| CI matrix where both packages are available with compatible JAX pins. See | ||
| plans/sbi-onnx-integration.md C7b for the environment-resolution note. |
| HSSM's `onnx2jax` consumer sets the related `jaxort_only_allow_initializers_as_static_args = False` | ||
| flag automatically, but the x64 setting is process-wide and must be opted | ||
| into by the caller. |
Three reviewer findings on PR #79: - sbi.py: validate example_theta_dim/example_x_dim are positive, raising a clear ValueError instead of failing deep in ONNX tracing. Adds a guard test (keeps sbi.py at 100% patch coverage). (CodeRabbit placed its suggested check inside the docstring; this puts it after.) - docs/exporting_sbi_models.md: clarify the ">=2D" constraint is specific to sbi's MAF NLE training (nflows), not enforced by transform_sbi_to_onnx — NRE classifiers and other density estimators are not bound by it. - test_sbi_hssm_integration.py: _build_observed_dataframe accepted an `rng` it never used; thread it through ssms simulator's `random_state` so observed data is actually reproducible. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Brings the merged sbi → ONNX work (PR #79) and v0.6.1 docs/logos into the bayesflow branch. Only pyproject.toml conflicted: both sides appended to the dev dependency group. Resolved as the union — keep bayesflow's `bayesflow`/`keras` and main's `sbi` — so the full sbi + bayesflow test matrix is installed by `uv sync --all-groups`. uv.lock regenerated (uv 0.6.5) to match the resolved manifest. All sbi-file changes (dim validation, doc-ref fix, ruff formatting, guard tests, seeded-RNG fix) come straight from main. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Summary
Adds
lanfactory.onnx.transform_sbi_to_onnx, an exporter that converts atrained
sbineural likelihood estimator(NLE) or neural ratio estimator (NRE) to an ONNX file consumable by HSSM's
loglik_kind="approx_differentiable"path. The exported file behavesidentically to a LAN-trained ONNX from
transform_to_onnx, so HSSMconsumes it with zero HSSM-side code changes for the basic case.
Lives as a sibling module to the existing LAN exporter
(
lanfactory/onnx/transform_onnx.py↔lanfactory/onnx/sbi.py). The newdeps (
sbi>=0.26,nflows>=0.14) are gated under the existing[all]extra; no impact on the default install.
What's in this branch
ac16eda[all]extra extended withsbi+nflowsf7c93c8torch.onnx.export→{onnxruntime, jaxonnxruntime}round-trip on a vanilla MLP and annflowsMAF. Kept as permanent regression guards.bdcabd3estimator.log_prob(x, condition=θ)as atorch.nn.Modulethat takes a 1D concatenated(theta, x)input. Three-way numerical agreement (torch / ORT / jaxonnxruntime) atatol=1e-5, gradient agreement atatol=1e-4.f4a54feb1fd188FCEmbeddingandCNNEmbedding.87704bbdocs/exporting_sbi_models.mdintegration guide, README mention, nav update.4990e85pytest.importorskip("hssm")skip-on-missing).222adf5(1, D)dummy causedSlice axes=[1]which fails under vmap with rank-1 per-trial input.Architectural contract
theta_dim + x_dim, matching HSSM'smake_jax_logp_funcs_from_onnxper-trial vmap contract.
unsqueeze(0)to satisfysbi's batched log_prob/forward API, returns scalar via
reshape(()).jaxonnxruntime.Test status
(
pytest tests/test_sbi_spike_*.py tests/test_sbi_nle_export.py tests/test_sbi_nre_export.py tests/test_sbi_embeddings.py).pytest.importorskip("hssm"); clean skipin the current LANfactory env. Designed to run in coordinated cross-
repo CI where both LANfactory + HSSM are installed together.
Known limitations (documented in C6 guide)
SearchSortedop injaxonnxruntime(planned upstream PR, ~50 lines). Tracked in
HSSMSpine/plans/sbi-onnx-integration.md.SearchSortedblocker (categorical lookup usestorch.searchsorted). Same upstream PR unlocks both.continuous but choice is discrete; MAF can't represent the mixed
structure). The exporter still supports
mode="nle"correctly —this is a sbi-method/data-shape issue, not an exporter bug. MNLE is
the correct sbi method for DDM-like data; deferred until the
upstream PR lands.
Companion PRs
Consumed by the HSSM-side keystone tutorial in
lnccbrown/HSSM#sbi-integration.
The tutorial uses an
importlibfallback to load this exporter evenwithout LANfactory installed, so this PR can be merged independently.
Test plan
test_sbi_hssm_integration.py)once both packages can be installed in the same env
🤖 Generated with Claude Code
Summary by CodeRabbit
approx_differentiableworkflow via a new public exporter with mode-specific validation..gitignorefor SBI TensorBoard logs.