The following throws the error "TypeError: iteration over a 0-d array":
from flax import nnx
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self, x, y):
self.x = nnx.Param(x)
self.y = nnx.Param(y)
@nnx.custom_vjp(graph=False, graph_updates=False)
def f(m: Foo):
return jnp.sin(m.x) * m.y
def f_fwd(m: Foo):
return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
def f_bwd(res, g):
input_updates_g, out_g = g
cos_x, sin_x, m = res
(m_updates_g,) = input_updates_g
m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
m_g['x'][...] = cos_x * out_g * m.y
m_g['y'][...] = sin_x * out_g
return (m_g,)
f.defvjp(f_fwd, f_bwd)
def test_it():
m = Foo(x=jnp.array(1.), y=jnp.array(2.))
grads = nnx.grad(f, graph=False, graph_updates=False)(m)
The following throws the error "TypeError: iteration over a 0-d array":