Skip to content

Commit 8a9fbc9

Browse files
author
Flax Authors
committed
Merge pull request #2701 from cgarciae:cleanup-dense-general
PiperOrigin-RevId: 495114816
2 parents 1d678b5 + bf60167 commit 8a9fbc9

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

flax/linen/linear.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Linear modules."""
1616

17-
import abc
1817
import dataclasses
1918
from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
2019
Union)
@@ -31,6 +30,7 @@
3130
from jax import ShapedArray
3231
import jax.numpy as jnp
3332
import numpy as np
33+
import jax
3434

3535

3636
PRNGKey = Any
@@ -107,12 +107,11 @@ def __call__(self, inputs: Array) -> Array:
107107
n_axis, n_features = len(axis), len(features)
108108

109109
def kernel_init_wrap(rng, shape, dtype=jnp.float32):
110-
size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
111-
flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
110+
flat_shape = (np.prod(shape[:n_batch_dims]) *
111+
np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
112112
np.prod(shape[-n_features:]),)
113-
kernel = jnp.concatenate(
114-
[self.kernel_init(rng, flat_shape, dtype)
115-
for rng in random.split(rng, size_batch_dims)], axis=0)
113+
flat_shape = jax.tree_map(int, flat_shape)
114+
kernel = self.kernel_init(rng, flat_shape, dtype)
116115
return jnp.reshape(kernel, shape)
117116

118117
batch_shape = tuple(inputs.shape[ax] for ax in batch_dims)
@@ -129,11 +128,10 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
129128

130129
if self.use_bias:
131130
def bias_init_wrap(rng, shape, dtype=jnp.float32):
132-
size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
133-
flat_shape = (np.prod(shape[-n_features:]),)
134-
bias = jnp.concatenate(
135-
[self.bias_init(rng, flat_shape, dtype)
136-
for rng in random.split(rng, size_batch_dims)], axis=0)
131+
flat_shape = (np.prod(shape[:n_batch_dims]) *
132+
np.prod(shape[-n_features:]),)
133+
flat_shape = jax.tree_map(int, flat_shape)
134+
bias = self.bias_init(rng, flat_shape, dtype)
137135
return jnp.reshape(bias, shape)
138136

139137
bias = self.param('bias', bias_init_wrap, batch_shape + features,

tests/linen/linen_linear_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def test_dense_is_dense_general(self):
7777
use_bias=True,
7878
bias_init=initializers.normal(),
7979
)
80-
y1, params = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x)
80+
y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x)
8181
dg_module = nn.DenseGeneral(
8282
features=4,
8383
use_bias=True,
8484
bias_init=initializers.normal(),
8585
)
86-
y2 = dg_module.apply(params, x)
86+
y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x)
8787

8888
np.testing.assert_allclose(y1, y2)
8989

@@ -141,8 +141,7 @@ def _counter_init(rng, shape, dtype, state):
141141
kernel_init=counter_init,
142142
)
143143
y, _ = dg_module.init_with_output(rng, x)
144-
target = np.concatenate(
145-
[np.full((1, 1, 7), 16.), np.full((1, 1, 7), 31.)], axis=0)
144+
target = np.full((2, 1, 7), 16.)
146145
np.testing.assert_allclose(y, target)
147146

148147
@parameterized.parameters([((-2, 3), (), 'bijk,jklm->bilm'),

0 commit comments

Comments
 (0)