Skip to content

JointDistributionCoroutineAutoBatched seed behaviour with batches #2011

@chrism0dwk

Description

@chrism0dwk

Hi all,

I'm trying to compute a posterior predictive distribution over samples from a posterior distribution (Colab here). TFP 0.25 with JAX backend.

My (mre and therefore contrived) model specification is

@tfd.JointDistributionCoroutineAutoBatched
def model_autobatched():
    theta = yield tfd.Normal(loc=0., scale=1., name="theta")
    yield tfd.Normal(loc=theta, scale=0.1, name="y")

i.e. a Normally-distributed observation model with Normally-distributed mean. To compute the posterior predictive distribution, I wish to sample the y component conditional on a vector of theta samples.

theta_samples = np.arange(5.)
model_autobatched.sample(theta=theta_samples, seed=jax.random.key(0))

giving

StructTuple(
  theta=Array([0., 1., 2., 3., 4.], dtype=float32),
  y=Array([0.06215769, 1.0621576 , 2.0621576 , 3.0621576 , 4.0621576 ],      dtype=float32)
)

Oh dear, we notice that y - theta = constant. This seems to suggest that a single PRNG key is being used for each draw of y given the sample from theta.

Moreover, this approach fails entirely if sample_distributions is called.

model_autobatched.sample_distributions(theta=theta_samples, seed=jax.random.key(0))
ValueError: Attempt to convert a value (<object object at 0x7a53561590d0>) with an unsupported type (<class 'object'>) to a Tensor.

As a workaround, we could use the older JointDistributionCoroutine with Root annotation which works as desired (see Colab)

[edit] actually, JDCoroutine/Root only works because the whole theta vector is passed to y's constructor, not vectorisation over the whole model.

Do we have a bug or a feature, I wonder?

Regards,

Chris

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions