Skip to content

Commit f4a35ce

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
Changing deprecated tree_multimap to tree_map.
PiperOrigin-RevId: 439855058
1 parent 7615ee7 commit f4a35ce

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

examples/optimizers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,16 @@ def tf1_rmsprop(
164164
def tf1_scale_by_rms(decay_=0.9, epsilon_=1e-8):
165165
"""Same as optax.scale_by_rms, but initializes second moment to one."""
166166
def init_fn(params):
167-
nu = jax.tree_multimap(jnp.ones_like, params) # second moment
167+
nu = jax.tree_map(jnp.ones_like, params) # second moment
168168
return optax.ScaleByRmsState(nu=nu)
169169
def _update_moment(updates, moments, decay, order):
170-
return jax.tree_multimap(
170+
return jax.tree_map(
171171
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
172172
def update_fn(updates, state, params=None):
173173
del params
174174
nu = _update_moment(updates, state.nu, decay_, 2)
175-
updates = jax.tree_multimap(lambda g, n: g / (jnp.sqrt(n + epsilon_)),
176-
updates, nu)
175+
updates = jax.tree_map(lambda g, n: g / (jnp.sqrt(n + epsilon_)),
176+
updates, nu)
177177
return updates, optax.ScaleByRmsState(nu=nu)
178178
return optax.GradientTransformation(init_fn, update_fn)
179179

kfac_jax/_src/optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def _step(
805805
update_norm = utils.norm(delta)
806806

807807
# Update parameters
808-
params = jax.tree_multimap(jnp.add, params, delta)
808+
params = jax.tree_map(jnp.add, params, delta)
809809

810810
# Optionally compute the reduction ratio and update the damping
811811
if self._use_adaptive_damping:

kfac_jax/_src/tag_graph_matcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def write_env(
678678
if isinstance(var, list):
679679
if not isinstance(val, list):
680680
val = [val]
681-
return jax.tree_multimap(lambda x, y: write_env(env, x, y), var, val)
681+
return jax.tree_map(lambda x, y: write_env(env, x, y), var, val)
682682
elif isinstance(var, (core.Literal, core.Var)):
683683
env[var] = val
684684
else:

kfac_jax/_src/utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def average_func(*args) -> PyTree:
160160
def scan_fn(accumulator, args_):
161161
vmap_value = vmap_fn(*args_)
162162
avg_value = jax.tree_map(lambda x: jnp.mean(x, axis=0), vmap_value)
163-
return jax.tree_multimap(jnp.add, accumulator, avg_value), None
163+
return jax.tree_map(jnp.add, accumulator, avg_value), None
164164

165165
loop_shape = (num_parallel_chunks, parallel_size)
166166
loop_args = jax.tree_map(
@@ -435,7 +435,7 @@ def weighted_sum_of_objects(
435435
if not abstract_objects_equal(accumulator, o_i):
436436
raise ValueError("One or more objects do not have equivalent abstract "
437437
"structure.")
438-
accumulator = jax.tree_multimap(jnp.add, accumulator, scalar_mul(o_i, c_i))
438+
accumulator = jax.tree_map(jnp.add, accumulator, scalar_mul(o_i, c_i))
439439
return accumulator
440440

441441

@@ -448,7 +448,7 @@ def array_ip(x, y):
448448
return jnp.dot(x, y, precision=lax.Precision.HIGHEST)
449449

450450
with jax.experimental.enable_x64():
451-
elements_inner_products = jax.tree_multimap(array_ip, obj1, obj2)
451+
elements_inner_products = jax.tree_map(array_ip, obj1, obj2)
452452
flat_list = jax.tree_leaves(elements_inner_products)
453453
result = flat_list[0]
454454
for element_ip in flat_list[1:]:
@@ -484,7 +484,7 @@ def inner_product(
484484
raise ValueError("The objects do not have identical abstract structure.")
485485
if in_float64:
486486
return _inner_product_float64(obj1, obj2)
487-
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
487+
elements_product = jax.tree_map(lambda x, y: jnp.sum(x * y), obj1, obj2)
488488
return sum(jax.tree_leaves(elements_product))
489489

490490

@@ -587,7 +587,7 @@ def block_permuted(
587587

588588
def norm(obj: PyTree) -> chex.Array:
589589
"""Computes the Euclidean norm of the provided PyTree object."""
590-
elements_squared_norm = jax.tree_multimap(
590+
elements_squared_norm = jax.tree_map(
591591
lambda x: jnp.sum(jnp.square(x)), obj)
592592
return jnp.sqrt(sum(jax.tree_flatten(elements_squared_norm)[0]))
593593

@@ -985,7 +985,7 @@ def add(self, value_obj: PyTree, weight: chex.Numeric = 1) -> None:
985985
raise ValueError("The provided `value_obj` has an empty PyTree "
986986
"structure, but the accumulator has been initialized "
987987
"with a non-empty PyTree object.")
988-
self._accumulator = jax.tree_multimap(
988+
self._accumulator = jax.tree_map(
989989
jnp.add, self._accumulator, value_obj)
990990
elif not tree_is_empty(value_obj):
991991
raise ValueError("The provided `value_obj` has a non-empty PyTree "
@@ -1261,7 +1261,7 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> PyTree:
12611261
]
12621262
outs.append(method(instance, *args_i))
12631263

1264-
outs = jax.tree_multimap(jnp.stack, *outs)
1264+
outs = jax.tree_map(jnp.stack, *outs)
12651265

12661266
elif instance.debug:
12671267
outs = method(instance, *args)

tests/test_tracer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def compare_multi_batch(
102102
for d in (data1, data2):
103103
outputs.append(func(d))
104104
if combine == "concatenate":
105-
outputs = jax.tree_multimap(
105+
outputs = jax.tree_map(
106106
lambda x, y: jnp.concatenate([x, y], axis=0), *outputs)
107107
elif combine == "sum":
108-
outputs = jax.tree_multimap(lambda x, y: x + y, *outputs)
108+
outputs = jax.tree_map(lambda x, y: x + y, *outputs)
109109
else:
110110
raise NotImplementedError()
111111

0 commit comments

Comments
 (0)