Skip to content

Cpaniaguam/fix rl likelihood builder pymc6#975

Draft
cpaniaguam wants to merge 49 commits into
990-fix-sample-posterior-predictivefrom
cpaniaguam/fix-rl-likelihood-builder-pymc6
Draft

Cpaniaguam/fix rl likelihood builder pymc6#975
cpaniaguam wants to merge 49 commits into
990-fix-sample-posterior-predictivefrom
cpaniaguam/fix-rl-likelihood-builder-pymc6

Conversation

@cpaniaguam

@cpaniaguam cpaniaguam commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Prevent PyTensor from incorrectly optimizing JAX-backed operations.

JAX-PyTensor Integration Improvements:

  • Added do_constant_folding methods to both LANLogpOp and LANLogpVJPOp classes in src/hssm/distribution_utils/jax.py to prevent PyTensor from attempting to precompute (constant fold) outputs of JAX-backed operations, ensuring correct runtime behavior. [1] [2]

Test Suite Updates:

  • Updated test_make_rl_logp_op in tests/rl/test_rl_likelihood_builder.py to use pytensor.function with mode="FAST_COMPILE" for gradient evaluation, improving test reliability and compatibility.
  • Marked test_predictive_idata_to_dataframe in tests/test_utils.py with @pytest.mark.xfail to indicate it is expected to fail due to recent changes in PyMC, preventing it from causing false negatives in CI. See Cpaniaguam/fix predictive idata to dataframe datatree #974 with a fix.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves the PyTensor↔JAX integration by preventing PyTensor from constant-folding JAX-backed Ops (which can lead to incorrect/unsupported compile-time evaluation), and adjusts tests to be more robust under recent PyTensor/PyMC behavior changes.

Changes:

  • Disabled PyTensor constant folding for the JAX-backed LANLogpOp and LANLogpVJPOp.
  • Updated the RL likelihood builder gradient test to evaluate gradients via a compiled pytensor.function(..., mode="FAST_COMPILE").
  • Marked test_predictive_idata_to_dataframe as xfail due to upstream PyMC changes.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
src/hssm/distribution_utils/jax.py Prevents constant folding of JAX-backed Ops to avoid incorrect compile-time evaluation.
tests/rl/test_rl_likelihood_builder.py Uses a compiled PyTensor function in FAST_COMPILE mode for gradient evaluation reliability.
tests/test_utils.py Marks a known-broken test as xfail to avoid CI false negatives.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/test_utils.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment thread tests/test_utils.py Outdated
Comment thread pyproject.toml Outdated
Comment on lines +29 to +30
"h5netcdf>=1.6.3",
"h5py>=3.14.0",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these packages? Are these absolutely necessary to maintain basic functionalities?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. HSSM still restores .nc traces through az.from_netcdf(...)

traces = az.from_netcdf(traces)

traces = az.from_netcdf(traces)

idata_dict["idata_mcmc"] = az.from_netcdf(traces_path)

What seems to have changed is that newer xarray/ArviZ installs no longer seem to guarantee a working NetCDF backend transitively, and CI started failing when pytest tried to load the .nc fixture in

return az.from_netcdf("tests/fixtures/cavanagh_idata.nc")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it seems that both xarray and arviz now have them as optional and/or dev dependencies. I think these packages do that for a good reason, and we probably should do the same and then add these packages to dev dependencies. In code, we can probably check if these packages are installed and throw an informative error if not.

I like to keep the dependencies as slim as possible and not add additional packages for things that are occasionally used. @cpaniaguam @AlexanderFengler @krishnbera thoughts?

@cpaniaguam these .nc files might need to be generated again since they came from old InferenceData saves

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I think it could be a good idea to have an io optional group and have cloudpickle, h5py and h5netcdf in that group

Comment thread src/hssm/distribution_utils/jax.py
@digicosmos86

Copy link
Copy Markdown
Collaborator

@cpaniaguam the upstream errors should all be fixed now. Can you change the base to 996-fix-model-cartoon-plot and see if all tests pass now?

@cpaniaguam

Copy link
Copy Markdown
Collaborator Author

@cpaniaguam the upstream errors should all be fixed now. Can you change the base to 996-fix-model-cartoon-plot and see if all tests pass now?

Sure!

@cpaniaguam cpaniaguam changed the base branch from 970-compatibility-with-pymc6 to 996-fix-model-cartoon-plot June 22, 2026 13:05
@cpaniaguam

Copy link
Copy Markdown
Collaborator Author

digicosmos86 and others added 2 commits June 22, 2026 14:40
…-is-not-callable

[pymc6 migration] Basic errors in slow MCMC tests
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: 4c8f4750-75cd-4c04-bffa-b7c85f6bb1d9

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch cpaniaguam/fix-rl-likelihood-builder-pymc6

Comment @coderabbitai help to get the list of available commands.

@cpaniaguam

Copy link
Copy Markdown
Collaborator Author

@digicosmos86 There still is one function that needs to be updated to handle xr.DataTree in src/hssm/utils.py -- predictive_idata_to_dataframe. The corresponding test is in tests/test_utils.py::test_predictive_idata_to_dataframe. I kept the failing marking on it.

@digicosmos86

Copy link
Copy Markdown
Collaborator

@digicosmos86 There still is one function that needs to be updated to handle xr.DataTree in src/hssm/utils.py -- predictive_idata_to_dataframe. The corresponding test is in tests/test_utils.py::test_predictive_idata_to_dataframe. I kept the failing marking on it.

This is fixed here https://github.com/lnccbrown/HSSM/pull/993/changes#diff-11c83362802db3028664e029ce20f83d181667e3631fe618614ad9672665d383R558

…ctive

Reapply [pymc6 migration] fixed `sample_posterior_predictive`, `sample_prior_predictive`, and `sample_do`
Base automatically changed from 996-fix-model-cartoon-plot to 990-fix-sample-posterior-predictive June 23, 2026 14:39
@digicosmos86 digicosmos86 changed the base branch from 990-fix-sample-posterior-predictive to migration-pymc6 June 23, 2026 15:08
@digicosmos86 digicosmos86 changed the base branch from migration-pymc6 to main June 23, 2026 16:40
@digicosmos86 digicosmos86 changed the base branch from main to migration-pymc6 June 23, 2026 16:40
@digicosmos86 digicosmos86 changed the base branch from migration-pymc6 to main June 23, 2026 16:50
@digicosmos86 digicosmos86 changed the base branch from main to migration-pymc6 June 23, 2026 16:51
…ive' into cpaniaguam/fix-rl-likelihood-builder-pymc6
@cpaniaguam cpaniaguam changed the base branch from migration-pymc6 to 990-fix-sample-posterior-predictive June 23, 2026 19:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[pymc6 migration] Fix likelihood-related issues with RL due to change of internals

3 participants