Skip to content

Commit 83eebe2

Browse files
Jake VanderPlasFlax Authors
authored andcommitted
Avoid assert_array_equal for PRNG keys.
This will soon error due to jax-ml/jax#24481 PiperOrigin-RevId: 694612853
1 parent c9cebee commit 83eebe2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/linen/linen_transforms_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def __call__(self, inputs):
8686

8787
class TransformTest(parameterized.TestCase):
8888

89+
def assert_keys_equal(self, key1, key2):
90+
self.assertEqual(key1.dtype, key2.dtype)
91+
np.testing.assert_array_equal(random.key_data(key1), random.key_data(key2))
92+
8993
def test_jit(self):
9094
key1, key2 = random.split(random.key(3), 2)
9195
x = random.uniform(key1, (4, 4))
@@ -1852,7 +1856,7 @@ def f(foo: Foo):
18521856
key_jit = foo.apply({}, True, rngs={'params': random.key(0)})
18531857
key_fold_rngs = foo.apply({}, False, rngs={'params': random.key(0)})
18541858

1855-
np.testing.assert_array_equal(key_jit, key_fold_rngs)
1859+
self.assert_keys_equal(key_jit, key_fold_rngs)
18561860

18571861
def test_same_key(self):
18581862

@@ -1885,9 +1889,9 @@ def __call__(self):
18851889
keys3, _ = model.init_with_output(jax.random.key(1))
18861890
keys4, _ = model.init_with_output(jax.random.key(1))
18871891

1888-
np.testing.assert_array_equal(keys1, keys2)
1889-
np.testing.assert_array_equal(keys2, keys3)
1890-
np.testing.assert_array_equal(keys2, keys3)
1892+
self.assert_keys_equal(keys1, keys2)
1893+
self.assert_keys_equal(keys2, keys3)
1894+
self.assert_keys_equal(keys2, keys3)
18911895

18921896
def test_jit_repr_hash(self):
18931897
n = 0

0 commit comments

Comments
 (0)