I am quite new to flax, so I am not 100% sure that I didn't make an obvious mistake.
Small example included at the bottom of this report.
System information
- OS Platform and Distribution: Alma linux
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib): flax: 0.12.5, jax: 0.9.1, jaxlib: 0.9., optax: 0.2.6
- Python version: 3.11.14
- GPU/TPU model and memory:
- CUDA version (if applicable): 13.1
Problem you have encountered:
The nnx training loop fails with I use a optax.partition as the optimizer. It works fine with another transformation. I can also use the optax.partition optimizer in a JAX training loop.
What you expected to happen:
Training loop should also work with nnx.
Logs, error messages, etc:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.variablelib.Param'>, value: JitTracer(float32[2,1]).
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
import jax
import jax.numpy as jnp
import flax
import flax.nnx as nnx
import optax
print("__jax_version__", jax.__version__
, "__optax_version__", optax.__version__
, "__flax_version__", flax.__version__)
BATCH_SIZE=8
# X is random data with two features, the target is the sum of those features + 3
X = jax.random.normal(key=jax.random.key(42), shape=(160,2))
y = jnp.sum(X, axis=1) + 3
# A very simple model with one linear layer
class SimpleModel(nnx.Module):
def __init__(self, *, rngs):
self.linear = nnx.Linear(in_features=2, out_features=1, rngs=rngs)
def __call__(self, x):
return self.linear(x)
model = SimpleModel(rngs=nnx.Rngs(0))
# Define the label function that will be used to assign labels to the parameters for partitioning
def label_fn(path, leaf):
""" Assigns a label to a parameter based on its path"""
param_name = path[-2].key
if 'bias' in param_name:
return 'non_decay_group'
return 'decay_group'
param_labels_pytree = jax.tree.map_with_path(
label_fn, nnx.state(model, nnx.Param)
)
partitioned_chain = optax.partition(
transforms={
'decay_group' : optax.adamw(learning_rate=0.05),
'non_decay_group' : optax.adam(learning_rate=0.05)
},
param_labels=param_labels_pytree
)
jax_optimizer = partitioned_chain
# Training loop with pure JAX --> This works fine
def make_train_step(graphdef, solver):
def loss_fn(param_state, x, y):
model = nnx.merge(graphdef, param_state)
return jnp.mean((model(X) - y)**2)
@jax.jit
def train_step(param_state, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(param_state, x, y)
updates, opt_state = solver.update(grads, opt_state, param_state)
param_state = optax.apply_updates(param_state, updates)
return param_state, opt_state
return train_step
graphdef, param_state = nnx.split(model, nnx.Param)
opt_state = jax_optimizer.init(params=param_state)
train_step = make_train_step(graphdef, jax_optimizer)
for _ in range(5):
for _idx in range(0, X.shape[0], BATCH_SIZE):
_x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]
param_state, opt_state = train_step(param_state, opt_state, _x_batch, _y_batch)
print(f"MSE on train {jnp.mean((nnx.merge(graphdef, param_state)(X) - y)**2)}")
print("JAX training loop finished. Note. Model still unchanged.")
@nnx.jit
def train_step_nnx(model, optimizer, X, y):
def loss_fn_nnx(model):
return jnp.mean((model(X) - y)**2)
grad_fn = nnx.value_and_grad(loss_fn_nnx)
loss, grads = grad_fn(model)
optimizer.update(model, grads)
return loss
# Recreate just to make sure
partitioned_chain = optax.partition(
transforms={
'decay_group' : optax.adamw(learning_rate=0.05),
'non_decay_group' : optax.adam(learning_rate=0.05)
},
param_labels=param_labels_pytree
)
# Choose one. Works with adamw, but not with the partitioned chain
#nnx_optimizer = nnx.Optimizer(model, tx=optax.adamw(learning_rate=0.05), wrt=nnx.Param)
nnx_optimizer = nnx.Optimizer(model, tx=partitioned_chain, wrt=nnx.Param)
print("Starting nnx training loop")
for _ in range(5):
for _idx in range(0, X.shape[0], BATCH_SIZE):
_x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]
loss = train_step_nnx(model, nnx_optimizer, _x_batch, _y_batch)
print(f"MSE on train {jnp.mean((model(X) - y)**2)}")
print("nnx training loop finished")
I am quite new to flax, so I am not 100% sure that I didn't make an obvious mistake.
Small example included at the bottom of this report.
System information
pip show flax jax jaxlib): flax: 0.12.5, jax: 0.9.1, jaxlib: 0.9., optax: 0.2.6Problem you have encountered:
The nnx training loop fails with I use a optax.partition as the optimizer. It works fine with another transformation. I can also use the optax.partition optimizer in a JAX training loop.
What you expected to happen:
Training loop should also work with nnx.
Logs, error messages, etc:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.variablelib.Param'>, value: JitTracer(float32[2,1]).Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.