1
+ # Copyright 2022 The Flax Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for flax.linen.initializers."""
16
+
17
+ from absl .testing import absltest
18
+ from absl .testing import parameterized
19
+
20
+ from flax import linen as nn
21
+ from flax .linen .initializers import zeros_init , ones_init
22
+
23
+ import jax
24
+ from jax import random
25
+ import jax .numpy as jnp
26
+
27
+ import numpy as np
28
+
29
+ # Parse absl flags test_srcdir and test_tmpdir.
30
+ jax .config .parse_flags_with_absl ()
31
+
32
+
33
+ class InitializersTest (parameterized .TestCase ):
34
+
35
+ @parameterized .parameters (
36
+ {
37
+ 'builder_fn' : zeros_init ,
38
+ 'params_shape' : (2 , 3 ),
39
+ 'expected_params' : jnp .zeros ((2 , 3 )),
40
+ }, {
41
+ 'builder_fn' : ones_init ,
42
+ 'params_shape' : (3 , 2 ),
43
+ 'expected_params' : jnp .ones ((3 , 2 )),
44
+ })
45
+ def test_call_builder (self , builder_fn , params_shape , expected_params ):
46
+ params = builder_fn ()(random .PRNGKey (42 ), params_shape , jnp .float32 )
47
+ np .testing .assert_allclose (params , expected_params )
48
+
49
+ @parameterized .parameters (
50
+ {
51
+ 'builder_fn' : zeros_init ,
52
+ 'expected_params' : jnp .zeros ((2 , 5 )),
53
+ }, {
54
+ 'builder_fn' : ones_init ,
55
+ 'expected_params' : jnp .ones ((2 , 5 )),
56
+ })
57
+ def test_kernel_builder (self , builder_fn , expected_params ):
58
+ layer = nn .Dense (5 , kernel_init = builder_fn ())
59
+ params = layer .init (random .PRNGKey (42 ), jnp .empty ((3 , 2 )))['params' ]
60
+ np .testing .assert_allclose (params ['kernel' ], expected_params )
61
+
62
+
63
+ if __name__ == '__main__' :
64
+ absltest .main ()
0 commit comments