Skip to content

Commit ffba15b

Browse files
author
Flax Authors
committed
Merge pull request #2026 from jheek:remove-more-tree_multimap
PiperOrigin-RevId: 439524627
2 parents 0461745 + c2a6212 commit ffba15b

File tree

8 files changed

+44
-43
lines changed

8 files changed

+44
-43
lines changed

docs/flip/1009-optimizer-api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ Remarks:
349349
`OptimizerDef`).
350350
- The functions `init_param_state()` and `apply_param_gradient()` are called
351351
for every leaf in the params/grads pytree. This makes it possible to write the
352-
calculations directly without `jax.tree_multimap()`.
352+
calculations directly without `jax.tree_map()`.
353353
- The interface was defined in pre-Linen without the distinction of `params` vs.
354354
other collections in `variables` in mind. The original API was elegant because
355355
one only needed to pass around the optimizer, which included the parameters,

docs/notebooks/flax_basics.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@
382382
"\n",
383383
"@jax.jit\n",
384384
"def update_params(params, learning_rate, grads):\n",
385-
" params = jax.tree_multimap(\n",
385+
" params = jax.tree_map(\n",
386386
" lambda p, g: p - learning_rate * g, params, grads)\n",
387387
" return params\n",
388388
"\n",

docs/notebooks/jax_for_the_impatient.ipynb

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,4 @@
11
{
2-
"nbformat": 4,
3-
"nbformat_minor": 0,
4-
"metadata": {
5-
"colab": {
6-
"name": "JAX for the impatient.ipynb",
7-
"provenance": [],
8-
"collapsed_sections": [],
9-
"toc_visible": true
10-
},
11-
"kernelspec": {
12-
"name": "python3",
13-
"display_name": "Python 3"
14-
}
15-
},
162
"cells": [
173
{
184
"cell_type": "markdown",
@@ -64,20 +50,14 @@
6450
},
6551
{
6652
"cell_type": "code",
53+
"execution_count": 2,
6754
"metadata": {
68-
"id": "L2HKiLTNJ4Eh",
69-
"outputId": "c4297a1a-4e4b-4bdc-ca5d-3d33aca92b3b",
7055
"colab": {
7156
"base_uri": "https://localhost:8080/"
72-
}
57+
},
58+
"id": "L2HKiLTNJ4Eh",
59+
"outputId": "c4297a1a-4e4b-4bdc-ca5d-3d33aca92b3b"
7360
},
74-
"source": [
75-
"m = jnp.ones((4,4)) # We're generating one 4 by 4 matrix filled with ones.\n",
76-
"n = jnp.array([[1.0, 2.0, 3.0, 4.0],\n",
77-
" [5.0, 6.0, 7.0, 8.0]]) # An explicit 2 by 4 array\n",
78-
"m"
79-
],
80-
"execution_count": 2,
8161
"outputs": [
8262
{
8363
"name": "stderr",
@@ -99,6 +79,12 @@
9979
"metadata": {},
10080
"output_type": "execute_result"
10181
}
82+
],
83+
"source": [
84+
"m = jnp.ones((4,4)) # We're generating one 4 by 4 matrix filled with ones.\n",
85+
"n = jnp.array([[1.0, 2.0, 3.0, 4.0],\n",
86+
" [5.0, 6.0, 7.0, 8.0]]) # An explicit 2 by 4 array\n",
87+
"m"
10288
]
10389
},
10490
{
@@ -116,11 +102,11 @@
116102
"cell_type": "code",
117103
"execution_count": 3,
118104
"metadata": {
119-
"id": "9do-ZRGaRThn",
120-
"outputId": "9c4feb4d-3bd1-4921-97ce-c8087b37496f",
121105
"colab": {
122106
"base_uri": "https://localhost:8080/"
123-
}
107+
},
108+
"id": "9do-ZRGaRThn",
109+
"outputId": "9c4feb4d-3bd1-4921-97ce-c8087b37496f"
124110
},
125111
"outputs": [
126112
{
@@ -1069,7 +1055,7 @@
10691055
"id": "3s167WGKUlZ9"
10701056
},
10711057
"source": [
1072-
"A more flexible version of `tree_map` would be `tree_multimap`. Instead of applying a standalone function to each of the tree leaves, you also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function."
1058+
"Instead of applying a standalone function to each of the tree leaves, you can also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function."
10731059
]
10741060
},
10751061
{
@@ -1095,7 +1081,8 @@
10951081
}
10961082
],
10971083
"source": [
1098-
"tree_util.tree_multimap(lambda x,y: x+y, t, tree_util.tree_map(lambda x: x*x, t))"
1084+
"t2 = tree_util.tree_map(lambda x: x*x, t)\n",
1085+
"tree_util.tree_map(lambda x,y: x+y, t, t2)"
10991086
]
11001087
},
11011088
{
@@ -1199,7 +1186,7 @@
11991186
"id": "nW1IKnjqXFdN"
12001187
},
12011188
"source": [
1202-
"Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_multimap`:"
1189+
"Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_map`:"
12031190
]
12041191
},
12051192
{
@@ -1246,7 +1233,7 @@
12461233
"# Always remember to jit!\n",
12471234
"@jax.jit\n",
12481235
"def update_params_pytree(params, learning_rate, x_samples, y_samples):\n",
1249-
" params = jax.tree_multimap(\n",
1236+
" params = jax.tree_map(\n",
12501237
" lambda p, g: p - learning_rate * g, params,\n",
12511238
" jax.grad(mse_pytree)(params, x_samples, y_samples))\n",
12521239
" return params\n",
@@ -1280,7 +1267,7 @@
12801267
"for i in range(101):\n",
12811268
" # Note that here the loss is computed before the param update.\n",
12821269
" loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n",
1283-
" params = jax.tree_multimap(\n",
1270+
" params = jax.tree_map(\n",
12841271
" lambda p, g: p - learning_rate * g, params, grads)\n",
12851272
" if (i % 5 == 0):\n",
12861273
" print(f\"Loss step {i}: \", loss_val)"
@@ -1295,5 +1282,19 @@
12951282
"That's all you needed to know to get started with Flax! To dive deeper, we very much recommend checking the JAX [docs](https://jax.readthedocs.io/en/latest/index.html)."
12961283
]
12971284
}
1298-
]
1285+
],
1286+
"metadata": {
1287+
"colab": {
1288+
"collapsed_sections": [],
1289+
"name": "JAX for the impatient.ipynb",
1290+
"provenance": [],
1291+
"toc_visible": true
1292+
},
1293+
"kernelspec": {
1294+
"display_name": "Python 3",
1295+
"name": "python3"
1296+
}
1297+
},
1298+
"nbformat": 4,
1299+
"nbformat_minor": 0
12991300
}

examples/imagenet/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ def loss_fn(params):
145145
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
146146
# params should be restored (= skip this step).
147147
new_state = new_state.replace(
148-
opt_state=jax.tree_multimap(
148+
opt_state=jax.tree_map(
149149
functools.partial(jnp.where, is_fin),
150150
new_state.opt_state,
151151
state.opt_state),
152-
params=jax.tree_multimap(
152+
params=jax.tree_map(
153153
functools.partial(jnp.where, is_fin),
154154
new_state.params,
155155
state.params))

examples/linen_design_test/linear_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ def init_params(rng):
4545
loss, grad = jax.value_and_grad(loss_fn)(params)
4646
print(i, "loss = ", loss, "Yhat = ", predict(params))
4747
lr = 0.03
48-
params = jax.tree_multimap(lambda x, d: x - lr * d, params, grad)
48+
params = jax.tree_map(lambda x, d: x - lr * d, params, grad)

examples/wmt/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@ def loss_fn(params):
224224
# params should be restored (= skip this step).
225225
select_fn = functools.partial(jnp.where, is_fin)
226226
new_state = new_state.replace(
227-
opt_state=jax.tree_multimap(
227+
opt_state=jax.tree_map(
228228
select_fn, new_state.opt_state, state.opt_state),
229-
params=jax.tree_multimap(
229+
params=jax.tree_map(
230230
select_fn, new_state.params, state.params)
231231
)
232232
metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"]

tests/core/core_frozen_dict_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_frozen_dict_pop(self):
4242
self.assertEqual(unfreeze(b), {'b': {'c': 2}})
4343

4444
def test_frozen_dict_partially_maps(self):
45-
x = jax.tree_multimap(
45+
x = jax.tree_map(
4646
lambda a, b: (a, b),
4747
freeze({'a': 2}), freeze({'a': {'b': 1}}))
4848
self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})

tests/linen/linen_module_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
def tree_equals(x, y):
4646
return jax.tree_util.tree_all(
47-
jax.tree_multimap(operator.eq, x, y))
47+
jax.tree_map(operator.eq, x, y))
4848

4949

5050
class DummyModule(nn.Module):
@@ -1074,7 +1074,7 @@ def __call__(self, c, x):
10741074
},
10751075
})
10761076
self.assertTrue(jax.tree_util.tree_all(
1077-
jax.tree_multimap(
1077+
jax.tree_map(
10781078
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
10791079
cntrs, ref_cntrs)
10801080
))

0 commit comments

Comments
 (0)