Skip to content

Commit 595e711

Browse files
author
Flax Authors
committed
Merge pull request #4472 from google:nnx-fix-fori
PiperOrigin-RevId: 712665956
2 parents 53bde74 + f8164dd commit 595e711

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

tests/jax_utils_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for flax.jax_utils."""
1616

1717
from functools import partial
18+
import os
19+
import re
1820

1921
from absl.testing import absltest
2022
from absl.testing import parameterized
@@ -26,9 +28,21 @@
2628

2729
NDEV = 4
2830

31+
_xla_device_count_flag_regexp = (
32+
r'[-]{0,2}xla_force_host_platform_device_count=(\d+)?(\s|$)'
33+
)
34+
35+
36+
def set_n_cpu_devices(n: int):
37+
xla_flags = os.getenv('XLA_FLAGS', '')
38+
xla_flags = re.sub(_xla_device_count_flag_regexp, '', xla_flags)
39+
os.environ['XLA_FLAGS'] = ' '.join(
40+
[f'--xla_force_host_platform_device_count={n}'] + xla_flags.split()
41+
)
42+
2943

3044
def setUpModule():
31-
chex.set_n_cpu_devices(NDEV)
45+
set_n_cpu_devices(NDEV)
3246

3347

3448
class PadShardUnpadTest(chex.TestCase):

tests/nnx/transforms_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,6 +2985,18 @@ def loop_fn(inputs):
29852985
nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2))
29862986
nnx.fori_loop(0, 2, fori_loop_fn, (a, b))
29872987

2988+
def test_fori_output(self):
2989+
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
2990+
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))
2991+
2992+
def f(i, x):
2993+
return x
2994+
2995+
model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2))
2996+
2997+
self.assertIs(model, model_out)
2998+
self.assertIs(model2, model2_out)
2999+
29883000

29893001
class TestSplitMergeInputs(absltest.TestCase):
29903002
def test_split_inputs(self):

0 commit comments

Comments
 (0)