Skip to content

Commit 6c84097

Browse files
author
Flax Authors
committed
Merge pull request #4415 from tilakrayal:patch-3
PiperOrigin-RevId: 702701233
2 parents 27b1829 + 4756b34 commit 6c84097

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

docs_nnx/guides/linen_to_nnx.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ Scan-over-layers is a technique where you run an input through a sequence of N r
531531
* Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
532532
* In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap<flax.nnx.vmap>` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan<flax.nnx.scan>` transform to run the model input through them.
533533

534-
For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.build/en/guides/transforms.html>`__.
534+
For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__.
535535

536536
.. codediff::
537537
:title: Linen, NNX

0 commit comments

Comments
 (0)