Skip to content

nnx.clone creates buffer copies with the same IDs, causing errors with donate_argnums #5461

@johnlyzhou

Description

@johnlyzhou

There are a number of cases where I would like to use nnx.clone to create a copy of a network (e.g., target networks in RL) inside an nnx.Module. This works fine in general, but I've found that combining this with jax.jit's donate_argnums option on the module state causes errors related to donating the same buffer twice,

I've provided a minimal reproducible example below, with assertions showing that the ID of the nnx.clone'd buffer is the same as the original:

JAX version 0.6.2, Flax version 0.10.7

import functools

import jax
from flax import nnx

class TestModule(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear = nnx.Linear(in_features=4, out_features=2, rngs=rngs)
        self.linear_copy = nnx.clone(self.linear)

rngs = nnx.Rngs(0)
test_module = TestModule(rngs)

linear_state = nnx.state(test_module.linear)
linear_copy_state = nnx.state(test_module.linear_copy)

online_kernel_array = linear_state.kernel.value
target_kernel_array = linear_copy_state.kernel.value

assert id(linear_state.kernel.value) == id(linear_copy_state.kernel.value)
assert online_kernel_array.unsafe_buffer_pointer() == target_kernel_array.unsafe_buffer_pointer()

# Test donate_argnums with nnx.clone'd layers.
@functools.partial(jax.jit, donate_argnums=(0,))
def update(state_pytree):
    return jax.tree.map(lambda x: x + 0.01, state_pytree)

state = nnx.state(test_module)
state = update(state)

This yields the error

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/johnzhou/Library/Application Support/JetBrains/PyCharm2025.2/scratches/clone_test.py", line 29, in <module>
    state = update(state)
jaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Attempt to donate the same buffer twice in Execute() (flattened argument 2, replica 0, partition 0, first use: 0). Toy example for this bug: `f(donate(a), donate(a))`.

Either using copy.deepcopy or something like

def clone(module: nnx.Module):
    graphdef, state = nnx.split(module)
    state_copy = jax.tree_util.tree_map(jnp.copy, state)
    return nnx.merge(graphdef, state_copy)

works fine. I'm not sure if this is expected behavior, but noting this behavior in the nnx.clone docs might be useful for users?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions