diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index db5880824..1866025db 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -142,7 +142,7 @@ NNX Modules are normal python classes, they obey regular python semantics such a ```python class Foo(nnx.Module): - def __init__(self, rngs: nnx.Rngs): + def __init__(self, *, rngs: nnx.Rngs): # node attributes self.param = nnx.Param(jnp.array(1)) self.submodule = nnx.Linear(12, 3, rngs=rngs)