diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index e06ab9fcd..1766da5b0 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -290,7 +290,7 @@ def __init__( self, num_features: int, *, - use_running_average: bool | None = False, + use_running_average: bool | None = None, axis: int = -1, momentum: float = 0.99, epsilon: float = 1e-5, diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index 6d03e7353..4dcc78925 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -73,7 +73,7 @@ def __init__( rate: float, *, broadcast_dims: Sequence[int] = (), - deterministic: bool | None = False, + deterministic: bool | None = None, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ):