Skip to content

How to sample fresh noise at each step for a non-linear noise-perturbed ODE dX = f(X, z)dt? #759

@vadmbertr

Description

@vadmbertr

Hi,
Thanks for the great work in developing diffrax!

I have a use case involving the introduction of some noise at every integration time-step, but then apply a non-linear transformation to it (so it's not a "classical" SDE, more something like a perturbed ODE if that makes sense).
It would broadly write as $dX = f(X, z) dt$, where $z \sim \mathcal{N}(0, 1)$ and $f$ is an arbitrary function (a neural network for example).

I'm not very sure about how to implement a vector field corresponding to $f$, as a fresh sample $z$ is needed at every integration time-step (and that we can not pass around a "fresh" jax.Key, as args must be fixed during the whole integration).
The way I see this would be to pass a VirtualBrownianTree to the vector field, through args, and interpolate this path at every integration time $t$ to sample $z$. Is this correct? If so, would it be "safe" to backpropagate through the solve if using UnsafeBrownianPath with a fixed step-size?

Do you think about another approach to this?

Thank you for the feedback!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions