docs: runnable sbi + bayesflow → ONNX marimo tutorials#84
Conversation
A marimo notebook that runs the full train → export → verify loop for the sbi exporter end to end. Fills a real gap: the reference guide's snippets aren't runnable, and HSSM's ONNX tutorial starts from an existing .onnx. Trains a tiny NLE_A on a 2D Gaussian toy, exports via transform_sbi_to_onnx, and verifies torch / onnxruntime / jaxonnxruntime agree (with interactive eval-point sliders), then hands off to HSSM's blackbox_contribution_onnx_example for the consumption side. - notebooks/exporting_sbi_to_onnx.py: marimo source of truth (ruff-excluded; validated with `marimo check`). - docs/tutorials/exporting_sbi_to_onnx.ipynb: rendered export with outputs, wired into the mkdocs nav + mkdocs-jupyter execute_ignore so the docs build renders baked outputs without re-training. - marimo added to the dev group; __marimo__/ session cache gitignored. Verified: marimo check clean; three-way agreement 2.4e-7; mkdocs build renders the page. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
The bayesflow sibling of the sbi tutorial: trains a tiny ContinuousApproximator (the v1 ONNX-exportable CouplingFlow — permutation=None, AffineTransform( clamp=False), silu, no actnorm), exports via transform_bayesflow_to_onnx, and verifies the torch wrapper / onnxruntime / jaxonnxruntime agree (interactive eval-point sliders), then hands off to HSSM. Mirrors the sbi tutorial with bayesflow's Keras specifics: KERAS_BACKEND=torch set before importing keras/bayesflow, autograd re-enabled at import, and the torch reference taken via the wrapper (approximator.log_prob runs the numpy adapter). Wired into the mkdocs nav + execute_ignore. Verified: marimo check clean; three-way agreement 4.8e-7; mkdocs build renders. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughTwo new interactive Marimo tutorial notebooks are added: one for exporting a trained ChangesONNX Export Tutorials and Configuration
Sequence Diagram(s)sequenceDiagram
participant User
participant Estimator as sbi NLE / BayesFlow ContinuousApproximator
participant lanfactory as lanfactory.onnx.transform_*_to_onnx
participant onnxruntime
participant jaxonnxruntime
participant hssm_HSSM as hssm.HSSM
User->>Estimator: train on 2D Gaussian toy
Estimator-->>User: trained estimator
User->>lanfactory: transform(estimator, mode="nle", theta_dim, x_dim)
lanfactory-->>User: .onnx file path
User->>onnxruntime: InferenceSession(.onnx)
User->>jaxonnxruntime: call_onnx_model(.onnx)
Note over User,jaxonnxruntime: eval_backends(theta, x) — slider-driven interactive recomputation
onnxruntime-->>User: log_prob
jaxonnxruntime-->>User: log_prob
User->>hssm_HSSM: HSSM(loglik_kind="approx_differentiable", loglik=".onnx")
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Pull request overview
Adds end-to-end, runnable tutorials demonstrating how to train small sbi / bayesflow estimators, export them to single-trial ONNX via LANfactory, and verify numerical agreement across torch / onnxruntime / jaxonnxruntime—then links the artifact handoff to HSSM.
Changes:
- Add two marimo notebook sources (
notebooks/exporting_*.py) covering train → export → verify forsbiandbayesflow. - Add rendered tutorial notebooks under
docs/tutorials/and wire them into MkDocs nav + mkdocs-jupyterexecute_ignore. - Add
marimoas a dev dependency and ignore__marimo__/cache.
Reviewed changes
Copilot reviewed 6 out of 8 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
uv.lock |
Locks new dev dependency (marimo) and its transitive deps. |
pyproject.toml |
Adds marimo>=0.23 to the dev dependency group. |
notebooks/exporting_sbi_to_onnx.py |
New marimo source tutorial for sbi → ONNX export and backend verification. |
notebooks/exporting_bayesflow_to_onnx.py |
New marimo source tutorial for bayesflow → ONNX export and backend verification. |
mkdocs.yml |
Adds both tutorials to nav and excludes them from mkdocs-jupyter execution. |
docs/tutorials/exporting_sbi_to_onnx.ipynb |
Rendered tutorial notebook with baked outputs. |
docs/tutorials/exporting_bayesflow_to_onnx.ipynb |
Rendered tutorial notebook with baked outputs. |
.gitignore |
Ignores __marimo__/ cache directory. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| THETA_DIM, X_DIM = 2, 2 | ||
| return ( | ||
| AffineTransform, | ||
| OfflineDataset, | ||
| THETA_DIM, | ||
| X_DIM, | ||
| bf, | ||
| call_onnx, | ||
| jax, | ||
| keras, | ||
| np, | ||
| onnx, | ||
| ort, | ||
| torch, | ||
| transform_bayesflow_to_onnx, | ||
| ) |
| def _( | ||
| THETA_DIM, | ||
| X_DIM, | ||
| approximator, | ||
| call_onnx, | ||
| jax, | ||
| np, | ||
| onnx, | ||
| onnx_path, | ||
| ort, | ||
| torch, | ||
| ): |
| # MUST precede any keras / bayesflow import. | ||
| os.environ["KERAS_BACKEND"] = "torch" | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") | ||
|
|
||
| import jax | ||
|
|
||
| jax.config.update("jax_enable_x64", True) |
| @app.cell | ||
| def _(THETA_DIM, X_DIM, approximator, transform_bayesflow_to_onnx): | ||
| import tempfile | ||
|
|
||
| _onnx_dir = tempfile.mkdtemp(prefix="bf_onnx_") | ||
| onnx_path = f"{_onnx_dir}/ddm_bayesflow_nle.onnx" | ||
|
|
||
| transform_bayesflow_to_onnx( | ||
| approximator, | ||
| onnx_path, | ||
| mode="nle", | ||
| example_theta_dim=THETA_DIM, | ||
| example_x_dim=X_DIM, | ||
| ) | ||
| print(f"wrote {onnx_path}") | ||
| return (onnx_path,) |
| @app.cell | ||
| def _(THETA_DIM, X_DIM, estimator, transform_sbi_to_onnx): | ||
| import os | ||
| import tempfile | ||
|
|
||
| _onnx_dir = tempfile.mkdtemp(prefix="sbi_onnx_") | ||
| onnx_path = os.path.join(_onnx_dir, "ddm_nle.onnx") | ||
|
|
||
| transform_sbi_to_onnx( | ||
| estimator, | ||
| onnx_path, | ||
| mode="nle", | ||
| example_theta_dim=THETA_DIM, | ||
| example_x_dim=X_DIM, | ||
| ) | ||
| print(f"wrote {onnx_path}") | ||
| return (onnx_path,) |
| @app.cell | ||
| def _(call_onnx, jax, np, onnx, onnx_path, ort): | ||
| # Load the exported graph into onnxruntime and the jax-translated runner once. | ||
| _ort_session = ort.InferenceSession(onnx_path) | ||
| _input_name = _ort_session.get_inputs()[0].name | ||
|
|
||
| _onnx_model = onnx.load(onnx_path) | ||
| _trace_input = np.zeros(4, dtype=np.float32) # [θ0, θ1, x0, x1] | ||
| _model_func, _weights = call_onnx.call_onnx_model( | ||
| _onnx_model, {_input_name: _trace_input} | ||
| ) | ||
| jax_run = jax.tree_util.Partial(_model_func, _weights) | ||
|
|
||
| def eval_backends(theta, x): | ||
| """Return (ort, jax) scalar log-probs for a θ/x point (length-2 each).""" | ||
| combined = np.asarray([*theta, *x], dtype=np.float32) | ||
| y_ort = float(np.asarray(_ort_session.run(None, {_input_name: combined})[0]).flatten()[0]) | ||
| y_jax = float(np.asarray(jax_run({_input_name: combined})[0]).flatten()[0]) | ||
| return y_ort, y_jax | ||
|
|
||
| return (eval_backends,) |
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/nflows/distributions/base.py:33: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", | ||
| " inputs = torch.as_tensor(inputs)\n", | ||
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/nflows/distributions/base.py:35: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", | ||
| " context = torch.as_tensor(context)\n", | ||
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/nflows/distributions/base.py:36: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", |
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/keras/src/backend/common/variables.py:634: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", | ||
| " d = int(d)\n", | ||
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/keras/src/backend/torch/numpy.py:1855: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", | ||
| " if dim % indices_or_sections != 0:\n", | ||
| "/Users/afengler/Projects/proj_hssmspine/HSSMSpine/repos/LANfactory/.venv/lib/python3.12/site-packages/keras/src/backend/torch/numpy.py:1869: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", |
| "INFO:2026-06-21 00:53:26,785:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/afengler/.local/share/uv/python/cpython-3.12.13-macos-aarch64-none/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)\n", | ||
| "INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/afengler/.local/share/uv/python/cpython-3.12.13-macos-aarch64-none/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)\n", | ||
| "INFO:jaxonnxruntime.call_onnx:Start tracing the jax_func model to get some static info\n" |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@notebooks/exporting_bayesflow_to_onnx.py`:
- Around line 173-187: The transform_bayesflow_to_onnx function call is emitting
environment-specific stderr messages (including local paths) that clutter the
rendered notebook output. Wrap the transform_bayesflow_to_onnx call in a stderr
capture mechanism (such as using contextlib.redirect_stderr or io.StringIO) to
suppress or filter the noisy stderr output before the notebook is re-exported.
Ensure the captured stderr is either discarded or filtered to only display
relevant messages, so the final rendered output appears clean without local path
information.
In `@notebooks/exporting_sbi_to_onnx.py`:
- Around line 151-159: The transform_sbi_to_onnx function call is allowing
stderr output to leak through, which exposes absolute local filesystem paths in
the rendered notebook output. Suppress or capture the stderr stream during the
execution of the transform_sbi_to_onnx function call to prevent these local
environment details from being included in the tutorial output. After
suppressing stderr, regenerate the docs/tutorials/exporting_sbi_to_onnx.ipynb
notebook to ensure the output no longer contains filesystem paths.
- Around line 163-171: The _trace_input initialization in the notebook cell
hardcodes the size as 4, which makes it brittle when THETA_DIM or X_DIM
constants change. Replace the hardcoded value 4 in np.zeros(4, dtype=np.float32)
with a dynamic calculation that uses the THETA_DIM and X_DIM constants (summing
them together to get the total input dimension needed for the trace input).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro Plus
Run ID: fae616e2-6f97-435d-90a4-bcedbac445b0
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (7)
.gitignoredocs/tutorials/exporting_bayesflow_to_onnx.ipynbdocs/tutorials/exporting_sbi_to_onnx.ipynbmkdocs.ymlnotebooks/exporting_bayesflow_to_onnx.pynotebooks/exporting_sbi_to_onnx.pypyproject.toml
| @app.cell | ||
| def _(THETA_DIM, X_DIM, approximator, transform_bayesflow_to_onnx): | ||
| import tempfile | ||
|
|
||
| _onnx_dir = tempfile.mkdtemp(prefix="bf_onnx_") | ||
| onnx_path = f"{_onnx_dir}/ddm_bayesflow_nle.onnx" | ||
|
|
||
| transform_bayesflow_to_onnx( | ||
| approximator, | ||
| onnx_path, | ||
| mode="nle", | ||
| example_theta_dim=THETA_DIM, | ||
| example_x_dim=X_DIM, | ||
| ) | ||
| print(f"wrote {onnx_path}") |
There was a problem hiding this comment.
Sanitize export/tracing stderr before publishing rendered notebook outputs.
These cells emit environment-specific stderr (including local absolute paths in current rendered output). Please capture/filter noisy stderr during ONNX export and JAX tracing, then re-export docs/tutorials/exporting_bayesflow_to_onnx.ipynb.
Also applies to: 210-214
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@notebooks/exporting_bayesflow_to_onnx.py` around lines 173 - 187, The
transform_bayesflow_to_onnx function call is emitting environment-specific
stderr messages (including local paths) that clutter the rendered notebook
output. Wrap the transform_bayesflow_to_onnx call in a stderr capture mechanism
(such as using contextlib.redirect_stderr or io.StringIO) to suppress or filter
the noisy stderr output before the notebook is re-exported. Ensure the captured
stderr is either discarded or filtered to only display relevant messages, so the
final rendered output appears clean without local path information.
Codecov Report✅ All modified and coverable lines are covered by tests. 🚀 New features to boost your workflow:
|
Don't bake machine-specific absolute paths into the published notebooks:
- the export cells printed the temp .onnx path → print a path-free message;
- a JAX xla_bridge stderr log ("Unable to initialize backend 'tpu'") leaked a
local libtpu.so path → silence the jax._src.xla_bridge logger + filter
warnings for clean tutorial output.
Re-exported both .ipynb: zero absolute-path leaks now.
Also: exporting_sbi_to_onnx uses THETA_DIM + X_DIM for the trace-input size
instead of a hardcoded 4 (the bayesflow notebook already did this).
Verified: marimo check clean; no error cells; three-way agreement preserved
(sbi 2.4e-7, bayesflow 4.8e-7).
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Copilot flagged that the bayesflow notebook imported _BayesflowNLELogProbWrapper in the setup cell but used it in the sessions cell. Because the name is _-prefixed, marimo treats it as cell-local and does not share it across cells — so the sessions cell would NameError under `marimo edit` (reactive execution). `marimo export` only masked it by running cells in a shared namespace. Move the import into the sessions cell where it's used. Re-exported the docs notebook. Verified: marimo check clean; no error cells; three-way agreement 4.8e-7; no local-path leaks in rendered output. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
| def _(THETA_DIM, X_DIM, estimator, transform_sbi_to_onnx): | ||
| import os | ||
| import tempfile | ||
|
|
||
| _onnx_dir = tempfile.mkdtemp(prefix="sbi_onnx_") | ||
| onnx_path = os.path.join(_onnx_dir, "ddm_nle.onnx") | ||
|
|
||
| transform_sbi_to_onnx( | ||
| estimator, | ||
| onnx_path, | ||
| mode="nle", | ||
| example_theta_dim=THETA_DIM, | ||
| example_x_dim=X_DIM, | ||
| ) | ||
| print("✓ exported ddm_nle.onnx") | ||
| return (onnx_path,) |
| def _(THETA_DIM, X_DIM, approximator, transform_bayesflow_to_onnx): | ||
| import tempfile | ||
|
|
||
| _onnx_dir = tempfile.mkdtemp(prefix="bf_onnx_") | ||
| onnx_path = f"{_onnx_dir}/ddm_bayesflow_nle.onnx" | ||
|
|
||
| transform_bayesflow_to_onnx( | ||
| approximator, | ||
| onnx_path, | ||
| mode="nle", | ||
| example_theta_dim=THETA_DIM, | ||
| example_x_dim=X_DIM, | ||
| ) | ||
| print("✓ exported ddm_bayesflow_nle.onnx") | ||
| return (onnx_path,) |
|
|
||
| estimator = _inference.append_simulations(_theta, _x).train( | ||
| training_batch_size=200, | ||
| max_num_epochs=15, |
| "\n", | ||
| "estimator = _inference.append_simulations(_theta, _x).train(\n", | ||
| " training_batch_size=200,\n", | ||
| " max_num_epochs=15,\n", |
| "source": [ | ||
| "import os\n", | ||
| "import tempfile\n", | ||
| "\n", | ||
| "_onnx_dir = tempfile.mkdtemp(prefix=\"sbi_onnx_\")\n", | ||
| "onnx_path = os.path.join(_onnx_dir, \"ddm_nle.onnx\")\n", | ||
| "\n", | ||
| "transform_sbi_to_onnx(\n", | ||
| " estimator,\n", | ||
| " onnx_path,\n", | ||
| " mode=\"nle\",\n", | ||
| " example_theta_dim=THETA_DIM,\n", | ||
| " example_x_dim=X_DIM,\n", | ||
| ")\n", | ||
| "print(\"✓ exported ddm_nle.onnx\")" | ||
| ] |
| "source": [ | ||
| "import tempfile\n", | ||
| "\n", | ||
| "_onnx_dir = tempfile.mkdtemp(prefix=\"bf_onnx_\")\n", | ||
| "onnx_path = f\"{_onnx_dir}/ddm_bayesflow_nle.onnx\"\n", | ||
| "\n", | ||
| "transform_bayesflow_to_onnx(\n", | ||
| " approximator,\n", | ||
| " onnx_path,\n", | ||
| " mode=\"nle\",\n", | ||
| " example_theta_dim=THETA_DIM,\n", | ||
| " example_x_dim=X_DIM,\n", | ||
| ")\n", | ||
| "print(\"✓ exported ddm_bayesflow_nle.onnx\")" | ||
| ] |
Why
The reference guides (
exporting_sbi_models.md,exporting_bayesflow_models.md) have only placeholder snippets, and HSSM'sblackbox_contribution_onnx_example.ipynbstarts from an existing.onnx. Nobody's runnable tutorial covers the LANfactory half: train an sbi/bayesflow estimator → export to ONNX → verify. These two marimo notebooks fill that gap and hand off to HSSM for consumption.What's in it
Two marimo notebooks, each running the full loop end to end on a 2D Gaussian toy:
exporting_sbi_to_onnxNLE_A(MAF) →transform_sbi_to_onnxexporting_bayesflow_to_onnxContinuousApproximator(v1 CouplingFlow) →transform_bayesflow_to_onnxBoth have interactive
mo.uieval-point sliders (the three-way check re-runs reactively, no retraining) and end with an HSSM handoff. The bayesflow one documents the Keras specifics (KERAS_BACKEND=torchbefore import, the exportable CouplingFlow knobs).How it's wired
notebooks/exporting_*.py(marimo.py, git-friendly,marimo check-clean;notebooks/*is ruff-excluded so marimo's formatting stands).docs/tutorials/exporting_*.ipynbexported with--include-outputs, added to the mkdocs nav + mkdocs-jupyterexecute_ignore(renders baked outputs without re-training at build). Regen command is in each.pyheader.marimoadded to the dev group;__marimo__/session cache gitignored.Verification
marimo checkclean on both; both export with no error cells;mkdocs buildrenders both pages (execute=False). CI unaffected —notebooks/anddocs/are ruff-excluded and the tutorials aren't pytest targets (the underlying logic is already covered bytests/test_{sbi,bayesflow}_nle_export.py).🤖 Generated with Claude Code
Summary by CodeRabbit
Documentation
Chores