Skip to content

Commit e51d017

Browse files
author
Flax Authors
committed
Merge pull request #2790 from chiamp:initializers
PiperOrigin-RevId: 502764439
2 parents 71772f6 + 871638e commit e51d017

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

flax/linen/initializers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,30 @@
3434
from jax.nn.initializers import xavier_normal as xavier_normal
3535
from jax.nn.initializers import xavier_uniform as xavier_uniform
3636
from jax.nn.initializers import zeros as zeros
37+
from jax.nn.initializers import Initializer as Initializer
3738
# pylint: enable=unused-import
39+
40+
def zeros_init() -> Initializer:
41+
"""Builds an initializer that returns a constant array full of zeros.
42+
43+
>>> import jax, jax.numpy as jnp
44+
>>> from flax.linen.initializers import zeros_init
45+
>>> zeros_initializer = zeros_init()
46+
>>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
47+
Array([[0., 0., 0.],
48+
[0., 0., 0.]], dtype=float32)
49+
"""
50+
return zeros
51+
52+
def ones_init() -> Initializer:
53+
"""Builds an initializer that returns a constant array full of ones.
54+
55+
>>> import jax, jax.numpy as jnp
56+
>>> from flax.linen.initializers import ones_init
57+
>>> ones_initializer = ones_init()
58+
>>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32)
59+
Array([[1., 1.],
60+
[1., 1.],
61+
[1., 1.]], dtype=float32)
62+
"""
63+
return ones

tests/linen/initializers_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2022 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for flax.linen.initializers."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
20+
from flax import linen as nn
21+
from flax.linen.initializers import zeros_init, ones_init
22+
23+
import jax
24+
from jax import random
25+
import jax.numpy as jnp
26+
27+
import numpy as np
28+
29+
# Parse absl flags test_srcdir and test_tmpdir.
30+
jax.config.parse_flags_with_absl()
31+
32+
33+
class InitializersTest(parameterized.TestCase):
34+
35+
@parameterized.parameters(
36+
{
37+
'builder_fn': zeros_init,
38+
'params_shape': (2, 3),
39+
'expected_params': jnp.zeros((2, 3)),
40+
}, {
41+
'builder_fn': ones_init,
42+
'params_shape': (3, 2),
43+
'expected_params': jnp.ones((3, 2)),
44+
})
45+
def test_call_builder(self, builder_fn, params_shape, expected_params):
46+
params = builder_fn()(random.PRNGKey(42), params_shape, jnp.float32)
47+
np.testing.assert_allclose(params, expected_params)
48+
49+
@parameterized.parameters(
50+
{
51+
'builder_fn': zeros_init,
52+
'expected_params': jnp.zeros((2, 5)),
53+
}, {
54+
'builder_fn': ones_init,
55+
'expected_params': jnp.ones((2, 5)),
56+
})
57+
def test_kernel_builder(self, builder_fn, expected_params):
58+
layer = nn.Dense(5, kernel_init=builder_fn())
59+
params = layer.init(random.PRNGKey(42), jnp.empty((3, 2)))['params']
60+
np.testing.assert_allclose(params['kernel'], expected_params)
61+
62+
63+
if __name__ == '__main__':
64+
absltest.main()

0 commit comments

Comments
 (0)