diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6e9a60011..0e1f7fa21 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -1026,6 +1026,19 @@ def _helper(self): with self.assertRaises(errors.TransformedMethodReturnValueError): b.apply({}, jnp.ones(2)) + def test_returned_variable_warning(self): + class Bar(nn.Module): + @nn.compact + def __call__(self, x): + f = self._helper() + return f(x) + @nn.jit + def _helper(self): + return nn.Variable(None, None, None) + b = Bar() + with self.assertRaises(errors.TransformedMethodReturnValueError): + b.apply({}, jnp.ones(2)) + def test_nowrap(self): class Bar(nn.Module): @nn.compact