Skip to content

Added builder functions for zeros and ones initializers #2790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions flax/linen/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,30 @@
from jax.nn.initializers import xavier_normal as xavier_normal
from jax.nn.initializers import xavier_uniform as xavier_uniform
from jax.nn.initializers import zeros as zeros
from jax.nn.initializers import Initializer as Initializer
# pylint: enable=unused-import

def zeros_init() -> Initializer:
"""Builds an initializer that returns a constant array full of zeros.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
return zeros

def ones_init() -> Initializer:
"""Builds an initializer that returns a constant array full of ones.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32)
Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
"""
return ones
64 changes: 64 additions & 0 deletions tests/linen/initializers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for flax.linen.initializers."""

from absl.testing import absltest
from absl.testing import parameterized

from flax import linen as nn
from flax.linen.initializers import zeros_init, ones_init

import jax
from jax import random
import jax.numpy as jnp

import numpy as np

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


class InitializersTest(parameterized.TestCase):

@parameterized.parameters(
{
'builder_fn': zeros_init,
'params_shape': (2, 3),
'expected_params': jnp.zeros((2, 3)),
}, {
'builder_fn': ones_init,
'params_shape': (3, 2),
'expected_params': jnp.ones((3, 2)),
})
def test_call_builder(self, builder_fn, params_shape, expected_params):
params = builder_fn()(random.PRNGKey(42), params_shape, jnp.float32)
np.testing.assert_allclose(params, expected_params)

@parameterized.parameters(
{
'builder_fn': zeros_init,
'expected_params': jnp.zeros((2, 5)),
}, {
'builder_fn': ones_init,
'expected_params': jnp.ones((2, 5)),
})
def test_kernel_builder(self, builder_fn, expected_params):
layer = nn.Dense(5, kernel_init=builder_fn())
params = layer.init(random.PRNGKey(42), jnp.empty((3, 2)))['params']
np.testing.assert_allclose(params['kernel'], expected_params)


if __name__ == '__main__':
absltest.main()