Skip to content

Commit 4960e96

Browse files
author
Flax Team
committed
Monitor checkpoint load/save durations.
PiperOrigin-RevId: 496431428
1 parent 7cbbe0c commit 4960e96

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

flax/training/checkpoints.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import pathlib
2626
import re
27+
import time
2728
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2829

2930
from absl import logging
@@ -34,13 +35,15 @@
3435
from flax import serialization
3536
from flax import traverse_util
3637
import jax
38+
from jax import monitoring
3739
from jax import process_index
3840
from jax import sharding
3941
from jax.experimental.global_device_array import GlobalDeviceArray
4042
from jax.experimental.multihost_utils import sync_global_devices
4143
import orbax.checkpoint as orbax
4244

43-
45+
_READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec'
46+
_WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec'
4447
_IMPORT_GDAM_SUCCESSFUL = False
4548
try:
4649
from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
@@ -193,6 +196,7 @@ def _make_mpa_dirs(mpa_targets: List[Tuple[MultiprocessArrayType, str]],
193196
def _save_mpas(gda_manager, mpa_targets: List[Tuple[MultiprocessArrayType, str]],
194197
tmp_path: str, final_path: str, base_path: str, keep: int,
195198
overwrite: bool, keep_every_n_steps: Optional[int],
199+
ckpt_start_time: float,
196200
async_manager: Optional[AsyncManager] = None):
197201
"""Save the multiprocess arrays given the paths."""
198202
mpa_list, mpa_subpaths = zip(*mpa_targets)
@@ -219,6 +223,7 @@ def _save_mpas(gda_manager, mpa_targets: List[Tuple[MultiprocessArrayType, str]]
219223
keep,
220224
overwrite,
221225
keep_every_n_steps,
226+
ckpt_start_time,
222227
has_mpa=True,
223228
write_commit_success=write_commit_success,
224229
async_manager=async_manager))
@@ -392,7 +397,8 @@ def _remove_invalid_ckpts(ckpt_path: str, base_path: str, keep: int,
392397

393398
def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int,
394399
overwrite: bool, keep_every_n_steps: Optional[int],
395-
has_mpa: bool, write_commit_success: bool,
400+
ckpt_start_time: float, has_mpa: bool,
401+
write_commit_success: bool,
396402
async_manager: Optional[AsyncManager] = None) -> None:
397403
"""Commit changes after saving checkpoints to disk.
398404
@@ -402,6 +408,7 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int,
402408
2. Remove newer checkpoints (files that ordered larger than this save) if
403409
`overwrite=True`.
404410
3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`.
411+
4. Record program duration saved by this checkpoint.
405412
"""
406413
mpa_ckpt_tmp_path, mpa_ckpt_path = ckpt_tmp_path + MP_ARRAY_POSTFIX, ckpt_path + MP_ARRAY_POSTFIX
407414
# Rename the multiprocess array path once serialization and writing finished.
@@ -429,6 +436,7 @@ def _save_commit(ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int,
429436
# Remove newer and older invalid checkpoints.
430437
_remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite,
431438
keep_every_n_steps, has_mpa)
439+
orbax.utils.record_saved_duration(ckpt_start_time)
432440

433441

434442

@@ -458,7 +466,8 @@ def _check_overwrite_error(ckpt_tmp_path: str, ckpt_path: str, base_path: str,
458466
def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str],
459467
base_path: str, step: int,
460468
keep: int, overwrite: bool,
461-
keep_every_n_steps: Optional[int]):
469+
keep_every_n_steps: Optional[int],
470+
ckpt_start_time: float):
462471
"""Save the main checkpoint file via file system."""
463472
ckpt_tmp_path, ckpt_path = paths
464473
io.makedirs(os.path.dirname(ckpt_path))
@@ -475,6 +484,7 @@ def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str],
475484
keep,
476485
overwrite,
477486
keep_every_n_steps,
487+
ckpt_start_time,
478488
has_mpa=False,
479489
write_commit_success=False)
480490

@@ -536,6 +546,7 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
536546
Returns:
537547
Filename of saved checkpoint.
538548
"""
549+
start_time = time.time()
539550
# Make sure all saves are finished before the logic of checking and removing
540551
# outdated checkpoints happens.
541552
if async_manager:
@@ -562,6 +573,9 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
562573
ckpt_path, target, save_args=save_args, force=overwrite)
563574
_remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite,
564575
keep_every_n_steps, True)
576+
end_time = time.time()
577+
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
578+
end_time - start_time)
565579
return ckpt_path
566580

567581
if not overwrite:
@@ -571,12 +585,15 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
571585
# Save the files via I/O sync or async.
572586
def save_main_ckpt_task():
573587
return _save_main_ckpt_file(target, False, (ckpt_tmp_path, ckpt_path),
574-
base_path, step, keep,
575-
overwrite, keep_every_n_steps)
588+
base_path, step, keep, overwrite,
589+
keep_every_n_steps, start_time)
576590
if async_manager:
577591
async_manager.save_async(save_main_ckpt_task)
578592
else:
579593
save_main_ckpt_task()
594+
end_time = time.time()
595+
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
596+
end_time - start_time)
580597
return ckpt_path
581598

582599

@@ -629,6 +646,7 @@ def save_checkpoint_multiprocess(
629646
Returns:
630647
Filename of saved checkpoint.
631648
"""
649+
start_time = time.time()
632650
# Make sure all saves are finished before the logic of checking and removing
633651
# outdated checkpoints happens.
634652
sync_global_devices('starting_save_checkpoint')
@@ -657,6 +675,9 @@ def save_checkpoint_multiprocess(
657675
aggregate=not _use_multiprocess_serialization(x)), target)
658676
orbax_checkpointer.save(
659677
ckpt_path, target, save_args=save_args, force=overwrite)
678+
end_time = time.time()
679+
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
680+
end_time - start_time)
660681
return ckpt_path
661682

662683
target = serialization.to_state_dict(target)
@@ -670,8 +691,8 @@ def save_checkpoint_multiprocess(
670691
# Save the files via I/O sync or async.
671692
def save_main_ckpt_task():
672693
return _save_main_ckpt_file(target, has_mpa, (ckpt_tmp_path, ckpt_path),
673-
base_path, step, keep,
674-
overwrite, keep_every_n_steps)
694+
base_path, step, keep, overwrite,
695+
keep_every_n_steps, start_time)
675696
# Write the main checkpoint file only via process 0, to avoid race condition.
676697
if process_index() == 0:
677698
if async_manager:
@@ -688,8 +709,11 @@ def save_main_ckpt_task():
688709
_make_mpa_dirs(mpa_targets, ckpt_tmp_path)
689710
sync_global_devices('Flax:Checkpointing:AfterCreateMPADir')
690711
_save_mpas(gda_manager, mpa_targets, ckpt_tmp_path, ckpt_path, base_path,
691-
keep, overwrite, keep_every_n_steps, async_manager)
712+
keep, overwrite, keep_every_n_steps, start_time, async_manager)
692713

714+
end_time = time.time()
715+
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
716+
end_time - start_time)
693717
return ckpt_path
694718

695719

@@ -769,6 +793,7 @@ def restore_checkpoint(
769793
returned. This is to match the behavior of the case where a directory path
770794
is specified but the directory has not yet been created.
771795
"""
796+
start_time = time.time()
772797
# Make sure any previous work is done before checking files.
773798
if orbax_checkpointer and isinstance(orbax_checkpointer,
774799
orbax.AsyncCheckpointer):
@@ -815,6 +840,9 @@ def make_restore_args(x):
815840
restore_args = jax.tree_util.tree_map(make_restore_args, target)
816841
restored = orbax_checkpointer.restore(
817842
ckpt_path, item=target, restore_args=restore_args)
843+
end_time = time.time()
844+
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT,
845+
end_time - start_time)
818846
return restored
819847

820848
with io.GFile(ckpt_path, 'rb') as fp:
@@ -849,8 +877,15 @@ def read_chunk(i):
849877
allow_partial_mpa_restoration)
850878

851879
if target is None:
852-
return state_dict
853-
return serialization.from_state_dict(target, state_dict)
880+
restored_checkpoint = state_dict
881+
else:
882+
restored_checkpoint = serialization.from_state_dict(target, state_dict)
883+
884+
end_time = time.time()
885+
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT,
886+
end_time - start_time)
887+
888+
return restored_checkpoint
854889

855890

856891
def convert_pre_linen(params: PyTree) -> PyTree:

0 commit comments

Comments
 (0)