@@ -160,7 +160,7 @@ def average_func(*args) -> PyTree:
160
160
def scan_fn (accumulator , args_ ):
161
161
vmap_value = vmap_fn (* args_ )
162
162
avg_value = jax .tree_map (lambda x : jnp .mean (x , axis = 0 ), vmap_value )
163
- return jax .tree_multimap (jnp .add , accumulator , avg_value ), None
163
+ return jax .tree_map (jnp .add , accumulator , avg_value ), None
164
164
165
165
loop_shape = (num_parallel_chunks , parallel_size )
166
166
loop_args = jax .tree_map (
@@ -435,7 +435,7 @@ def weighted_sum_of_objects(
435
435
if not abstract_objects_equal (accumulator , o_i ):
436
436
raise ValueError ("One or more objects do not have equivalent abstract "
437
437
"structure." )
438
- accumulator = jax .tree_multimap (jnp .add , accumulator , scalar_mul (o_i , c_i ))
438
+ accumulator = jax .tree_map (jnp .add , accumulator , scalar_mul (o_i , c_i ))
439
439
return accumulator
440
440
441
441
@@ -448,7 +448,7 @@ def array_ip(x, y):
448
448
return jnp .dot (x , y , precision = lax .Precision .HIGHEST )
449
449
450
450
with jax .experimental .enable_x64 ():
451
- elements_inner_products = jax .tree_multimap (array_ip , obj1 , obj2 )
451
+ elements_inner_products = jax .tree_map (array_ip , obj1 , obj2 )
452
452
flat_list = jax .tree_leaves (elements_inner_products )
453
453
result = flat_list [0 ]
454
454
for element_ip in flat_list [1 :]:
@@ -484,7 +484,7 @@ def inner_product(
484
484
raise ValueError ("The objects do not have identical abstract structure." )
485
485
if in_float64 :
486
486
return _inner_product_float64 (obj1 , obj2 )
487
- elements_product = jax .tree_multimap (lambda x , y : jnp .sum (x * y ), obj1 , obj2 )
487
+ elements_product = jax .tree_map (lambda x , y : jnp .sum (x * y ), obj1 , obj2 )
488
488
return sum (jax .tree_leaves (elements_product ))
489
489
490
490
@@ -587,7 +587,7 @@ def block_permuted(
587
587
588
588
def norm (obj : PyTree ) -> chex .Array :
589
589
"""Computes the Euclidean norm of the provided PyTree object."""
590
- elements_squared_norm = jax .tree_multimap (
590
+ elements_squared_norm = jax .tree_map (
591
591
lambda x : jnp .sum (jnp .square (x )), obj )
592
592
return jnp .sqrt (sum (jax .tree_flatten (elements_squared_norm )[0 ]))
593
593
@@ -985,7 +985,7 @@ def add(self, value_obj: PyTree, weight: chex.Numeric = 1) -> None:
985
985
raise ValueError ("The provided `value_obj` has an empty PyTree "
986
986
"structure, but the accumulator has been initialized "
987
987
"with a non-empty PyTree object." )
988
- self ._accumulator = jax .tree_multimap (
988
+ self ._accumulator = jax .tree_map (
989
989
jnp .add , self ._accumulator , value_obj )
990
990
elif not tree_is_empty (value_obj ):
991
991
raise ValueError ("The provided `value_obj` has a non-empty PyTree "
@@ -1261,7 +1261,7 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> PyTree:
1261
1261
]
1262
1262
outs .append (method (instance , * args_i ))
1263
1263
1264
- outs = jax .tree_multimap (jnp .stack , * outs )
1264
+ outs = jax .tree_map (jnp .stack , * outs )
1265
1265
1266
1266
elif instance .debug :
1267
1267
outs = method (instance , * args )
0 commit comments