Skip to content

Add SubsampledLogDensity #271

Open
yebai wants to merge 7 commits into
mainfrom
refactor-dynamicppl-subsample-callback
Open

Add SubsampledLogDensity #271
yebai wants to merge 7 commits into
mainfrom
refactor-dynamicppl-subsample-callback

Conversation

@yebai

@yebai yebai commented Jun 5, 2026

Copy link
Copy Markdown
Member

Fix #259

Replaces the bespoke `DynamicPPLModelLogDensityFunction` ext wrapper with a
generic `SubsampledLogDensity` in main, plus a small DPPL-glue method in the
extension. Closes the documentation gap from #259.

- `SubsampledLogDensity(prob, make_prob, dataset_size)` wraps any
  `LogDensityProblem` and rebuilds the inner problem per batch via a
  user-supplied `(batch, scale) -> prob` callback; type-stable, validates
  `dataset_size > 0` and `length(batch) <= dataset_size`.
- `WeightedLogJoint(scale)` carries the SG correction; the DPPL extension
  defines its call method on `DynamicPPL.AbstractVarInfo` using the
  `LogPrior` / `LogLikelihood` / `LogJacobian` accumulators.
- Renames the subsampling protocol verb `subsample` → `with_batch` across
  algorithms, the test model, and the LogReg tutorial. `with_batch` is
  pure (returns a fresh wrapper) and reads correctly: it applies a batch
  rather than drawing one.
- DPPL users now write a model factory parametric in `N` and supply the
  data per batch via `|` (conditioning); the old `datapoints=` keyword
  convention is no longer required.

Trade-offs (intentional):
- The inner `DynamicPPL.LogDensityFunction` is rebuilt — and its AD prep
  re-run — on every `with_batch` call, because upstream LDF bakes the
  model into its prep context. Negligible for ForwardDiff; per-step cost
  for compiled backends (Mooncake, Enzyme).
- DPPL targets now expose `LogDensityOrder{1}` only; the previous wrapper
  also offered `LogDensityOrder{2}` via a separate Hessian prep.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
yebai and others added 3 commits June 5, 2026 17:05
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The previous model wrote `x ~ MvNormal(μs_batch[:, i], I)` N times for the
same `x`. DPPL routes ALL contributions of a sampled variable into
`LogPriorAccumulator`, so `LogLikelihood` stayed at 0.0 and the
`scale * loglike + logprior - logjac` correction multiplied zero. The test
passed only because the per-batch gradient direction happened to point at
the right answer.

Inverted the model to a proper hierarchical setup: a latent `μ` with a
weak prior, observations conditioned on `μ`. Now the per-batch contributions
land in `LogLikelihoodAccumulator`, where the SG-correction scale is
actually applied.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
`μ_true = [-2.0, 2.0]` was a fixed point that q never converges to
exactly — q converges to the sample mean of `observations`, which differs
by O(1/√n_data). Compute the target from the data so the assertion checks
proximity to the actual fixed point.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@yebai yebai marked this pull request as ready for review June 5, 2026 16:14
Comment thread docs/src/tutorials/subsampling.md Outdated
yebai and others added 2 commits June 5, 2026 17:19
Reverts the protocol-verb rename to minimise the diff against main: the
algorithm files and `test/models/subsamplednormals.jl` are now byte-identical
to main. Only the new abstraction (`SubsampledLogDensity`, `WeightedLogJoint`)
and the rewired DPPL ext/test/doc remain changed.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Both `bayes_logreg(X_batch, N) | (y = y_batch,)` (conditioning) and
`bayes_logreg(X_batch, y_batch, N)` (argument) route per-batch logpdf
contributions into `LogLikelihoodAccumulator`, so the SG correction
applies identically. Adds a side note to the subsampling tutorial.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Details
Benchmark suite Current: 3c12ec4 Previous: c4aaa5f Ratio
normal/RepGradELBO/fullrank/Mooncake 688795402.5 ns 522191080.5 ns 1.32
normal/RepGradELBO/fullrank/ReverseDiff 575899447 ns 591786773 ns 0.97
normal/RepGradELBO/meanfield/Mooncake 245520817 ns 212394574 ns 1.16
normal/RepGradELBO/meanfield/ReverseDiff 295496401 ns 297190927 ns 0.99
normal/RepGradELBO + STL/fullrank/Mooncake 841880102 ns 660211320.5 ns 1.28
normal/RepGradELBO + STL/fullrank/ReverseDiff 1162326165 ns 1129261429 ns 1.03
normal/RepGradELBO + STL/meanfield/Mooncake 383625489 ns 319214528 ns 1.20
normal/RepGradELBO + STL/meanfield/ReverseDiff 598968879 ns 600170298 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

`SubsampledLogDensity` and `WeightedLogJoint` are not listed under any
`@docs`/`@autodocs` block in the manual (35 other docstrings have the same
status), so Documenter can't resolve `@ref`. Plain code-style backticks
render the same monospace name without breaking the build.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

AdvancedVI.jl documentation for PR #271 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR271/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Docs on subsample for DynamicPPL models?

1 participant