From d40e1578e1173fd5bc95b7d46e68d6f53ef8c464 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 7 Apr 2026 15:27:47 -0700 Subject: [PATCH] use State instead of GraphState PiperOrigin-RevId: 896115256 --- flax/nnx/extract.py | 4 +-- flax/nnx/graphlib.py | 55 ++++++++++++++++---------------- flax/nnx/module.py | 3 +- flax/nnx/transforms/autodiff.py | 2 +- flax/nnx/transforms/iteration.py | 8 ++--- 5 files changed, 35 insertions(+), 37 deletions(-) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 3bf687aae..f839149b7 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -223,10 +223,10 @@ def broadcast_prefix_map( class GraphDefState(struct.PyTreeNode): graphdef: graphlib.GraphDef[tp.Any] = struct.field(pytree_node=False) - state: graphlib.GraphState = struct.field(pytree_node=True) + state: graphlib.State = struct.field(pytree_node=True) S = tp.TypeVar( - 'S', bound=graphlib.GraphState | graphlib.GraphFlatState | list[tp.Any] + 'S', bound=graphlib.State | graphlib.GraphFlatState | list[tp.Any] ) class NodeStates(struct.PyTreeNode): diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 6cf896f23..d3740ddb0 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -143,7 +143,6 @@ def __treescope_repr__(self, path, subtree_renderer): ArrayRefOutput, NoUpdate, ] -GraphState = State[Key, LeafType] GraphFlatState = FlatState[LeafType] @@ -657,14 +656,14 @@ def with_same_outer_index(self) -> GraphDef[Node]: # TODO(cgarciae): remove this method def apply( - self, state: GraphState, *states: GraphState, + self, state: State, *states: State, graph: bool | None = None, - ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: + ) -> ApplyCaller[tuple[GraphDef[Node], State]]: accessor = DelayedAccessor() def _apply( accessor: DelayedAccessor, *args, **kwargs - ) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]: + ) -> tuple[tp.Any, tuple[GraphDef[Node], State]]: module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) @@ -679,7 +678,7 @@ def _apply( return CallableProxy(_apply, accessor) # type: ignore -PureState = tuple[GraphDef[Node], GraphState] +PureState = tuple[GraphDef[Node], State] def _tree_flatten( @@ -1365,7 +1364,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: def graph_pop( node: tp.Any, filters: tuple[filterlib.Filter, ...], -) -> tuple[GraphState, ...]: +) -> tuple[State, ...]: id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) @@ -1693,12 +1692,12 @@ class SplitContext: is_inner: bool | None @tp.overload - def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... # type: ignore[invalid-annotation] + def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ... # type: ignore[invalid-annotation] @tp.overload def split( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, / - ) -> tuple[GraphDef[A], GraphState]: ... + ) -> tuple[GraphDef[A], State]: ... @tp.overload def split( @@ -1708,11 +1707,11 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... # type: ignore[not-supported-yet] + ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ... # type: ignore[not-supported-yet] def split( self, node: A, *filters: filterlib.Filter - ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: # type: ignore[not-supported-yet] + ) -> tuple[GraphDef[A], tpe.Unpack[tuple[State, ...]]]: # type: ignore[not-supported-yet] ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) @@ -1872,9 +1871,9 @@ class MergeContext: def merge( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], - state: GraphState, + state: State, /, - *states: GraphState, + *states: State, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None @@ -2232,11 +2231,11 @@ def _split_state( @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, /, *, graph: bool | None = None, -) -> tuple[GraphDef[A], GraphState]: ... +) -> tuple[GraphDef[A], State]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, /, *, graph: bool | None = None, -) -> tuple[GraphDef[A], GraphState]: ... +) -> tuple[GraphDef[A], State]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, @@ -2247,15 +2246,15 @@ def split( # type: ignore[invalid-annotation] graph: bool | None = None, ) -> tuple[ GraphDef[A], - GraphState, - tpe.Unpack[tuple[GraphState, ...]], + State, + tpe.Unpack[tuple[State, ...]], ]: ... def split( # type: ignore[invalid-annotation] node: A, *filters: filterlib.Filter, graph: bool | None = None, ) -> tuple[ GraphDef[A], - GraphState, - tpe.Unpack[tuple[GraphState, ...]], + State, + tpe.Unpack[tuple[State, ...]], ]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef @@ -2467,9 +2466,9 @@ def update(node, state: tp.Any, /, *states: tp.Any) -> None: @tp.overload -def state(node, /, *, graph: bool | None = None) -> GraphState: ... +def state(node, /, *, graph: bool | None = None) -> State: ... @tp.overload -def state(node, first: filterlib.Filter, /, *, graph: bool | None = None) -> GraphState: ... +def state(node, first: filterlib.Filter, /, *, graph: bool | None = None) -> State: ... @tp.overload def state( node, @@ -2478,12 +2477,12 @@ def state( /, *filters: filterlib.Filter, graph: bool | None = None, -) -> tuple[GraphState, ...]: ... +) -> tuple[State, ...]: ... def state( node, *filters: filterlib.Filter, graph: bool | None = None, -) -> tp.Union[GraphState, tuple[GraphState, ...]]: +) -> tp.Union[State, tuple[State, ...]]: """Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters. Example usage:: @@ -2522,7 +2521,7 @@ def state( _, flat_state = flatten(node, graph=graph) state = flat_state.to_nested_state() - states: GraphState | tuple[GraphState, ...] + states: State | tuple[State, ...] if len(filters) == 0: states = state # type: ignore[assignment] elif len(filters) == 1: @@ -2611,7 +2610,7 @@ def pop( node, filter: filterlib.Filter, /, -) -> GraphState: ... +) -> State: ... @tp.overload @@ -2621,12 +2620,12 @@ def pop( filter2: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[GraphState, ...]: ... +) -> tuple[State, ...]: ... def pop( node, *filters: filterlib.Filter -) -> tp.Union[GraphState, tuple[GraphState, ...]]: +) -> tp.Union[State, tuple[State, ...]]: """Pop one or more :class:`Variable` types from the graph node. Example usage:: @@ -2818,8 +2817,8 @@ def _pure_fn(x): pure = deprecated(as_pure) def call( - graphdef_state: tuple[GraphDef[A], GraphState], / -) -> ApplyCaller[tuple[GraphDef[A], GraphState]]: + graphdef_state: tuple[GraphDef[A], State], / +) -> ApplyCaller[tuple[GraphDef[A], State]]: """Calls a method underlying graph node defined by a (GraphDef, State) pair. ``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be diff --git a/flax/nnx/module.py b/flax/nnx/module.py index a2d977ad9..ad236734b 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -27,7 +27,6 @@ ) from flax.nnx import variablelib as variableslib from flax.nnx.pytreelib import Pytree, PytreeMeta -from flax.nnx.graphlib import GraphState from flax.nnx.statelib import split_state, State import functools as ft from flax.typing import Key, Path, PathParts @@ -37,7 +36,7 @@ A = tp.TypeVar('A') B = tp.TypeVar('B') M = tp.TypeVar('M', bound='Module') -S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]]) +S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]]) V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any]) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 9ea5ff3cc..758a95b8d 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -1063,7 +1063,7 @@ def _custom_vjp_split_fn( *, nondiff_states: list[extract.GraphDefState], ): - broadcast: graphlib.GraphState + broadcast: State if prefix is False: # pure non-differentiable arg, not supported raise TypeError( diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 67ec87740..165bd7522 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -223,8 +223,8 @@ def __hash__(self): AxisFn = tp.Callable[ - [graphlib.GraphState | variablelib.Variable, int, tp.Mapping], - graphlib.GraphState | variablelib.Variable, + [State | variablelib.Variable, int, tp.Mapping], + State | variablelib.Variable, ] @@ -241,9 +241,9 @@ def _update_axes_fn(node_states): state = axis_fn(state, node_states.metadata, transform_metadata) return node_states.replace(states=(state,)) else: - states_out: list[graphlib.GraphState | variablelib.Variable] = [] + states_out: list[State | variablelib.Variable] = [] for state, axis in zip(node_states.states, node_states.metadata.axes): - assert isinstance(state, graphlib.State | variablelib.Variable) + assert isinstance(state, State | variablelib.Variable) if isinstance(axis, int): state = axis_fn(state, axis, transform_metadata) states_out.append(state)