|
32 | 32 | from flax.nnx.object import Object
|
33 | 33 | from flax.nnx import variablelib
|
34 | 34 | from flax.nnx.bridge import variables as bridge_variables
|
35 |
| -import jax.numpy as jnp |
| 35 | +import numpy as np |
36 | 36 |
|
37 | 37 | A = tp.TypeVar('A')
|
38 | 38 | M = tp.TypeVar('M', bound='Module')
|
@@ -231,9 +231,9 @@ def param( # type: ignore[invalid-annotation]
|
231 | 231 | abs_value_flat = jax.tree_util.tree_leaves(abs_value)
|
232 | 232 | value_flat = jax.tree_util.tree_leaves(value)
|
233 | 233 | for val, abs_val in zip(value_flat, abs_value_flat):
|
234 |
| - if jnp.shape(val) != jnp.shape(abs_val): |
| 234 | + if np.shape(val) != np.shape(abs_val): |
235 | 235 | raise errors.ScopeParamShapeError(
|
236 |
| - name, '', jnp.shape(abs_val), jnp.shape(val) |
| 236 | + name, '', np.shape(abs_val), np.shape(val) |
237 | 237 | )
|
238 | 238 |
|
239 | 239 | if isinstance(abs_value, variablelib.VariableMetadata):
|
@@ -282,9 +282,9 @@ def variable( # type: ignore[invalid-annotation]
|
282 | 282 | abs_value_flat = jax.tree_util.tree_leaves(abs_value)
|
283 | 283 | value_flat = jax.tree_util.tree_leaves(value)
|
284 | 284 | for val, abs_val in zip(value_flat, abs_value_flat):
|
285 |
| - if jnp.shape(val) != jnp.shape(abs_val): |
| 285 | + if np.shape(val) != np.shape(abs_val): |
286 | 286 | raise errors.ScopeParamShapeError(
|
287 |
| - name, '', jnp.shape(abs_val), jnp.shape(val) |
| 287 | + name, '', np.shape(abs_val), np.shape(val) |
288 | 288 | )
|
289 | 289 |
|
290 | 290 | if isinstance(abs_value, variablelib.VariableMetadata):
|
|
0 commit comments