-
Suppose I have a module that takes some number of inputs, but returns just one output. Something like
Now suppose I want to use
Unfortunately, I cannot figure out from documentation how to achieve that with nn.scan. Any suggestions would be much appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hey @hr0nix, you can set class MyModule(nn.Module):
@nn.compact
def __call__(self, x, y):
return x + y, None
StackedMyModule = nn.scan(MyModule, in_axes=nn.broadcast) Note that you must return |
Beta Was this translation helpful? Give feedback.
Hey @hr0nix, you can set
in_axes
tonn.broadcast
to indicate thaty
should be the same on all steps:Note that you must return
(carry, output)
, here I am returningNone
as the scan output.