Replies: 3 comments 2 replies
-
Hey @ymahlau, thanks for bringing this up!
Agreed, we've been thinking about this and have a plan that will allow NNX to be used with raw JAX transformations and other libraries by making NNX objects implement the pytree protocol. Will publish an announcement soon-ish. |
Beta Was this translation helpful? Give feedback.
-
Related: #4431. |
Beta Was this translation helpful? Give feedback.
-
Any update on this new functionality to allow NNX to work with raw JAX transformations and other libraries? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I would like to start a discussion about the design descisions of the nnx library. Specifically, I asked myself why reference sharing and mutability was introduced as this kind of goes against the design of JAX itself.
The only reason I can see to try to support mutability and reference sharing within JAX is to enable a more "pytorch-like" interface. And NNX has definitely suceeded with this. In my opinion, the implementing models with nnx feels very smooth and avoids a lot of the annoyences of flax.linen.
However, the problems start when trying to integrate nnx modules with more complicated jax transformations and existing libraries. This can also be seen by many of the current issues in this repo.
As another point, mutability and reference sharing are not strictly necessary to hold the state inside the model itself, as demonstrated by the serket library.
Therefore my question would be what others think the advantages and disadvantages of mutability and reference sharing are. Also, maybe some developers could give more information about the design decisions they had to face.
Beta Was this translation helpful? Give feedback.
All reactions