24
24
import os
25
25
import pathlib
26
26
import re
27
+ import time
27
28
from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
28
29
29
30
from absl import logging
34
35
from flax import serialization
35
36
from flax import traverse_util
36
37
import jax
38
+ from jax import monitoring
37
39
from jax import process_index
38
40
from jax import sharding
39
41
from jax .experimental .global_device_array import GlobalDeviceArray
40
42
from jax .experimental .multihost_utils import sync_global_devices
41
43
import orbax .checkpoint as orbax
42
44
43
-
45
+ _READ_CHECKPOINT_EVENT : str = '/jax/checkpoint/read/durations_sec'
46
+ _WRITE_CHECKPOINT_EVENT : str = '/jax/checkpoint/write/durations_sec'
44
47
_IMPORT_GDAM_SUCCESSFUL = False
45
48
try :
46
49
from jax .experimental .gda_serialization .serialization import get_tensorstore_spec
@@ -193,6 +196,7 @@ def _make_mpa_dirs(mpa_targets: List[Tuple[MultiprocessArrayType, str]],
193
196
def _save_mpas (gda_manager , mpa_targets : List [Tuple [MultiprocessArrayType , str ]],
194
197
tmp_path : str , final_path : str , base_path : str , keep : int ,
195
198
overwrite : bool , keep_every_n_steps : Optional [int ],
199
+ ckpt_start_time : float ,
196
200
async_manager : Optional [AsyncManager ] = None ):
197
201
"""Save the multiprocess arrays given the paths."""
198
202
mpa_list , mpa_subpaths = zip (* mpa_targets )
@@ -219,6 +223,7 @@ def _save_mpas(gda_manager, mpa_targets: List[Tuple[MultiprocessArrayType, str]]
219
223
keep ,
220
224
overwrite ,
221
225
keep_every_n_steps ,
226
+ ckpt_start_time ,
222
227
has_mpa = True ,
223
228
write_commit_success = write_commit_success ,
224
229
async_manager = async_manager ))
@@ -392,7 +397,8 @@ def _remove_invalid_ckpts(ckpt_path: str, base_path: str, keep: int,
392
397
393
398
def _save_commit (ckpt_tmp_path : str , ckpt_path : str , base_path : str , keep : int ,
394
399
overwrite : bool , keep_every_n_steps : Optional [int ],
395
- has_mpa : bool , write_commit_success : bool ,
400
+ ckpt_start_time : float , has_mpa : bool ,
401
+ write_commit_success : bool ,
396
402
async_manager : Optional [AsyncManager ] = None ) -> None :
397
403
"""Commit changes after saving checkpoints to disk.
398
404
@@ -402,6 +408,7 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int,
402
408
2. Remove newer checkpoints (files that ordered larger than this save) if
403
409
`overwrite=True`.
404
410
3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`.
411
+ 4. Record program duration saved by this checkpoint.
405
412
"""
406
413
mpa_ckpt_tmp_path , mpa_ckpt_path = ckpt_tmp_path + MP_ARRAY_POSTFIX , ckpt_path + MP_ARRAY_POSTFIX
407
414
# Rename the multiprocess array path once serialization and writing finished.
@@ -429,6 +436,7 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int,
429
436
# Remove newer and older invalid checkpoints.
430
437
_remove_invalid_ckpts (ckpt_path , base_path , keep , overwrite ,
431
438
keep_every_n_steps , has_mpa )
439
+ orbax .utils .record_saved_duration (ckpt_start_time )
432
440
433
441
434
442
@@ -458,7 +466,8 @@ def _check_overwrite_error(ckpt_tmp_path: str, ckpt_path: str, base_path: str,
458
466
def _save_main_ckpt_file (target : bytes , has_mpa : bool , paths : Tuple [str , str ],
459
467
base_path : str , step : int ,
460
468
keep : int , overwrite : bool ,
461
- keep_every_n_steps : Optional [int ]):
469
+ keep_every_n_steps : Optional [int ],
470
+ ckpt_start_time : float ):
462
471
"""Save the main checkpoint file via file system."""
463
472
ckpt_tmp_path , ckpt_path = paths
464
473
io .makedirs (os .path .dirname (ckpt_path ))
@@ -475,6 +484,7 @@ def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str],
475
484
keep ,
476
485
overwrite ,
477
486
keep_every_n_steps ,
487
+ ckpt_start_time ,
478
488
has_mpa = False ,
479
489
write_commit_success = False )
480
490
@@ -536,6 +546,7 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
536
546
Returns:
537
547
Filename of saved checkpoint.
538
548
"""
549
+ start_time = time .time ()
539
550
# Make sure all saves are finished before the logic of checking and removing
540
551
# outdated checkpoints happens.
541
552
if async_manager :
@@ -562,6 +573,9 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
562
573
ckpt_path , target , save_args = save_args , force = overwrite )
563
574
_remove_invalid_ckpts (ckpt_path , base_path , keep , overwrite ,
564
575
keep_every_n_steps , True )
576
+ end_time = time .time ()
577
+ monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
578
+ end_time - start_time )
565
579
return ckpt_path
566
580
567
581
if not overwrite :
@@ -571,12 +585,15 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
571
585
# Save the files via I/O sync or async.
572
586
def save_main_ckpt_task ():
573
587
return _save_main_ckpt_file (target , False , (ckpt_tmp_path , ckpt_path ),
574
- base_path , step , keep ,
575
- overwrite , keep_every_n_steps )
588
+ base_path , step , keep , overwrite ,
589
+ keep_every_n_steps , start_time )
576
590
if async_manager :
577
591
async_manager .save_async (save_main_ckpt_task )
578
592
else :
579
593
save_main_ckpt_task ()
594
+ end_time = time .time ()
595
+ monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
596
+ end_time - start_time )
580
597
return ckpt_path
581
598
582
599
@@ -629,6 +646,7 @@ def save_checkpoint_multiprocess(
629
646
Returns:
630
647
Filename of saved checkpoint.
631
648
"""
649
+ start_time = time .time ()
632
650
# Make sure all saves are finished before the logic of checking and removing
633
651
# outdated checkpoints happens.
634
652
sync_global_devices ('starting_save_checkpoint' )
@@ -657,6 +675,9 @@ def save_checkpoint_multiprocess(
657
675
aggregate = not _use_multiprocess_serialization (x )), target )
658
676
orbax_checkpointer .save (
659
677
ckpt_path , target , save_args = save_args , force = overwrite )
678
+ end_time = time .time ()
679
+ monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
680
+ end_time - start_time )
660
681
return ckpt_path
661
682
662
683
target = serialization .to_state_dict (target )
@@ -670,8 +691,8 @@ def save_checkpoint_multiprocess(
670
691
# Save the files via I/O sync or async.
671
692
def save_main_ckpt_task ():
672
693
return _save_main_ckpt_file (target , has_mpa , (ckpt_tmp_path , ckpt_path ),
673
- base_path , step , keep ,
674
- overwrite , keep_every_n_steps )
694
+ base_path , step , keep , overwrite ,
695
+ keep_every_n_steps , start_time )
675
696
# Write the main checkpoint file only via process 0, to avoid race condition.
676
697
if process_index () == 0 :
677
698
if async_manager :
@@ -688,8 +709,11 @@ def save_main_ckpt_task():
688
709
_make_mpa_dirs (mpa_targets , ckpt_tmp_path )
689
710
sync_global_devices ('Flax:Checkpointing:AfterCreateMPADir' )
690
711
_save_mpas (gda_manager , mpa_targets , ckpt_tmp_path , ckpt_path , base_path ,
691
- keep , overwrite , keep_every_n_steps , async_manager )
712
+ keep , overwrite , keep_every_n_steps , start_time , async_manager )
692
713
714
+ end_time = time .time ()
715
+ monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
716
+ end_time - start_time )
693
717
return ckpt_path
694
718
695
719
@@ -769,6 +793,7 @@ def restore_checkpoint(
769
793
returned. This is to match the behavior of the case where a directory path
770
794
is specified but the directory has not yet been created.
771
795
"""
796
+ start_time = time .time ()
772
797
# Make sure any previous work is done before checking files.
773
798
if orbax_checkpointer and isinstance (orbax_checkpointer ,
774
799
orbax .AsyncCheckpointer ):
@@ -815,6 +840,9 @@ def make_restore_args(x):
815
840
restore_args = jax .tree_util .tree_map (make_restore_args , target )
816
841
restored = orbax_checkpointer .restore (
817
842
ckpt_path , item = target , restore_args = restore_args )
843
+ end_time = time .time ()
844
+ monitoring .record_event_duration_secs (_READ_CHECKPOINT_EVENT ,
845
+ end_time - start_time )
818
846
return restored
819
847
820
848
with io .GFile (ckpt_path , 'rb' ) as fp :
@@ -849,8 +877,15 @@ def read_chunk(i):
849
877
allow_partial_mpa_restoration )
850
878
851
879
if target is None :
852
- return state_dict
853
- return serialization .from_state_dict (target , state_dict )
880
+ restored_checkpoint = state_dict
881
+ else :
882
+ restored_checkpoint = serialization .from_state_dict (target , state_dict )
883
+
884
+ end_time = time .time ()
885
+ monitoring .record_event_duration_secs (_READ_CHECKPOINT_EVENT ,
886
+ end_time - start_time )
887
+
888
+ return restored_checkpoint
854
889
855
890
856
891
def convert_pre_linen (params : PyTree ) -> PyTree :
0 commit comments