Skip to content

Commit 4f24933

Browse files
author
Flax Authors
committed
Merge pull request #2729 from levskaya:queue_runner_fix
PiperOrigin-RevId: 495767913
2 parents 4669ea0 + eb9cecf commit 4f24933

File tree

2 files changed

+1
-25
lines changed

2 files changed

+1
-25
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
tests:
8080
name: Run Tests
8181
needs: [cancel-previous, pre-commit, commit-count, test-import]
82-
runs-on: ubuntu-20.04-16core
82+
runs-on: ubuntu-latest
8383
strategy:
8484
matrix:
8585
python-version: ['3.8', '3.9', '3.10']

tests/checkpoints_test.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -374,30 +374,6 @@ def test_auto_restore(self):
374374
os.path.join(tmp_dir, 'test_1'), target=target)
375375
check_eq(restored, to_save)
376376

377-
378-
# This is for fully addressable JAX arrays. For multiprocess JAX arrays like
379-
# GDA, see multihost_test.py (internal only)
380-
@parameterized.parameters({'use_orbax': True, 'jax_array_config': True},
381-
{'use_orbax': False, 'jax_array_config': False})
382-
def test_jax_array(self, use_orbax, jax_array_config):
383-
config.flax_use_orbax_checkpointing = use_orbax
384-
jax.config.update('jax_array', jax_array_config)
385-
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
386-
test_object0 = {'a': jnp.zeros(3), 'b': jnp.arange(3)}
387-
test_object1 = {'a': jnp.ones(3), 'b': jnp.arange(3, 6)}
388-
new_object = checkpoints.restore_checkpoint(
389-
tmp_dir, test_object0, prefix='test_')
390-
check_eq(new_object, test_object0)
391-
# Create leftover temporary checkpoint, which should be ignored.
392-
io.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
393-
checkpoints.save_checkpoint(
394-
tmp_dir, test_object1, 0, prefix='test_', keep=1)
395-
self.assertIn('test_0', os.listdir(tmp_dir))
396-
new_object = checkpoints.restore_checkpoint(
397-
tmp_dir, test_object0, prefix='test_')
398-
check_eq(new_object, {'a': np.ones(3), 'b': np.arange(3, 6)})
399-
400-
401377
def test_convert_pre_linen(self):
402378
params = checkpoints.convert_pre_linen({
403379
'mod_0': {

0 commit comments

Comments
 (0)