@@ -86,6 +86,10 @@ def __call__(self, inputs):
86
86
87
87
class TransformTest (parameterized .TestCase ):
88
88
89
+ def assert_keys_equal (self , key1 , key2 ):
90
+ self .assertEqual (key1 .dtype , key2 .dtype )
91
+ np .testing .assert_array_equal (random .key_data (key1 ), random .key_data (key2 ))
92
+
89
93
def test_jit (self ):
90
94
key1 , key2 = random .split (random .key (3 ), 2 )
91
95
x = random .uniform (key1 , (4 , 4 ))
@@ -1852,7 +1856,7 @@ def f(foo: Foo):
1852
1856
key_jit = foo .apply ({}, True , rngs = {'params' : random .key (0 )})
1853
1857
key_fold_rngs = foo .apply ({}, False , rngs = {'params' : random .key (0 )})
1854
1858
1855
- np . testing . assert_array_equal (key_jit , key_fold_rngs )
1859
+ self . assert_keys_equal (key_jit , key_fold_rngs )
1856
1860
1857
1861
def test_same_key (self ):
1858
1862
@@ -1885,9 +1889,9 @@ def __call__(self):
1885
1889
keys3 , _ = model .init_with_output (jax .random .key (1 ))
1886
1890
keys4 , _ = model .init_with_output (jax .random .key (1 ))
1887
1891
1888
- np . testing . assert_array_equal (keys1 , keys2 )
1889
- np . testing . assert_array_equal (keys2 , keys3 )
1890
- np . testing . assert_array_equal (keys2 , keys3 )
1892
+ self . assert_keys_equal (keys1 , keys2 )
1893
+ self . assert_keys_equal (keys2 , keys3 )
1894
+ self . assert_keys_equal (keys2 , keys3 )
1891
1895
1892
1896
def test_jit_repr_hash (self ):
1893
1897
n = 0
0 commit comments