feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80
Conversation
Adds lanfactory.onnx.transform_bayesflow_to_onnx, the bayesflow sibling of transform_sbi_to_onnx (PR #79). Wraps a trained bayesflow ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable" path. Same I/O contract as the sbi exporter (rank-1 input [theta..., x...], rank-0 scalar log-likelihood, opset 17) so HSSM ingests both via the same loglik="*.onnx" gesture with zero HSSM-side changes. What's in this commit - src/lanfactory/onnx/bayesflow.py: exporter module mirroring sbi.py. Contains _BayesflowNLELogProbWrapper and _BayesflowNRELogRatioWrapper. Pre-evaluates the bayesflow Standardize layer's moving mean/std to torch buffer constants at wrapper construction time so the ONNX trace is fully static (avoids If, Size, Tile dynamic-shape ops that jaxonnxruntime can't run). Guards on KERAS_BACKEND=torch and identity Adapter; both raise actionable errors with concrete fix hints. - src/lanfactory/onnx/__init__.py: export the new function. - pyproject.toml: add [bayesflow] optional extra (bayesflow>=2.0.8, keras>=3.12), add to [all] and [dev]. Also refactors the existing sbi+nflows pair into its own [sbi] extra (mirroring the new [bayesflow]) while keeping them in [all]. - tests/test_bayesflow_nle_export.py: 6 tests. Three-way numerical agreement (torch reference wrapper <-> onnxruntime <-> jaxonnxruntime) at atol=1e-5, gradient agreement at atol=1e-4, log-prob ordering sanity, and three guard tests (wrong backend, non-identity adapter, wrong mode). - tests/test_bayesflow_nre_export.py: 4 tests. Same shape for the NRE path on a RatioApproximator. - tests/test_bayesflow_hssm_integration.py: end-to-end DDM smoke (pytest.importorskip("hssm")). Mirrors test_sbi_hssm_integration.py. - docs/exporting_bayesflow_models.md: full constraint catalog (KERAS_BACKEND, CouplingFlow knobs, silu vs hard_silu activation choice, identity-adapter requirement, JAX x64). Quick-starts for NLE and NRE. "Two paths into HSSM" framing alongside the JAX-callable path used in bayesflow_lre_integration.ipynb. v1 constraints (documented, enforced where introspectable) User must train with: - permutation=None (FixedPermutation -> aten::ravel, unsupported) - use_actnorm=False (untested in v1) - transform=AffineTransform(clamp=False) explicit instance (find_transform("affine") drops kwargs - bayesflow upstream bug) - subnet_kwargs.activation="silu" or another smooth activation (default hard_silu emits HardSwish, no jaxonnxruntime handler) - identity Adapter (numpy-only adapter ops cannot be baked into ONNX) Bayesflow continuous observations only. MNLE-style discrete + continuous deferred until upstream MNLE support lands. Numerical guarantees 19 passing tests across both bayesflow and sbi tracks; no regressions on the existing sbi exporter. Each export is verified for three-way numerical agreement at 1e-5 and gradient agreement at 1e-4. Companion PRs - HSSM: docs(tutorials): add bayesflow_nle_onnx_integration.ipynb on a fresh bayesflow-integration branch off main (sibling, not child, of the sbi-integration branch in PR #964). - HSSMSpine: bayesflow-onnx-integration.md design doc + upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream defects surfaced during this work (jaxonnxruntime missing HardSwish/Size handlers; bayesflow find_transform kwarg-drop bug; bayesflow global torch.set_grad_enabled(False) cross-library leak; torch.onnx missing aten::ravel/asinh symbolic registrations). This branch is stacked on sbi-connector (PR #79). When #79 merges, this PR's base auto-retargets to main. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
There was a problem hiding this comment.
Pull request overview
Adds a BayesFlow-to-ONNX exporter so BayesFlow-trained NLE/NRE models can be consumed by HSSM via the same loglik="model.onnx", loglik_kind="approx_differentiable" pathway already used for LAN and sbi exports.
Changes:
- Introduces
lanfactory.onnx.transform_bayesflow_to_onnxwith NLE/NRE wrappers that bake BayesFlow standardization statistics into a static ONNX trace. - Adds BayesFlow-focused regression tests (three-way numerical agreement + gradient agreement) and an optional HSSM integration smoke test.
- Adds a
bayesflowoptional extra (and a dedicatedsbiextra) plus a new documentation guide for BayesFlow exports.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
src/lanfactory/onnx/bayesflow.py |
New BayesFlow exporter implementation (NLE + NRE wrappers + ONNX export entry point). |
src/lanfactory/onnx/__init__.py |
Exposes transform_bayesflow_to_onnx at the package level. |
pyproject.toml |
Adds bayesflow and sbi extras; extends all and dev deps. |
tests/test_bayesflow_nle_export.py |
New NLE export tests: forward/grad agreement + guardrails. |
tests/test_bayesflow_nre_export.py |
New NRE export tests: forward/grad agreement + guardrails. |
tests/test_bayesflow_hssm_integration.py |
New (skip-if-missing-HSSM) end-to-end BayesFlow→ONNX→HSSM integration test. |
docs/exporting_bayesflow_models.md |
New BayesFlow ONNX export guide and constraint catalog. |
uv.lock |
Dependency lockfile updates to include BayesFlow/Keras and related resolutions. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import pandas as pd # noqa: E402 | ||
|
|
||
| import bayesflow as bf # noqa: E402 |
| # Exporting bayesflow-trained networks to ONNX | ||
|
|
||
| LANfactory's [`transform_bayesflow_to_onnx`](api/onnx.md) is the bayesflow | ||
| sibling of [`transform_sbi_to_onnx`](exporting_sbi_models.md). It wraps a | ||
| trained [`bayesflow`](https://github.com/bayesflow-org/bayesflow) | ||
| `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE) and writes a | ||
| single-trial ONNX file that HSSM's `loglik_kind="approx_differentiable"` | ||
| path can consume exactly like an sbi export. Same user gesture, same file | ||
| format, same HSSM-side loader — regardless of which training framework you | ||
| came from. |
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") |
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") |
| @@ -82,6 +86,8 @@ dev = [ | |||
| "jaxonnxruntime>=0.3", | |||
| "onnxruntime>=1.17", | |||
| "nflows>=0.14", | |||
Brings the merged sbi → ONNX work (PR #79) and v0.6.1 docs/logos into the bayesflow branch. Only pyproject.toml conflicted: both sides appended to the dev dependency group. Resolved as the union — keep bayesflow's `bayesflow`/`keras` and main's `sbi` — so the full sbi + bayesflow test matrix is installed by `uv sync --all-groups`. uv.lock regenerated (uv 0.6.5) to match the resolved manifest. All sbi-file changes (dim validation, doc-ref fix, ruff formatting, guard tests, seeded-RNG fix) come straight from main. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Same as the sbi files: collapses multi-line signatures/f-strings that fit within 88 chars. CI's `ruff format --check .` flagged these once the merge brought the bayesflow files alongside the dev-group fix. Cosmetic only. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
exporting_bayesflow_models.md shipped with the feature but was never added to the nav, so it was orphaned (unreachable in the built site). Add it under Guides next to the sbi export guide. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
…alidation Brings bayesflow.py to 100% patch coverage so codecov/patch passes on PR #80, mirroring the sbi exporter: - transform_bayesflow_to_onnx: validate example_theta_dim/example_x_dim are positive (parity with transform_sbi_to_onnx), raising a clear ValueError instead of failing deep in tracing. - Add guard tests: invalid mode, non-positive dims, NLE wrapper missing .inference_network. - Cover the standardizer if/else branches in both wrappers (the trained fixtures only standardize one slot) with lightweight stand-in approximators that exercise the opposite combination. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (4)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds ChangesBayesflow ONNX Exporter
Sequence Diagram(s)sequenceDiagram
participant User
participant transform_bayesflow_to_onnx
participant WrapperModule
participant torch_onnx_export as torch.onnx.export
participant ONNX_Runtime
participant HSSM
User->>transform_bayesflow_to_onnx: approximator, path, mode="nle"
transform_bayesflow_to_onnx->>transform_bayesflow_to_onnx: validate dims, backend, adapter
transform_bayesflow_to_onnx->>WrapperModule: construct _BayesflowNLELogProbWrapper
transform_bayesflow_to_onnx->>torch_onnx_export: trace wrapper with dummy input
torch_onnx_export-->>User: ONNX file written
User->>HSSM: load ONNX with loglik_kind="approx_differentiable"
HSSM->>ONNX_Runtime: load model.onnx
ONNX_Runtime-->>HSSM: per-trial log-likelihood
HSSM-->>User: posterior samples via MCMC
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
|
@coderabbitai review |
✅ Action performedReview finished.
|
|
@coderabbitai full review |
✅ Action performedFull review finished. |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tests/test_bayesflow_nre_export.py (1)
136-138: ⚡ Quick winPin ONNX Runtime to CPU provider for deterministic test behavior.
On Line 136, provider auto-selection is host-dependent. Explicitly setting
CPUExecutionProvideravoids environment-specific execution paths and reduces CI flakiness.Suggested fix
- sess = ort.InferenceSession(str(onnx_path)) + sess = ort.InferenceSession( + str(onnx_path), providers=["CPUExecutionProvider"] + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_bayesflow_nre_export.py` around lines 136 - 138, The ort.InferenceSession initialization on line 136 does not specify an execution provider, causing it to auto-select based on the host environment which leads to non-deterministic behavior across different CI environments. Modify the InferenceSession constructor call to explicitly pass the providers parameter with CPUExecutionProvider to ensure consistent execution behavior across all test environments.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/exporting_bayesflow_models.md`:
- Around line 88-97: Add JAX 64-bit mode configuration before the quick-start
code blocks that use ONNX models. Since ONNX graphs contain int64 tensors and
JAX silently truncates to int32 by default (causing incorrect log-probability
calculations), include the JAX x64 configuration at the beginning of the code
snippet containing the HSSM instantiation and sampling call. Apply this fix to
both the ONNX sampling quick-start block and the NRE quick-start section
mentioned in the comment.
In `@tests/test_bayesflow_hssm_integration.py`:
- Around line 22-23: Replace the os.environ.setdefault calls for KERAS_BACKEND
and KERAS_TORCH_DEVICE with direct assignment using os.environ["KERAS_BACKEND"]
= "torch" and os.environ["KERAS_TORCH_DEVICE"] = "cpu" to ensure these values
are forcibly set regardless of any pre-existing environment variables,
guaranteeing test isolation and preventing nondeterministic failures from leaked
environment state.
In `@tests/test_bayesflow_nle_export.py`:
- Around line 15-16: The use of setdefault for KERAS_BACKEND and
KERAS_TORCH_DEVICE allows pre-existing environment variable values set elsewhere
to persist, which can break the test assumptions. Replace both setdefault calls
with direct assignment using the subscript notation so that KERAS_BACKEND is
forcefully set to "torch" and KERAS_TORCH_DEVICE is forcefully set to "cpu",
ensuring these values are not overridden by any previously set environment
variables.
In `@tests/test_bayesflow_nre_export.py`:
- Around line 13-15: The use of setdefault in lines setting KERAS_BACKEND and
KERAS_TORCH_DEVICE allows preexisting environment variables to persist, creating
test order dependency and non-hermetic behavior. Replace both setdefault calls
with direct assignment using os.environ[] bracket notation to force the backend
to always be torch and device to always be cpu, ensuring this test remains
independent of other test configurations.
---
Nitpick comments:
In `@tests/test_bayesflow_nre_export.py`:
- Around line 136-138: The ort.InferenceSession initialization on line 136 does
not specify an execution provider, causing it to auto-select based on the host
environment which leads to non-deterministic behavior across different CI
environments. Modify the InferenceSession constructor call to explicitly pass
the providers parameter with CPUExecutionProvider to ensure consistent execution
behavior across all test environments.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro Plus
Run ID: 6772c78e-d6c3-4fdf-979d-1d1df817aa19
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (8)
docs/exporting_bayesflow_models.mdmkdocs.ymlpyproject.tomlsrc/lanfactory/onnx/__init__.pysrc/lanfactory/onnx/bayesflow.pytests/test_bayesflow_hssm_integration.pytests/test_bayesflow_nle_export.pytests/test_bayesflow_nre_export.py
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and examine the test file
find . -type f -name "test_bayesflow_hssm_integration.py" | head -5Repository: lnccbrown/LANfactory
Length of output: 107
🏁 Script executed:
# Also check the overall test structure
git ls-files tests/ | grep -E "\.py$" | head -20Repository: lnccbrown/LANfactory
Length of output: 615
🏁 Script executed:
# Search for patterns of environment variable usage in tests
rg "os\.environ\[|os\.environ\.setdefault|pytest\.fixture" tests/ -t py | head -30Repository: lnccbrown/LANfactory
Length of output: 1728
🏁 Script executed:
# Read the test file to verify lines 22-23
cat -n tests/test_bayesflow_hssm_integration.py | head -40Repository: lnccbrown/LANfactory
Length of output: 1824
🏁 Script executed:
# Check conftest.py for environment setup patterns
cat -n tests/conftest.py | head -100Repository: lnccbrown/LANfactory
Length of output: 3232
🏁 Script executed:
# Check similar test files that use the same setdefault pattern
cat -n tests/test_bayesflow_nle_export.py | head -40Repository: lnccbrown/LANfactory
Length of output: 1950
🏁 Script executed:
# Check if there's a pattern in other similar test files (e.g., test_sbi_hssm_integration)
cat -n tests/test_sbi_hssm_integration.py | head -40Repository: lnccbrown/LANfactory
Length of output: 1986
🏁 Script executed:
# Check for any direct assignment patterns in the codebase
rg "os\.environ\[\"KERAS" tests/ -A 1Repository: lnccbrown/LANfactory
Length of output: 46
🏁 Script executed:
# Search for any documentation or comments about environment variable handling
rg "KERAS_BACKEND|KERAS_TORCH_DEVICE" tests/ -B 3 -A 1 | head -60Repository: lnccbrown/LANfactory
Length of output: 3654
🏁 Script executed:
# Check if there are any CI/configuration files that might set these variables
find . -type f \( -name "*.yml" -o -name "*.yaml" -o -name ".env*" -o -name "pytest.ini" -o -name "pyproject.toml" \) | head -20Repository: lnccbrown/LANfactory
Length of output: 1168
🏁 Script executed:
# Check the pyproject.toml or setup files for pytest configuration
cat pyproject.toml 2>/dev/null | grep -A 20 "^\[tool.pytest" || echo "No pytest config found"Repository: lnccbrown/LANfactory
Length of output: 508
🏁 Script executed:
# Search for any monkeypatch or environment variable handling in conftest or tests
rg "monkeypatch.*KERAS|KERAS.*monkeypatch" tests/ -B 2 -A 2Repository: lnccbrown/LANfactory
Length of output: 46
🏁 Script executed:
# Check the CI workflow to see if environment variables are set before tests
cat .github/workflows/run_tests.yml | head -100Repository: lnccbrown/LANfactory
Length of output: 1294
🏁 Script executed:
# Check if there's any documentation about environment setup for tests
find . -type f \( -name "README*" -o -name "CONTRIBUTING*" \) -exec cat {} \; | grep -i "keras\|backend\|environ" | head -20Repository: lnccbrown/LANfactory
Length of output: 235
🏁 Script executed:
# Check if test_bayesflow_nle_export tests backend validation with monkeypatch
cat -n tests/test_bayesflow_nle_export.py | grep -A 20 "test_transform_rejects_wrong_backend"Repository: lnccbrown/LANfactory
Length of output: 961
🏁 Script executed:
# Check the full structure of test_bayesflow_hssm_integration to see if there are similar backend tests
cat -n tests/test_bayesflow_hssm_integration.py | tail -100Repository: lnccbrown/LANfactory
Length of output: 4040
Force Keras backend/device assignment to ensure test isolation.
Using setdefault allows pre-existing KERAS_BACKEND or KERAS_TORCH_DEVICE environment variables from the user's shell to leak in. This can cause nondeterministic test failures if the environment is configured for a different backend (e.g., jax). Since the code comment explicitly states these must be set before imports, use direct assignment to guarantee the correct values regardless of environment state.
Suggested patch
-os.environ.setdefault("KERAS_BACKEND", "torch")
-os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
+os.environ["KERAS_BACKEND"] = "torch"
+os.environ["KERAS_TORCH_DEVICE"] = "cpu"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| os.environ.setdefault("KERAS_BACKEND", "torch") | |
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") | |
| os.environ["KERAS_BACKEND"] = "torch" | |
| os.environ["KERAS_TORCH_DEVICE"] = "cpu" |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_bayesflow_hssm_integration.py` around lines 22 - 23, Replace the
os.environ.setdefault calls for KERAS_BACKEND and KERAS_TORCH_DEVICE with direct
assignment using os.environ["KERAS_BACKEND"] = "torch" and
os.environ["KERAS_TORCH_DEVICE"] = "cpu" to ensure these values are forcibly set
regardless of any pre-existing environment variables, guaranteeing test
isolation and preventing nondeterministic failures from leaked environment
state.
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") | ||
|
|
There was a problem hiding this comment.
Make backend preconditions explicit, not setdefault-dependent.
On Line 13 and Line 14, setdefault keeps any preexisting environment values, which can make this module order-dependent if another test session config sets a different backend. Fail fast (or force values) so this test remains hermetic.
Suggested fix
-os.environ.setdefault("KERAS_BACKEND", "torch")
-os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
+os.environ["KERAS_BACKEND"] = "torch"
+os.environ["KERAS_TORCH_DEVICE"] = "cpu" import keras # noqa: E402
+if keras.backend.backend() != "torch":
+ pytest.skip("tests/test_bayesflow_nre_export.py requires KERAS_BACKEND=torch")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_bayesflow_nre_export.py` around lines 13 - 15, The use of
setdefault in lines setting KERAS_BACKEND and KERAS_TORCH_DEVICE allows
preexisting environment variables to persist, creating test order dependency and
non-hermetic behavior. Replace both setdefault calls with direct assignment
using os.environ[] bracket notation to force the backend to always be torch and
device to always be cpu, ensuring this test remains independent of other test
configurations.
- docs/exporting_bayesflow_models.md: enable jax_enable_x64 in the NLE
quick-start before HSSM imports jax — the "Known constraints" section
already documents that ONNX int64 tensors get truncated under JAX's default
32-bit mode, but the copy-paste quick-start omitted the setup.
- test_bayesflow_{nle,nre,hssm_integration}: force KERAS_BACKEND=torch instead
of setdefault. These tests hard-require the torch backend; setdefault would
silently honor a stray KERAS_BACKEND=jax in the environment and fail
confusingly. KERAS_TORCH_DEVICE stays a setdefault (a preference, not a
requirement — forcing cpu would override GPU users).
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
| def _simulate_ddm_rows(theta: np.ndarray) -> np.ndarray: | ||
| """Simulate one (rt, choice) per row of theta. Returns (n, 2) float32.""" | ||
| rts = np.empty(theta.shape[0], dtype=np.float32) | ||
| choices = np.empty(theta.shape[0], dtype=np.float32) | ||
| for i, th in enumerate(theta): | ||
| out = simulator(theta=th[None, :], model="ddm", n_samples=1) | ||
| rts[i] = out["rts"].squeeze() | ||
| choices[i] = out["choices"].squeeze() | ||
| return np.stack([rts, choices], axis=-1) |
| def _build_observed_dataframe() -> pd.DataFrame: | ||
| """Generate N_OBS trials at the true theta as an HSSM-shaped DataFrame.""" | ||
| out = simulator(theta=_TRUE_THETA[None, :], model="ddm", n_samples=_N_OBS) | ||
| rts = out["rts"].squeeze().astype(np.float32) | ||
| choices = out["choices"].squeeze().astype(np.float32) | ||
| return pd.DataFrame({"rt": rts, "response": choices}) |
Summary
Adds
lanfactory.onnx.transform_bayesflow_to_onnx— the bayesflow sibling oftransform_sbi_to_onnxfrom #79. Wraps a trainedbayesflow.ContinuousApproximator(NLE) orRatioApproximator(NRE) and writes a single-trial ONNX file consumable by HSSM'sloglik_kind="approx_differentiable"path.The user gesture becomes the same regardless of training framework:
hssm.HSSM(loglik="model.onnx", loglik_kind="approx_differentiable")works for LAN, sbi, and now bayesflow exports through the identical HSSM-side code path (no HSSM changes required).Branch relationship
This branch is stacked on
sbi-connector(#79). The diff in this PR is exactly the bayesflow-specific additions; #79's commits are not duplicated. When #79 merges tomain, GitHub will offer to auto-retarget this PR's base tomain.What's in this PR
src/lanfactory/onnx/bayesflow.pytransform_bayesflow_to_onnx)src/lanfactory/onnx/__init__.pypyproject.toml[bayesflow]optional extra; refactor existing sbi+nflows pair into a new[sbi]extra (symmetric with[bayesflow]); both added to[all]and[dev]tests/test_bayesflow_nle_export.pytests/test_bayesflow_nre_export.pytests/test_bayesflow_hssm_integration.pypytest.importorskip("hssm"), mirrorstest_sbi_hssm_integration.py)docs/exporting_bayesflow_models.mdexporting_sbi_models.mdArchitectural contract
Same I/O contract as the sbi exporter:
(theta_dim + x_dim,), parameters first then observations.jaxonnxruntimereproducibility.Key implementation choice: the wrapper bakes the bayesflow
Standardizelayer's accumulatedmoving_mean/moving_stdas torch buffer constants at construction time. This sidesteps the dynamic-shape ops (If,Size,Tile) that the live Keras layer would emit at trace time —jaxonnxruntimedoesn't have aSizehandler. The constants are correct because training is complete by export time.v1 constraints (documented, enforced where introspectable)
User must train with:
permutation=None(FixedPermutation →aten::ravel, unsupported in opset 17/20)use_actnorm=False(untested in v1)transform=AffineTransform(clamp=False)as an explicit instance (find_transform("affine")silently dropstransform_kwargs— bayesflow upstream bug, catalogued in companion HSSMSpine PR)subnet_kwargs.activation="silu"(defaulthard_siluexports as the fused ONNX opHardSwish, no jaxonnxruntime handler; silu decomposes toSigmoid + Mul)Adapter(numpy-only Adapter ops can't be baked into ONNX)Each violation produces an actionable error message at export time.
Test status
torch.set_grad_enabled(True)after importing bayesflow to undo the global autograd disable that bayesflow's torch backend does at import time. This is a known upstream issue documented in the companion HSSMSpine PR's upstream-bugs catalog.Companion PRs
bayesflow-integrationbranch (new), addsdocs/tutorials/bayesflow_nle_onnx_integration.ipynb. Sibling, not child, ofsbi-integration(#964) — works against HSSMmainstandalone (Part 1 includes the manualjaxort_only_allow_initializers_as_static_args=Falseworkaround that #964 plans to auto-handle insideonnx2jax).bayesflow-onnx-plansbranch (stacked on NameError when wandb is not found #9 for the cross-reference tosbi-onnx-integration.md). Adds the design doc and an upstream-bugs catalog covering the five real upstream defects surfaced during this work.Test plan
pytest.importorskip("hssm")) once both packages can be installed in the same env🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
transform_bayesflow_to_onnxto export BayesFlow NLE/NRE models into a single ONNX artifact for HSSM’sapprox_differentiablelog-likelihood workflow (with configurable ONNX opset).Documentation
Tests
Chores