diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 7f2d4da26..87efdca08 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -828,9 +828,10 @@ def __call__(self, carry, inputs): A tuple with the new carry and the output. """ c, h = carry + features = c.shape[-1] input_to_hidden = partial( Conv, - features=4 * self.features, + features=4 * features, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, @@ -842,7 +843,7 @@ def __call__(self, carry, inputs): hidden_to_hidden = partial( Conv, - features=4 * self.features, + features=4 * features, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding,