Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ def broadcast_prefix_map(

class GraphDefState(struct.PyTreeNode):
graphdef: graphlib.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: graphlib.GraphState = struct.field(pytree_node=True)
state: graphlib.State = struct.field(pytree_node=True)

S = tp.TypeVar(
'S', bound=graphlib.GraphState | graphlib.GraphFlatState | list[tp.Any]
'S', bound=graphlib.State | graphlib.GraphFlatState | list[tp.Any]
)

class NodeStates(struct.PyTreeNode):
Expand Down
55 changes: 27 additions & 28 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __treescope_repr__(self, path, subtree_renderer):
ArrayRefOutput,
NoUpdate,
]
GraphState = State[Key, LeafType]
GraphFlatState = FlatState[LeafType]


Expand Down Expand Up @@ -657,14 +656,14 @@ def with_same_outer_index(self) -> GraphDef[Node]:

# TODO(cgarciae): remove this method
def apply(
self, state: GraphState, *states: GraphState,
self, state: State, *states: State,
graph: bool | None = None,
) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]:
) -> ApplyCaller[tuple[GraphDef[Node], State]]:
accessor = DelayedAccessor()

def _apply(
accessor: DelayedAccessor, *args, **kwargs
) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]:
) -> tuple[tp.Any, tuple[GraphDef[Node], State]]:
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
Expand All @@ -679,7 +678,7 @@ def _apply(
return CallableProxy(_apply, accessor) # type: ignore


PureState = tuple[GraphDef[Node], GraphState]
PureState = tuple[GraphDef[Node], State]


def _tree_flatten(
Expand Down Expand Up @@ -1365,7 +1364,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
def graph_pop(
node: tp.Any,
filters: tuple[filterlib.Filter, ...],
) -> tuple[GraphState, ...]:
) -> tuple[State, ...]:
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
Expand Down Expand Up @@ -1693,12 +1692,12 @@ class SplitContext:
is_inner: bool | None

@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... # type: ignore[invalid-annotation]
def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ... # type: ignore[invalid-annotation]

@tp.overload
def split( # type: ignore[invalid-annotation]
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], GraphState]: ...
) -> tuple[GraphDef[A], State]: ...

@tp.overload
def split(
Expand All @@ -1708,11 +1707,11 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... # type: ignore[not-supported-yet]
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ... # type: ignore[not-supported-yet]

def split(
self, node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: # type: ignore[not-supported-yet]
) -> tuple[GraphDef[A], tpe.Unpack[tuple[State, ...]]]: # type: ignore[not-supported-yet]
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
Expand Down Expand Up @@ -1872,9 +1871,9 @@ class MergeContext:
def merge( # type: ignore[invalid-annotation]
self,
graphdef: GraphDef[A],
state: GraphState,
state: State,
/,
*states: GraphState,
*states: State,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
Expand Down Expand Up @@ -2232,11 +2231,11 @@ def _split_state(
@tp.overload
def split( # type: ignore[invalid-annotation]
graph_node: A, /, *, graph: bool | None = None,
) -> tuple[GraphDef[A], GraphState]: ...
) -> tuple[GraphDef[A], State]: ...
@tp.overload
def split( # type: ignore[invalid-annotation]
graph_node: A, first: filterlib.Filter, /, *, graph: bool | None = None,
) -> tuple[GraphDef[A], GraphState]: ...
) -> tuple[GraphDef[A], State]: ...
@tp.overload
def split( # type: ignore[invalid-annotation]
graph_node: A,
Expand All @@ -2247,15 +2246,15 @@ def split( # type: ignore[invalid-annotation]
graph: bool | None = None,
) -> tuple[
GraphDef[A],
GraphState,
tpe.Unpack[tuple[GraphState, ...]],
State,
tpe.Unpack[tuple[State, ...]],
]: ...
def split( # type: ignore[invalid-annotation]
node: A, *filters: filterlib.Filter, graph: bool | None = None,
) -> tuple[
GraphDef[A],
GraphState,
tpe.Unpack[tuple[GraphState, ...]],
State,
tpe.Unpack[tuple[State, ...]],
]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
Expand Down Expand Up @@ -2467,9 +2466,9 @@ def update(node, state: tp.Any, /, *states: tp.Any) -> None:


@tp.overload
def state(node, /, *, graph: bool | None = None) -> GraphState: ...
def state(node, /, *, graph: bool | None = None) -> State: ...
@tp.overload
def state(node, first: filterlib.Filter, /, *, graph: bool | None = None) -> GraphState: ...
def state(node, first: filterlib.Filter, /, *, graph: bool | None = None) -> State: ...
@tp.overload
def state(
node,
Expand All @@ -2478,12 +2477,12 @@ def state(
/,
*filters: filterlib.Filter,
graph: bool | None = None,
) -> tuple[GraphState, ...]: ...
) -> tuple[State, ...]: ...
def state(
node,
*filters: filterlib.Filter,
graph: bool | None = None,
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
) -> tp.Union[State, tuple[State, ...]]:
"""Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters.

Example usage::
Expand Down Expand Up @@ -2522,7 +2521,7 @@ def state(
_, flat_state = flatten(node, graph=graph)
state = flat_state.to_nested_state()

states: GraphState | tuple[GraphState, ...]
states: State | tuple[State, ...]
if len(filters) == 0:
states = state # type: ignore[assignment]
elif len(filters) == 1:
Expand Down Expand Up @@ -2611,7 +2610,7 @@ def pop(
node,
filter: filterlib.Filter,
/,
) -> GraphState: ...
) -> State: ...


@tp.overload
Expand All @@ -2621,12 +2620,12 @@ def pop(
filter2: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphState, ...]: ...
) -> tuple[State, ...]: ...


def pop(
node, *filters: filterlib.Filter
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
) -> tp.Union[State, tuple[State, ...]]:
"""Pop one or more :class:`Variable` types from the graph node.

Example usage::
Expand Down Expand Up @@ -2818,8 +2817,8 @@ def _pure_fn(x):
pure = deprecated(as_pure)

def call(
graphdef_state: tuple[GraphDef[A], GraphState], /
) -> ApplyCaller[tuple[GraphDef[A], GraphState]]:
graphdef_state: tuple[GraphDef[A], State], /
) -> ApplyCaller[tuple[GraphDef[A], State]]:
"""Calls a method underlying graph node defined by a (GraphDef, State) pair.

``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be
Expand Down
3 changes: 1 addition & 2 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from flax.nnx import variablelib as variableslib
from flax.nnx.pytreelib import Pytree, PytreeMeta
from flax.nnx.graphlib import GraphState
from flax.nnx.statelib import split_state, State
import functools as ft
from flax.typing import Key, Path, PathParts
Expand All @@ -37,7 +36,7 @@
A = tp.TypeVar('A')
B = tp.TypeVar('B')
M = tp.TypeVar('M', bound='Module')
S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]])
S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]])
V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any])
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def _custom_vjp_split_fn(
*,
nondiff_states: list[extract.GraphDefState],
):
broadcast: graphlib.GraphState
broadcast: State
if prefix is False:
# pure non-differentiable arg, not supported
raise TypeError(
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def __hash__(self):


AxisFn = tp.Callable[
[graphlib.GraphState | variablelib.Variable, int, tp.Mapping],
graphlib.GraphState | variablelib.Variable,
[State | variablelib.Variable, int, tp.Mapping],
State | variablelib.Variable,
]


Expand All @@ -241,9 +241,9 @@ def _update_axes_fn(node_states):
state = axis_fn(state, node_states.metadata, transform_metadata)
return node_states.replace(states=(state,))
else:
states_out: list[graphlib.GraphState | variablelib.Variable] = []
states_out: list[State | variablelib.Variable] = []
for state, axis in zip(node_states.states, node_states.metadata.axes):
assert isinstance(state, graphlib.State | variablelib.Variable)
assert isinstance(state, State | variablelib.Variable)
if isinstance(axis, int):
state = axis_fn(state, axis, transform_metadata)
states_out.append(state)
Expand Down
Loading