Skip to content

Commit 45a8f84

Browse files
author
Flax Authors
committed
Merge pull request google#4592 from jakevdp:fix-shape
PiperOrigin-RevId: 733430128
2 parents a24d790 + 4e300d4 commit 45a8f84

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

flax/core/scope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,9 +956,9 @@ def param(
956956
# NOTE: We could check dtype consistency here as well but it's
957957
# usefuleness is less obvious. We might intentionally change the dtype
958958
# for inference to a half float type for example.
959-
if jnp.shape(val) != jnp.shape(abs_val):
959+
if np.shape(val) != np.shape(abs_val):
960960
raise errors.ScopeParamShapeError(
961-
name, self.path_text, jnp.shape(abs_val), jnp.shape(val)
961+
name, self.path_text, np.shape(abs_val), np.shape(val)
962962
)
963963
else:
964964
if not self.is_mutable_collection('params'):

flax/nnx/bridge/module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from flax.nnx.object import Object
3333
from flax.nnx import variablelib
3434
from flax.nnx.bridge import variables as bridge_variables
35-
import jax.numpy as jnp
35+
import numpy as np
3636

3737
A = tp.TypeVar('A')
3838
M = tp.TypeVar('M', bound='Module')
@@ -231,9 +231,9 @@ def param( # type: ignore[invalid-annotation]
231231
abs_value_flat = jax.tree_util.tree_leaves(abs_value)
232232
value_flat = jax.tree_util.tree_leaves(value)
233233
for val, abs_val in zip(value_flat, abs_value_flat):
234-
if jnp.shape(val) != jnp.shape(abs_val):
234+
if np.shape(val) != np.shape(abs_val):
235235
raise errors.ScopeParamShapeError(
236-
name, '', jnp.shape(abs_val), jnp.shape(val)
236+
name, '', np.shape(abs_val), np.shape(val)
237237
)
238238

239239
if isinstance(abs_value, variablelib.VariableMetadata):
@@ -282,9 +282,9 @@ def variable( # type: ignore[invalid-annotation]
282282
abs_value_flat = jax.tree_util.tree_leaves(abs_value)
283283
value_flat = jax.tree_util.tree_leaves(value)
284284
for val, abs_val in zip(value_flat, abs_value_flat):
285-
if jnp.shape(val) != jnp.shape(abs_val):
285+
if np.shape(val) != np.shape(abs_val):
286286
raise errors.ScopeParamShapeError(
287-
name, '', jnp.shape(abs_val), jnp.shape(val)
287+
name, '', np.shape(abs_val), np.shape(val)
288288
)
289289

290290
if isinstance(abs_value, variablelib.VariableMetadata):

0 commit comments

Comments
 (0)