@@ -374,30 +374,6 @@ def test_auto_restore(self):
374
374
os .path .join (tmp_dir , 'test_1' ), target = target )
375
375
check_eq (restored , to_save )
376
376
377
-
378
- # This is for fully addressable JAX arrays. For multiprocess JAX arrays like
379
- # GDA, see multihost_test.py (internal only)
380
- @parameterized .parameters ({'use_orbax' : True , 'jax_array_config' : True },
381
- {'use_orbax' : False , 'jax_array_config' : False })
382
- def test_jax_array (self , use_orbax , jax_array_config ):
383
- config .flax_use_orbax_checkpointing = use_orbax
384
- jax .config .update ('jax_array' , jax_array_config )
385
- tmp_dir = pathlib .Path (self .create_tempdir ().full_path )
386
- test_object0 = {'a' : jnp .zeros (3 ), 'b' : jnp .arange (3 )}
387
- test_object1 = {'a' : jnp .ones (3 ), 'b' : jnp .arange (3 , 6 )}
388
- new_object = checkpoints .restore_checkpoint (
389
- tmp_dir , test_object0 , prefix = 'test_' )
390
- check_eq (new_object , test_object0 )
391
- # Create leftover temporary checkpoint, which should be ignored.
392
- io .GFile (os .path .join (tmp_dir , 'test_tmp' ), 'w' )
393
- checkpoints .save_checkpoint (
394
- tmp_dir , test_object1 , 0 , prefix = 'test_' , keep = 1 )
395
- self .assertIn ('test_0' , os .listdir (tmp_dir ))
396
- new_object = checkpoints .restore_checkpoint (
397
- tmp_dir , test_object0 , prefix = 'test_' )
398
- check_eq (new_object , {'a' : np .ones (3 ), 'b' : np .arange (3 , 6 )})
399
-
400
-
401
377
def test_convert_pre_linen (self ):
402
378
params = checkpoints .convert_pre_linen ({
403
379
'mod_0' : {
0 commit comments