14
14
15
15
"""Linear modules."""
16
16
17
- import abc
18
17
import dataclasses
19
18
from typing import (Any , Callable , Iterable , List , Optional , Sequence , Tuple ,
20
19
Union )
31
30
from jax import ShapedArray
32
31
import jax .numpy as jnp
33
32
import numpy as np
33
+ import jax
34
34
35
35
36
36
PRNGKey = Any
@@ -107,12 +107,11 @@ def __call__(self, inputs: Array) -> Array:
107
107
n_axis , n_features = len (axis ), len (features )
108
108
109
109
def kernel_init_wrap (rng , shape , dtype = jnp .float32 ):
110
- size_batch_dims = np .prod (shape [:n_batch_dims ], dtype = np . int32 )
111
- flat_shape = ( np .prod (shape [n_batch_dims :n_axis + n_batch_dims ]),
110
+ flat_shape = ( np .prod (shape [:n_batch_dims ]) *
111
+ np .prod (shape [n_batch_dims :n_axis + n_batch_dims ]),
112
112
np .prod (shape [- n_features :]),)
113
- kernel = jnp .concatenate (
114
- [self .kernel_init (rng , flat_shape , dtype )
115
- for rng in random .split (rng , size_batch_dims )], axis = 0 )
113
+ flat_shape = jax .tree_map (int , flat_shape )
114
+ kernel = self .kernel_init (rng , flat_shape , dtype )
116
115
return jnp .reshape (kernel , shape )
117
116
118
117
batch_shape = tuple (inputs .shape [ax ] for ax in batch_dims )
@@ -129,11 +128,10 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
129
128
130
129
if self .use_bias :
131
130
def bias_init_wrap (rng , shape , dtype = jnp .float32 ):
132
- size_batch_dims = np .prod (shape [:n_batch_dims ], dtype = np .int32 )
133
- flat_shape = (np .prod (shape [- n_features :]),)
134
- bias = jnp .concatenate (
135
- [self .bias_init (rng , flat_shape , dtype )
136
- for rng in random .split (rng , size_batch_dims )], axis = 0 )
131
+ flat_shape = (np .prod (shape [:n_batch_dims ]) *
132
+ np .prod (shape [- n_features :]),)
133
+ flat_shape = jax .tree_map (int , flat_shape )
134
+ bias = self .bias_init (rng , flat_shape , dtype )
137
135
return jnp .reshape (bias , shape )
138
136
139
137
bias = self .param ('bias' , bias_init_wrap , batch_shape + features ,
0 commit comments