Skip to content

feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80

Merged
AlexanderFengler merged 6 commits into
mainfrom
bayesflow-connector
Jun 20, 2026
Merged

feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80
AlexanderFengler merged 6 commits into
mainfrom
bayesflow-connector

Conversation

@AlexanderFengler

@AlexanderFengler AlexanderFengler commented May 18, 2026

Copy link
Copy Markdown
Member

Summary

Adds lanfactory.onnx.transform_bayesflow_to_onnx — the bayesflow sibling of transform_sbi_to_onnx from #79. Wraps a trained bayesflow.ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable" path.

The user gesture becomes the same regardless of training framework: hssm.HSSM(loglik="model.onnx", loglik_kind="approx_differentiable") works for LAN, sbi, and now bayesflow exports through the identical HSSM-side code path (no HSSM changes required).

Branch relationship

This branch is stacked on sbi-connector (#79). The diff in this PR is exactly the bayesflow-specific additions; #79's commits are not duplicated. When #79 merges to main, GitHub will offer to auto-retarget this PR's base to main.

What's in this PR

File Action Purpose
src/lanfactory/onnx/bayesflow.py new Exporter module (NLE + NRE wrappers + transform_bayesflow_to_onnx)
src/lanfactory/onnx/__init__.py edit Export the new function
pyproject.toml edit New [bayesflow] optional extra; refactor existing sbi+nflows pair into a new [sbi] extra (symmetric with [bayesflow]); both added to [all] and [dev]
tests/test_bayesflow_nle_export.py new 6 tests: three-way numerical agreement (atol=1e-5), gradient agreement (atol=1e-4), log-prob ordering sanity, three guard tests
tests/test_bayesflow_nre_export.py new 4 tests: same shape for the NRE path
tests/test_bayesflow_hssm_integration.py new End-to-end DDM smoke (pytest.importorskip("hssm"), mirrors test_sbi_hssm_integration.py)
docs/exporting_bayesflow_models.md new Sibling of exporting_sbi_models.md

Architectural contract

Same I/O contract as the sbi exporter:

  • Input: rank-1 tensor of shape (theta_dim + x_dim,), parameters first then observations.
  • Output: rank-0 scalar log-likelihood.
  • Opset: pinned to 17 for jaxonnxruntime reproducibility.

Key implementation choice: the wrapper bakes the bayesflow Standardize layer's accumulated moving_mean / moving_std as torch buffer constants at construction time. This sidesteps the dynamic-shape ops (If, Size, Tile) that the live Keras layer would emit at trace time — jaxonnxruntime doesn't have a Size handler. The constants are correct because training is complete by export time.

v1 constraints (documented, enforced where introspectable)

User must train with:

  • permutation=None (FixedPermutation → aten::ravel, unsupported in opset 17/20)
  • use_actnorm=False (untested in v1)
  • transform=AffineTransform(clamp=False) as an explicit instance (find_transform("affine") silently drops transform_kwargs — bayesflow upstream bug, catalogued in companion HSSMSpine PR)
  • subnet_kwargs.activation="silu" (default hard_silu exports as the fused ONNX op HardSwish, no jaxonnxruntime handler; silu decomposes to Sigmoid + Mul)
  • Identity Adapter (numpy-only Adapter ops can't be baked into ONNX)

Each violation produces an actionable error message at export time.

Test status

  • bayesflow NLE: 6/6 passing
  • bayesflow NRE: 4/4 passing
  • bayesflow ↔ HSSM integration: skip-on-missing-HSSM in this env (designed for the coordinated cross-repo CI matrix)
  • sbi regression check: existing sbi tests still pass (no regression). Note: the bayesflow test modules call torch.set_grad_enabled(True) after importing bayesflow to undo the global autograd disable that bayesflow's torch backend does at import time. This is a known upstream issue documented in the companion HSSMSpine PR's upstream-bugs catalog.

Companion PRs

  • HSSM: bayesflow-integration branch (new), adds docs/tutorials/bayesflow_nle_onnx_integration.ipynb. Sibling, not child, of sbi-integration (#964) — works against HSSM main standalone (Part 1 includes the manual jaxort_only_allow_initializers_as_static_args=False workaround that #964 plans to auto-handle inside onnx2jax).
  • HSSMSpine: bayesflow-onnx-plans branch (stacked on NameError when wandb is not found #9 for the cross-reference to sbi-onnx-integration.md). Adds the design doc and an upstream-bugs catalog covering the five real upstream defects surfaced during this work.

Test plan

  • bayesflow NLE tests pass (6/6)
  • bayesflow NRE tests pass (4/4)
  • No regression on existing sbi tests
  • Cross-repo HSSM integration test (pytest.importorskip("hssm")) once both packages can be installed in the same env

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added transform_bayesflow_to_onnx to export BayesFlow NLE/NRE models into a single ONNX artifact for HSSM’s approx_differentiable log-likelihood workflow (with configurable ONNX opset).
    • Re-exported the new exporter from the ONNX module entry point.
  • Documentation

    • Added a guide covering installation, setup, quick-start export flows, and ONNX-specific constraints.
  • Tests

    • Added unit tests for NLE/NRE numerical and gradient agreement (via ONNX Runtime/JAX) plus an end-to-end BayesFlow → ONNX → HSSM MCMC integration test.
  • Chores

    • Added BayesFlow/Keras optional dependency extras and updated docs navigation.

Adds lanfactory.onnx.transform_bayesflow_to_onnx, the bayesflow sibling of
transform_sbi_to_onnx (PR #79). Wraps a trained bayesflow
ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a
single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable"
path. Same I/O contract as the sbi exporter (rank-1 input
[theta..., x...], rank-0 scalar log-likelihood, opset 17) so HSSM ingests
both via the same loglik="*.onnx" gesture with zero HSSM-side changes.

What's in this commit

- src/lanfactory/onnx/bayesflow.py: exporter module mirroring sbi.py.
  Contains _BayesflowNLELogProbWrapper and _BayesflowNRELogRatioWrapper.
  Pre-evaluates the bayesflow Standardize layer's moving mean/std to
  torch buffer constants at wrapper construction time so the ONNX trace
  is fully static (avoids If, Size, Tile dynamic-shape ops that
  jaxonnxruntime can't run). Guards on KERAS_BACKEND=torch and identity
  Adapter; both raise actionable errors with concrete fix hints.

- src/lanfactory/onnx/__init__.py: export the new function.

- pyproject.toml: add [bayesflow] optional extra
  (bayesflow>=2.0.8, keras>=3.12), add to [all] and [dev]. Also
  refactors the existing sbi+nflows pair into its own [sbi] extra
  (mirroring the new [bayesflow]) while keeping them in [all].

- tests/test_bayesflow_nle_export.py: 6 tests. Three-way numerical
  agreement (torch reference wrapper <-> onnxruntime <-> jaxonnxruntime)
  at atol=1e-5, gradient agreement at atol=1e-4, log-prob ordering
  sanity, and three guard tests (wrong backend, non-identity adapter,
  wrong mode).

- tests/test_bayesflow_nre_export.py: 4 tests. Same shape for the NRE
  path on a RatioApproximator.

- tests/test_bayesflow_hssm_integration.py: end-to-end DDM smoke
  (pytest.importorskip("hssm")). Mirrors test_sbi_hssm_integration.py.

- docs/exporting_bayesflow_models.md: full constraint catalog
  (KERAS_BACKEND, CouplingFlow knobs, silu vs hard_silu activation
  choice, identity-adapter requirement, JAX x64). Quick-starts for NLE
  and NRE. "Two paths into HSSM" framing alongside the JAX-callable
  path used in bayesflow_lre_integration.ipynb.

v1 constraints (documented, enforced where introspectable)

User must train with:
- permutation=None (FixedPermutation -> aten::ravel, unsupported)
- use_actnorm=False (untested in v1)
- transform=AffineTransform(clamp=False) explicit instance
  (find_transform("affine") drops kwargs - bayesflow upstream bug)
- subnet_kwargs.activation="silu" or another smooth activation
  (default hard_silu emits HardSwish, no jaxonnxruntime handler)
- identity Adapter (numpy-only adapter ops cannot be baked into ONNX)

Bayesflow continuous observations only. MNLE-style discrete + continuous
deferred until upstream MNLE support lands.

Numerical guarantees

19 passing tests across both bayesflow and sbi tracks; no regressions on
the existing sbi exporter. Each export is verified for three-way numerical
agreement at 1e-5 and gradient agreement at 1e-4.

Companion PRs

- HSSM: docs(tutorials): add bayesflow_nle_onnx_integration.ipynb
  on a fresh bayesflow-integration branch off main (sibling, not
  child, of the sbi-integration branch in PR #964).
- HSSMSpine: bayesflow-onnx-integration.md design doc +
  upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream
  defects surfaced during this work (jaxonnxruntime missing
  HardSwish/Size handlers; bayesflow find_transform kwarg-drop bug;
  bayesflow global torch.set_grad_enabled(False) cross-library leak;
  torch.onnx missing aten::ravel/asinh symbolic registrations).

This branch is stacked on sbi-connector (PR #79). When #79 merges,
this PR's base auto-retargets to main.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a BayesFlow-to-ONNX exporter so BayesFlow-trained NLE/NRE models can be consumed by HSSM via the same loglik="model.onnx", loglik_kind="approx_differentiable" pathway already used for LAN and sbi exports.

Changes:

  • Introduces lanfactory.onnx.transform_bayesflow_to_onnx with NLE/NRE wrappers that bake BayesFlow standardization statistics into a static ONNX trace.
  • Adds BayesFlow-focused regression tests (three-way numerical agreement + gradient agreement) and an optional HSSM integration smoke test.
  • Adds a bayesflow optional extra (and a dedicated sbi extra) plus a new documentation guide for BayesFlow exports.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
src/lanfactory/onnx/bayesflow.py New BayesFlow exporter implementation (NLE + NRE wrappers + ONNX export entry point).
src/lanfactory/onnx/__init__.py Exposes transform_bayesflow_to_onnx at the package level.
pyproject.toml Adds bayesflow and sbi extras; extends all and dev deps.
tests/test_bayesflow_nle_export.py New NLE export tests: forward/grad agreement + guardrails.
tests/test_bayesflow_nre_export.py New NRE export tests: forward/grad agreement + guardrails.
tests/test_bayesflow_hssm_integration.py New (skip-if-missing-HSSM) end-to-end BayesFlow→ONNX→HSSM integration test.
docs/exporting_bayesflow_models.md New BayesFlow ONNX export guide and constraint catalog.
uv.lock Dependency lockfile updates to include BayesFlow/Keras and related resolutions.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +32 to +34
import pandas as pd # noqa: E402

import bayesflow as bf # noqa: E402
Comment on lines +1 to +10
# Exporting bayesflow-trained networks to ONNX

LANfactory's [`transform_bayesflow_to_onnx`](api/onnx.md) is the bayesflow
sibling of [`transform_sbi_to_onnx`](exporting_sbi_models.md). It wraps a
trained [`bayesflow`](https://github.com/bayesflow-org/bayesflow)
`ContinuousApproximator` (NLE) or `RatioApproximator` (NRE) and writes a
single-trial ONNX file that HSSM's `loglik_kind="approx_differentiable"`
path can consume exactly like an sbi export. Same user gesture, same file
format, same HSSM-side loader — regardless of which training framework you
came from.
Comment thread tests/test_bayesflow_nle_export.py Outdated
Comment on lines +15 to +16
os.environ.setdefault("KERAS_BACKEND", "torch")
os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
Comment thread tests/test_bayesflow_nre_export.py Outdated
Comment on lines +22 to +23
os.environ.setdefault("KERAS_BACKEND", "torch")
os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
Comment thread pyproject.toml
@@ -82,6 +86,8 @@ dev = [
"jaxonnxruntime>=0.3",
"onnxruntime>=1.17",
"nflows>=0.14",
AlexanderFengler and others added 4 commits June 20, 2026 13:26
Brings the merged sbi → ONNX work (PR #79) and v0.6.1 docs/logos into the
bayesflow branch. Only pyproject.toml conflicted: both sides appended to the
dev dependency group. Resolved as the union — keep bayesflow's
`bayesflow`/`keras` and main's `sbi` — so the full sbi + bayesflow test
matrix is installed by `uv sync --all-groups`. uv.lock regenerated (uv 0.6.5)
to match the resolved manifest. All sbi-file changes (dim validation, doc-ref
fix, ruff formatting, guard tests, seeded-RNG fix) come straight from main.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Same as the sbi files: collapses multi-line signatures/f-strings that fit
within 88 chars. CI's `ruff format --check .` flagged these once the merge
brought the bayesflow files alongside the dev-group fix. Cosmetic only.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
exporting_bayesflow_models.md shipped with the feature but was never added
to the nav, so it was orphaned (unreachable in the built site). Add it under
Guides next to the sbi export guide.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
…alidation

Brings bayesflow.py to 100% patch coverage so codecov/patch passes on PR #80,
mirroring the sbi exporter:

- transform_bayesflow_to_onnx: validate example_theta_dim/example_x_dim are
  positive (parity with transform_sbi_to_onnx), raising a clear ValueError
  instead of failing deep in tracing.
- Add guard tests: invalid mode, non-positive dims, NLE wrapper missing
  .inference_network.
- Cover the standardizer if/else branches in both wrappers (the trained
  fixtures only standardize one slot) with lightweight stand-in approximators
  that exercise the opposite combination.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: 2f16bc95-b7fa-4ae1-8797-be537747aa48

📥 Commits

Reviewing files that changed from the base of the PR and between d8712dc and 0fe67dd.

📒 Files selected for processing (4)
  • docs/exporting_bayesflow_models.md
  • tests/test_bayesflow_hssm_integration.py
  • tests/test_bayesflow_nle_export.py
  • tests/test_bayesflow_nre_export.py
✅ Files skipped from review due to trivial changes (1)
  • docs/exporting_bayesflow_models.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_bayesflow_nre_export.py
  • tests/test_bayesflow_nle_export.py

📝 Walkthrough

Walkthrough

Adds transform_bayesflow_to_onnx, a new public function exporting Bayesflow ContinuousApproximator (NLE) and RatioApproximator (NRE) models to single-trial ONNX graphs consumable by HSSM. The implementation includes torch wrapper modules for both modes, frozen standardization, validation guards, three-way (torch/ONNX/JAX) numerical and gradient tests, an HSSM end-to-end DDM integration test, and a new documentation page.

Changes

Bayesflow ONNX Exporter

Layer / File(s) Summary
Dependencies and public API surface
pyproject.toml, src/lanfactory/onnx/__init__.py
Adds bayesflow>=2.0.8 and keras>=3.12 to all and dev dependency groups; re-exports transform_bayesflow_to_onnx in the onnx package __all__.
Core exporter: entry point and wrapper modules
src/lanfactory/onnx/bayesflow.py
Implements transform_bayesflow_to_onnx with dim validation, torch-backend enforcement, identity-adapter assertion, mode dispatch, rank-1 dummy tensor construction, and torch.onnx.export call. Adds _assert_identity_adapter, _frozen_mean_std, _BayesflowNLELogProbWrapper (standardizes theta/x, calls inference_network.log_prob, applies Jacobian correction), and _BayesflowNRELogRatioWrapper (standardizes, concatenates theta/x, runs inference_network, applies projector).
NLE export test suite
tests/test_bayesflow_nle_export.py
Trains an ONNX-friendly CouplingFlow NLE on a 2D Gaussian toy dataset; verifies three-way numerical and gradient agreement; tests log-prob ordering and rejects invalid configs (wrong backend, non-identity adapter, wrong mode, invalid dims, missing inference_network); unit-tests standardization branch with _FakeStandardizeLayer.
NRE export test suite
tests/test_bayesflow_nre_export.py
Trains an ONNX-friendly RatioApproximator on a 2D Gaussian toy dataset; verifies three-way numerical and gradient agreement; tests log-ratio ordering and rejects NRE in NLE mode; unit-tests standardization branch with _FakeStandardizeLayer.
HSSM end-to-end integration tests
tests/test_bayesflow_hssm_integration.py
Trains a tiny ContinuousApproximator on synthetic DDM data, exports to ONNX, asserts HSSM model construction succeeds, and validates that short MCMC sampling recovers DDM parameters within 2 standard deviations with r_hat < 1.05.
Documentation page and navigation
docs/exporting_bayesflow_models.md, mkdocs.yml
Adds full documentation page covering installation, KERAS_BACKEND=torch requirement, NLE/NRE quick-start examples, known constraints, out-of-scope models, numerical guarantees, HSSM integration paths, and related API links; registers the page under Guides.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant transform_bayesflow_to_onnx
  participant WrapperModule
  participant torch_onnx_export as torch.onnx.export
  participant ONNX_Runtime
  participant HSSM

  User->>transform_bayesflow_to_onnx: approximator, path, mode="nle"
  transform_bayesflow_to_onnx->>transform_bayesflow_to_onnx: validate dims, backend, adapter
  transform_bayesflow_to_onnx->>WrapperModule: construct _BayesflowNLELogProbWrapper
  transform_bayesflow_to_onnx->>torch_onnx_export: trace wrapper with dummy input
  torch_onnx_export-->>User: ONNX file written
  
  User->>HSSM: load ONNX with loglik_kind="approx_differentiable"
  HSSM->>ONNX_Runtime: load model.onnx
  ONNX_Runtime-->>HSSM: per-trial log-likelihood
  HSSM-->>User: posterior samples via MCMC
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • lnccbrown/LANfactory#79: Extends the same ONNX public API surface in src/lanfactory/onnx/__init__.py to export new HSSM-compatible exporter functions (transform_sbi_to_onnx vs transform_bayesflow_to_onnx).

Poem

🐇 Hop, hop — the model's packed in ONNX today,
Frozen means and sigmas neatly stored away.
NLE and NRE both wrapped with care,
Three runtimes agree — torch, JAX, and Runtime there.
The DDM recovers; r_hat stays small and true,
A Bayesflow rabbit ships the exporter for you! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.45% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately reflects the main change: introducing a bayesflow-to-ONNX exporter function (transform_bayesflow_to_onnx). It is concise, specific, and clearly communicates the primary feature being added.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch bayesflow-connector

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov

codecov Bot commented Jun 20, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/lanfactory/onnx/bayesflow.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AlexanderFengler AlexanderFengler changed the base branch from sbi-connector to main June 20, 2026 15:43
@AlexanderFengler

Copy link
Copy Markdown
Member Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@AlexanderFengler

Copy link
Copy Markdown
Member Author

@coderabbitai full review

@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown
✅ Action performed

Full review finished.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tests/test_bayesflow_nre_export.py (1)

136-138: ⚡ Quick win

Pin ONNX Runtime to CPU provider for deterministic test behavior.

On Line 136, provider auto-selection is host-dependent. Explicitly setting CPUExecutionProvider avoids environment-specific execution paths and reduces CI flakiness.

Suggested fix
-    sess = ort.InferenceSession(str(onnx_path))
+    sess = ort.InferenceSession(
+        str(onnx_path), providers=["CPUExecutionProvider"]
+    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_bayesflow_nre_export.py` around lines 136 - 138, The
ort.InferenceSession initialization on line 136 does not specify an execution
provider, causing it to auto-select based on the host environment which leads to
non-deterministic behavior across different CI environments. Modify the
InferenceSession constructor call to explicitly pass the providers parameter
with CPUExecutionProvider to ensure consistent execution behavior across all
test environments.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/exporting_bayesflow_models.md`:
- Around line 88-97: Add JAX 64-bit mode configuration before the quick-start
code blocks that use ONNX models. Since ONNX graphs contain int64 tensors and
JAX silently truncates to int32 by default (causing incorrect log-probability
calculations), include the JAX x64 configuration at the beginning of the code
snippet containing the HSSM instantiation and sampling call. Apply this fix to
both the ONNX sampling quick-start block and the NRE quick-start section
mentioned in the comment.

In `@tests/test_bayesflow_hssm_integration.py`:
- Around line 22-23: Replace the os.environ.setdefault calls for KERAS_BACKEND
and KERAS_TORCH_DEVICE with direct assignment using os.environ["KERAS_BACKEND"]
= "torch" and os.environ["KERAS_TORCH_DEVICE"] = "cpu" to ensure these values
are forcibly set regardless of any pre-existing environment variables,
guaranteeing test isolation and preventing nondeterministic failures from leaked
environment state.

In `@tests/test_bayesflow_nle_export.py`:
- Around line 15-16: The use of setdefault for KERAS_BACKEND and
KERAS_TORCH_DEVICE allows pre-existing environment variable values set elsewhere
to persist, which can break the test assumptions. Replace both setdefault calls
with direct assignment using the subscript notation so that KERAS_BACKEND is
forcefully set to "torch" and KERAS_TORCH_DEVICE is forcefully set to "cpu",
ensuring these values are not overridden by any previously set environment
variables.

In `@tests/test_bayesflow_nre_export.py`:
- Around line 13-15: The use of setdefault in lines setting KERAS_BACKEND and
KERAS_TORCH_DEVICE allows preexisting environment variables to persist, creating
test order dependency and non-hermetic behavior. Replace both setdefault calls
with direct assignment using os.environ[] bracket notation to force the backend
to always be torch and device to always be cpu, ensuring this test remains
independent of other test configurations.

---

Nitpick comments:
In `@tests/test_bayesflow_nre_export.py`:
- Around line 136-138: The ort.InferenceSession initialization on line 136 does
not specify an execution provider, causing it to auto-select based on the host
environment which leads to non-deterministic behavior across different CI
environments. Modify the InferenceSession constructor call to explicitly pass
the providers parameter with CPUExecutionProvider to ensure consistent execution
behavior across all test environments.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: 6772c78e-d6c3-4fdf-979d-1d1df817aa19

📥 Commits

Reviewing files that changed from the base of the PR and between f59f0a8 and d8712dc.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (8)
  • docs/exporting_bayesflow_models.md
  • mkdocs.yml
  • pyproject.toml
  • src/lanfactory/onnx/__init__.py
  • src/lanfactory/onnx/bayesflow.py
  • tests/test_bayesflow_hssm_integration.py
  • tests/test_bayesflow_nle_export.py
  • tests/test_bayesflow_nre_export.py

Comment thread docs/exporting_bayesflow_models.md
Comment on lines +22 to +23
os.environ.setdefault("KERAS_BACKEND", "torch")
os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and examine the test file
find . -type f -name "test_bayesflow_hssm_integration.py" | head -5

Repository: lnccbrown/LANfactory

Length of output: 107


🏁 Script executed:

# Also check the overall test structure
git ls-files tests/ | grep -E "\.py$" | head -20

Repository: lnccbrown/LANfactory

Length of output: 615


🏁 Script executed:

# Search for patterns of environment variable usage in tests
rg "os\.environ\[|os\.environ\.setdefault|pytest\.fixture" tests/ -t py | head -30

Repository: lnccbrown/LANfactory

Length of output: 1728


🏁 Script executed:

# Read the test file to verify lines 22-23
cat -n tests/test_bayesflow_hssm_integration.py | head -40

Repository: lnccbrown/LANfactory

Length of output: 1824


🏁 Script executed:

# Check conftest.py for environment setup patterns
cat -n tests/conftest.py | head -100

Repository: lnccbrown/LANfactory

Length of output: 3232


🏁 Script executed:

# Check similar test files that use the same setdefault pattern
cat -n tests/test_bayesflow_nle_export.py | head -40

Repository: lnccbrown/LANfactory

Length of output: 1950


🏁 Script executed:

# Check if there's a pattern in other similar test files (e.g., test_sbi_hssm_integration)
cat -n tests/test_sbi_hssm_integration.py | head -40

Repository: lnccbrown/LANfactory

Length of output: 1986


🏁 Script executed:

# Check for any direct assignment patterns in the codebase
rg "os\.environ\[\"KERAS" tests/ -A 1

Repository: lnccbrown/LANfactory

Length of output: 46


🏁 Script executed:

# Search for any documentation or comments about environment variable handling
rg "KERAS_BACKEND|KERAS_TORCH_DEVICE" tests/ -B 3 -A 1 | head -60

Repository: lnccbrown/LANfactory

Length of output: 3654


🏁 Script executed:

# Check if there are any CI/configuration files that might set these variables
find . -type f \( -name "*.yml" -o -name "*.yaml" -o -name ".env*" -o -name "pytest.ini" -o -name "pyproject.toml" \) | head -20

Repository: lnccbrown/LANfactory

Length of output: 1168


🏁 Script executed:

# Check the pyproject.toml or setup files for pytest configuration
cat pyproject.toml 2>/dev/null | grep -A 20 "^\[tool.pytest" || echo "No pytest config found"

Repository: lnccbrown/LANfactory

Length of output: 508


🏁 Script executed:

# Search for any monkeypatch or environment variable handling in conftest or tests
rg "monkeypatch.*KERAS|KERAS.*monkeypatch" tests/ -B 2 -A 2

Repository: lnccbrown/LANfactory

Length of output: 46


🏁 Script executed:

# Check the CI workflow to see if environment variables are set before tests
cat .github/workflows/run_tests.yml | head -100

Repository: lnccbrown/LANfactory

Length of output: 1294


🏁 Script executed:

# Check if there's any documentation about environment setup for tests
find . -type f \( -name "README*" -o -name "CONTRIBUTING*" \) -exec cat {} \; | grep -i "keras\|backend\|environ" | head -20

Repository: lnccbrown/LANfactory

Length of output: 235


🏁 Script executed:

# Check if test_bayesflow_nle_export tests backend validation with monkeypatch
cat -n tests/test_bayesflow_nle_export.py | grep -A 20 "test_transform_rejects_wrong_backend"

Repository: lnccbrown/LANfactory

Length of output: 961


🏁 Script executed:

# Check the full structure of test_bayesflow_hssm_integration to see if there are similar backend tests
cat -n tests/test_bayesflow_hssm_integration.py | tail -100

Repository: lnccbrown/LANfactory

Length of output: 4040


Force Keras backend/device assignment to ensure test isolation.

Using setdefault allows pre-existing KERAS_BACKEND or KERAS_TORCH_DEVICE environment variables from the user's shell to leak in. This can cause nondeterministic test failures if the environment is configured for a different backend (e.g., jax). Since the code comment explicitly states these must be set before imports, use direct assignment to guarantee the correct values regardless of environment state.

Suggested patch
-os.environ.setdefault("KERAS_BACKEND", "torch")
-os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
+os.environ["KERAS_BACKEND"] = "torch"
+os.environ["KERAS_TORCH_DEVICE"] = "cpu"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
os.environ.setdefault("KERAS_BACKEND", "torch")
os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
os.environ["KERAS_BACKEND"] = "torch"
os.environ["KERAS_TORCH_DEVICE"] = "cpu"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_bayesflow_hssm_integration.py` around lines 22 - 23, Replace the
os.environ.setdefault calls for KERAS_BACKEND and KERAS_TORCH_DEVICE with direct
assignment using os.environ["KERAS_BACKEND"] = "torch" and
os.environ["KERAS_TORCH_DEVICE"] = "cpu" to ensure these values are forcibly set
regardless of any pre-existing environment variables, guaranteeing test
isolation and preventing nondeterministic failures from leaked environment
state.

Comment thread tests/test_bayesflow_nle_export.py Outdated
Comment thread tests/test_bayesflow_nre_export.py Outdated
Comment on lines +13 to +15
os.environ.setdefault("KERAS_BACKEND", "torch")
os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make backend preconditions explicit, not setdefault-dependent.

On Line 13 and Line 14, setdefault keeps any preexisting environment values, which can make this module order-dependent if another test session config sets a different backend. Fail fast (or force values) so this test remains hermetic.

Suggested fix
-os.environ.setdefault("KERAS_BACKEND", "torch")
-os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu")
+os.environ["KERAS_BACKEND"] = "torch"
+os.environ["KERAS_TORCH_DEVICE"] = "cpu"
 import keras  # noqa: E402
+if keras.backend.backend() != "torch":
+    pytest.skip("tests/test_bayesflow_nre_export.py requires KERAS_BACKEND=torch")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_bayesflow_nre_export.py` around lines 13 - 15, The use of
setdefault in lines setting KERAS_BACKEND and KERAS_TORCH_DEVICE allows
preexisting environment variables to persist, creating test order dependency and
non-hermetic behavior. Replace both setdefault calls with direct assignment
using os.environ[] bracket notation to force the backend to always be torch and
device to always be cpu, ensuring this test remains independent of other test
configurations.

- docs/exporting_bayesflow_models.md: enable jax_enable_x64 in the NLE
  quick-start before HSSM imports jax — the "Known constraints" section
  already documents that ONNX int64 tensors get truncated under JAX's default
  32-bit mode, but the copy-paste quick-start omitted the setup.
- test_bayesflow_{nle,nre,hssm_integration}: force KERAS_BACKEND=torch instead
  of setdefault. These tests hard-require the torch backend; setdefault would
  silently honor a stray KERAS_BACKEND=jax in the environment and fail
  confusingly. KERAS_TORCH_DEVICE stays a setdefault (a preference, not a
  requirement — forcing cpu would override GPU users).

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Copilot AI review requested due to automatic review settings June 20, 2026 16:01

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 2 comments.

Comment on lines +52 to +60
def _simulate_ddm_rows(theta: np.ndarray) -> np.ndarray:
"""Simulate one (rt, choice) per row of theta. Returns (n, 2) float32."""
rts = np.empty(theta.shape[0], dtype=np.float32)
choices = np.empty(theta.shape[0], dtype=np.float32)
for i, th in enumerate(theta):
out = simulator(theta=th[None, :], model="ddm", n_samples=1)
rts[i] = out["rts"].squeeze()
choices[i] = out["choices"].squeeze()
return np.stack([rts, choices], axis=-1)
Comment on lines +63 to +68
def _build_observed_dataframe() -> pd.DataFrame:
"""Generate N_OBS trials at the true theta as an HSSM-shaped DataFrame."""
out = simulator(theta=_TRUE_THETA[None, :], model="ddm", n_samples=_N_OBS)
rts = out["rts"].squeeze().astype(np.float32)
choices = out["choices"].squeeze().astype(np.float32)
return pd.DataFrame({"rt": rts, "response": choices})
@AlexanderFengler AlexanderFengler merged commit c9a5b56 into main Jun 20, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants