Skip to content

optimizer.update fails with optax.partition #5372

@slievens

Description

@slievens

I am quite new to flax, so I am not 100% sure that I didn't make an obvious mistake.
Small example included at the bottom of this report.

System information

  • OS Platform and Distribution: Alma linux
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib): flax: 0.12.5, jax: 0.9.1, jaxlib: 0.9., optax: 0.2.6
  • Python version: 3.11.14
  • GPU/TPU model and memory:
  • CUDA version (if applicable): 13.1

Problem you have encountered:

The nnx training loop fails with I use a optax.partition as the optimizer. It works fine with another transformation. I can also use the optax.partition optimizer in a JAX training loop.

What you expected to happen:

Training loop should also work with nnx.

Logs, error messages, etc:

ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.variablelib.Param'>, value: JitTracer(float32[2,1]).

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

import jax
import jax.numpy as jnp
import flax
import flax.nnx as nnx
import optax

print("__jax_version__", jax.__version__
      , "__optax_version__", optax.__version__
      , "__flax_version__", flax.__version__)


BATCH_SIZE=8

# X is random data with two features, the target is the sum of those features + 3
X = jax.random.normal(key=jax.random.key(42), shape=(160,2))
y = jnp.sum(X, axis=1) + 3

# A very simple model with one linear layer
class SimpleModel(nnx.Module):

    def __init__(self, *, rngs):
        self.linear = nnx.Linear(in_features=2, out_features=1, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

model = SimpleModel(rngs=nnx.Rngs(0))

# Define the label function that will be used to assign labels to the parameters for partitioning
def label_fn(path, leaf):
    """ Assigns a label to a parameter based on its path"""
    param_name = path[-2].key   
    if 'bias' in param_name:
        return 'non_decay_group'

    return 'decay_group'

param_labels_pytree = jax.tree.map_with_path(
    label_fn, nnx.state(model, nnx.Param)
)

partitioned_chain = optax.partition(
    transforms={
        'decay_group' : optax.adamw(learning_rate=0.05),
        'non_decay_group' : optax.adam(learning_rate=0.05)
    },
    param_labels=param_labels_pytree
)

jax_optimizer = partitioned_chain


# Training loop with pure JAX --> This works fine
def make_train_step(graphdef, solver):
    def loss_fn(param_state, x, y):
        model = nnx.merge(graphdef, param_state)
        return jnp.mean((model(X) - y)**2)

    @jax.jit
    def train_step(param_state, opt_state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(param_state, x, y)
        updates, opt_state = solver.update(grads, opt_state, param_state)
        param_state = optax.apply_updates(param_state, updates)
        return param_state, opt_state

    return train_step

graphdef, param_state =  nnx.split(model, nnx.Param)
opt_state = jax_optimizer.init(params=param_state)
train_step = make_train_step(graphdef, jax_optimizer)


for _ in range(5):
    for _idx in range(0, X.shape[0], BATCH_SIZE):
        _x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]       
        param_state, opt_state = train_step(param_state, opt_state, _x_batch, _y_batch)
    
    print(f"MSE on train {jnp.mean((nnx.merge(graphdef, param_state)(X) - y)**2)}")

print("JAX training loop finished. Note. Model still unchanged.")


@nnx.jit
def train_step_nnx(model, optimizer, X, y):
    def loss_fn_nnx(model):
        return jnp.mean((model(X) - y)**2)

    grad_fn = nnx.value_and_grad(loss_fn_nnx)
    loss, grads = grad_fn(model)
    optimizer.update(model, grads)
    return loss

# Recreate just to make sure
partitioned_chain = optax.partition(
    transforms={
        'decay_group' : optax.adamw(learning_rate=0.05),
        'non_decay_group' : optax.adam(learning_rate=0.05)
    },
    param_labels=param_labels_pytree
)

# Choose one. Works with adamw, but not with the partitioned chain
#nnx_optimizer =  nnx.Optimizer(model, tx=optax.adamw(learning_rate=0.05), wrt=nnx.Param)
nnx_optimizer = nnx.Optimizer(model, tx=partitioned_chain, wrt=nnx.Param) 

print("Starting nnx training loop")
for _ in range(5):
    for _idx in range(0, X.shape[0], BATCH_SIZE):
        _x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]       
        loss = train_step_nnx(model, nnx_optimizer, _x_batch, _y_batch)
    
    print(f"MSE on train {jnp.mean((model(X) - y)**2)}")

print("nnx training loop finished")

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions