Skip to content

Scan over a module with a different number of inputs and outputs #3230

Answered by cgarciae
hr0nix asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @hr0nix, you can set in_axes to nn.broadcast to indicate that y should be the same on all steps:

class MyModule(nn.Module):
  @nn.compact
  def __call__(self, x, y):
    return x + y, None
    
StackedMyModule = nn.scan(MyModule, in_axes=nn.broadcast)

Note that you must return (carry, output), here I am returning None as the scan output.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by hr0nix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants