From e3ae2c3ee0f0085a0829031c34b95320182b2a81 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 7 Apr 2026 16:33:37 -0700 Subject: [PATCH] use_running_average and deterministic default to None PiperOrigin-RevId: 896146741 --- flax/nnx/nn/normalization.py | 2 +- flax/nnx/nn/stochastic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index e06ab9fcd..1766da5b0 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -290,7 +290,7 @@ def __init__( self, num_features: int, *, - use_running_average: bool | None = False, + use_running_average: bool | None = None, axis: int = -1, momentum: float = 0.99, epsilon: float = 1e-5, diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index 6d03e7353..4dcc78925 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -73,7 +73,7 @@ def __init__( rate: float, *, broadcast_dims: Sequence[int] = (), - deterministic: bool | None = False, + deterministic: bool | None = None, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ):