There are a number of cases where I would like to use nnx.clone to create a copy of a network (e.g., target networks in RL) inside an nnx.Module. This works fine in general, but I've found that combining this with jax.jit's donate_argnums option on the module state causes errors related to donating the same buffer twice,
I've provided a minimal reproducible example below, with assertions showing that the ID of the nnx.clone'd buffer is the same as the original:
JAX version 0.6.2, Flax version 0.10.7
import functools
import jax
from flax import nnx
class TestModule(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features=4, out_features=2, rngs=rngs)
self.linear_copy = nnx.clone(self.linear)
rngs = nnx.Rngs(0)
test_module = TestModule(rngs)
linear_state = nnx.state(test_module.linear)
linear_copy_state = nnx.state(test_module.linear_copy)
online_kernel_array = linear_state.kernel.value
target_kernel_array = linear_copy_state.kernel.value
assert id(linear_state.kernel.value) == id(linear_copy_state.kernel.value)
assert online_kernel_array.unsafe_buffer_pointer() == target_kernel_array.unsafe_buffer_pointer()
# Test donate_argnums with nnx.clone'd layers.
@functools.partial(jax.jit, donate_argnums=(0,))
def update(state_pytree):
return jax.tree.map(lambda x: x + 0.01, state_pytree)
state = nnx.state(test_module)
state = update(state)
This yields the error
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/johnzhou/Library/Application Support/JetBrains/PyCharm2025.2/scratches/clone_test.py", line 29, in <module>
state = update(state)
jaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Attempt to donate the same buffer twice in Execute() (flattened argument 2, replica 0, partition 0, first use: 0). Toy example for this bug: `f(donate(a), donate(a))`.
Either using copy.deepcopy or something like
def clone(module: nnx.Module):
graphdef, state = nnx.split(module)
state_copy = jax.tree_util.tree_map(jnp.copy, state)
return nnx.merge(graphdef, state_copy)
works fine. I'm not sure if this is expected behavior, but noting this behavior in the nnx.clone docs might be useful for users?
There are a number of cases where I would like to use nnx.clone to create a copy of a network (e.g., target networks in RL) inside an nnx.Module. This works fine in general, but I've found that combining this with jax.jit's donate_argnums option on the module state causes errors related to donating the same buffer twice,
I've provided a minimal reproducible example below, with assertions showing that the ID of the nnx.clone'd buffer is the same as the original:
JAX version 0.6.2, Flax version 0.10.7
This yields the error
Either using
copy.deepcopyor something likeworks fine. I'm not sure if this is expected behavior, but noting this behavior in the nnx.clone docs might be useful for users?