A minimal implementation of preconditioned Crank-Nicolson MCMC sampling.
minipcn can be installed from PyPI using pip:
pip install minipcnThe basic usage is:
from minipcn import Sampler
import numpy as np
log_prob_fn = ... # Log-probability function - must be vectorized
dims = ... # The number of dimensions
rng = np.random.default_rng(42)
sampler = Sampler(
log_prob_fn=log_prob_fn,
dims=dims,
step_fn="pcn", # Or "tpcn"
)
x0 = rng.normal(size=(100, dims))
chain, history = sampler.sample(x0, n_steps=500, rng=rng)For a complete example, see the examples directory.
minipcn also supports different array API backends via array-api-compat
and orng for random number generation.
Usage is then similar to when using numpy, except one must use the RNG from
orng and specify the backend via xp:
from minipcn import Sampler
from orng import RandomGenerator
import torch
log_prob_fn = ... # Log-probability function - must be vectorized
dims = ... # The number of dimensions
rng = RandomGenerator(backend="torch", seed=42)
sampler = Sampler(
log_prob_fn=log_prob_fn,
dims=dims,
step_fn="pcn", # Or tpcn
xp=torch,
)
# Generate initial samples
x0 = rng.randn(size=(100, dims))
# Run the sampler
chain, history = sampler.sample(x0, n_steps=500, rng=rng)Note: the tpCN step falls back to numpy for fitting the Student-t distribution
minipcn also supports explicit functional RNG state via
Sampler.sample_functional(...). This is the path to use for JAX compilation
or any workflow where RNG state must be threaded explicitly.
The functional API does not take an RNG object but a backend and state:
import jax
import jax.numpy as jnp
from minipcn import Sampler
from orng.functional import create_functional_backend
dims = 4
rng_backend = create_functional_backend("jax")
rng_state = rng_backend.init_state(seed=42, generator=None)
x0, rng_state = rng_backend.normal(
rng_state,
loc=0.0,
scale=1.0,
size=(32, dims),
dtype=jnp.float32,
)
def log_prob_fn(x):
return -0.5 * jnp.sum(x**2, axis=-1)
sampler = Sampler(
log_prob_fn=log_prob_fn,
dims=dims,
step_fn="pcn",
xp=jnp,
)
samples, history, next_rng_state = sampler.sample_functional(
x0,
n_steps=8,
rng_state=rng_state,
verbose=False,
return_last_only=True,
)sample_functional(...) returns (chain, history, next_rng_state).
To use it under jax.jit, thread the state through the compiled function:
@jax.jit
def
run(x, state):
samples, history, next_state = sampler.sample_functional(
x,
n_steps=8,
rng_state=state,
verbose=False,
return_last_only=True,
)
return samples, history, next_state
samples, history, rng_state = run(x0, rng_state)The backend for sample_functional(...) is inferred from xp. For example:
xp=npuses the NumPy functional backendxp=jax.numpyuses the JAX functional backendxp=torchuses the PyTorch functional backend
Use sample(...) for stateful RNG objects and sample_functional(...) when
you want explicit RNG state.
If you use minipcn in your work, please cite our DOI
If using the tpcn kernel, please also cite Grumitt et al