Skip to content

Commit e5778a3

Browse files
author
v-chen_data
committed
Merge branch 'hsdp-ci-tests' of g.yxqyang.asia-regular:mosaicml/composer into hsdp-ci-tests
2 parents 6352972 + 13ab59c commit e5778a3

File tree

13 files changed

+525
-72
lines changed

13 files changed

+525
-72
lines changed

composer/callbacks/checkpoint_saver.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030
is_model_deepspeed,
3131
partial_format,
3232
)
33+
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
3334
from composer.utils.compression import get_compressor, is_compressed_pt
3435
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY
3536

3637
log = logging.getLogger(__name__)
3738

3839
__all__ = ['CheckpointSaver']
3940

40-
_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'
41-
4241

4342
class CheckpointSaver(Callback): # noqa: D101
4443
__doc__ = f"""Callback to save checkpoints.

composer/callbacks/eval_output_logging_callback.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
114114
self.rows.extend(rows)
115115

116116
def eval_end(self, state: State, logger: Logger) -> None:
117+
# eval_batch_end will have set these if there is anything to log
118+
if self.name is None or self.columns is None:
119+
return
120+
117121
list_of_rows = dist.all_gather_object(self.rows)
118122
rows = [row for rows in list_of_rows for row in rows]
119123
for dest_logger in logger.destinations:

composer/checkpoint/save.py

Lines changed: 281 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,291 @@
33

44
"""Useful functions for saving state dicts to disk."""
55

6+
import json
67
import logging
78
import os
9+
import pickle
810
import textwrap
911
import warnings
12+
from dataclasses import dataclass
1013
from pathlib import Path
11-
from typing import Any, Dict, Optional, Union
14+
from typing import Any, Dict, Optional, Sequence, Union
1215

1316
import torch
1417
import torch.distributed.checkpoint as DCP
1518
from packaging import version
1619
from torch.distributed._shard.sharded_tensor import ShardedTensor
1720
from torch.distributed._tensor import DTensor
1821

22+
from composer.checkpoint.state_dict import (
23+
get_metadata_state_dict,
24+
get_model_state_dict,
25+
get_optim_state_dict,
26+
get_resumption_state_dict,
27+
)
28+
from composer.core import State, Time
29+
from composer.devices import Device
30+
from composer.models import ComposerModel
31+
from composer.utils import dist
32+
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file
33+
from composer.utils.file_helpers import format_name_with_dist_and_time
34+
35+
log = logging.getLogger(__name__)
36+
37+
MODEL_CHECKPOINT_DIRECTORY_NAME = 'model'
38+
MONOLITHIC_MODEL_CHECKPOINT_FILENAME = 'model.pt'
39+
OPTIM_CHECKPOINT_DIRECTORY_NAME = 'optim'
40+
OPTIM_MONO_CHECKPOINT_FILENAME = 'optim.pt'
41+
METADATA_CHECKPOINT_FILENAME = 'composer_metadata.json'
42+
RESUMPTION_CHECKPOINT_FILENAME = 'resumption.pkl'
43+
44+
45+
@dataclass
46+
class CheckpointSaveOptions:
47+
"""Options for saving a checkpoint to disk.
48+
49+
Args:
50+
destination_dir (str): The directory to save the checkpoint to.
51+
save_frequency (Union[str, int, Time]): The frequency to save the checkpoint.
52+
If '1ep', the checkpoint will be saved after each epoch.
53+
If '1ba', the checkpoint will be saved after each batch.
54+
If an int, the checkpoint will be saved after that many epochs.
55+
dir_prefix (str): The prefix to use for the directory name. Can include {epoch} and {batch}.
56+
overwrite (bool): Whether to overwrite the checkpoint if it already exists.
57+
save_model (bool): Whether to save the model.
58+
save_optimizer (bool): Whether to save the optimizer.
59+
save_resumption_state (bool): Whether to save the resumption state.
60+
num_checkpoints_to_keep (int): The number of checkpoints to keep.
61+
If -1, all checkpoints will be kept.
62+
save_format (str): The format to save the model in. 'pt', which is the standard pytorch serializarion, is the only option for now.
63+
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
64+
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
65+
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
66+
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
67+
"""
68+
destination_dir: str
69+
save_frequency: Union[str, int, Time] = '1ep'
70+
dir_prefix: str = 'ep{epoch}-ba{batch}'
71+
overwrite: bool = False
72+
save_model: bool = True
73+
save_optimizer: bool = True
74+
save_resumption_state: bool = True
75+
num_checkpoints_to_keep: int = -1
76+
save_format: str = 'pt'
77+
sharded_checkpoint: bool = False
78+
precision: str = 'bf16'
79+
include_keys: Optional[Union[str, Sequence[str]]] = None
80+
ignore_keys: Optional[Union[str, Sequence[str]]] = None
81+
82+
83+
def save_checkpoint_to_disk(
84+
state: State,
85+
options: Optional[Union[CheckpointSaveOptions, Dict]] = None,
86+
destination_dir: Optional[str] = None,
87+
):
88+
"""Saves a checkpoint to disk.
89+
90+
Args:
91+
state (State): The state to save.
92+
options (Optional[Union[CheckpointSaveOptions, Dict]]): The options for saving the checkpoint.
93+
If None, destination_dir must be provided.
94+
destination_dir (Optional[str]): The directory to save the checkpoint to.
95+
If options is provided, this will overwrite options.destination_dir.
96+
"""
97+
if options is None:
98+
if destination_dir is None:
99+
raise ValueError('destination_dir must be provided if options is None')
100+
options = CheckpointSaveOptions(destination_dir=destination_dir)
101+
else:
102+
if isinstance(options, Dict):
103+
options = CheckpointSaveOptions(**options)
104+
if destination_dir is not None:
105+
options.destination_dir = destination_dir
106+
save_path = os.path.join(options.destination_dir, options.dir_prefix)
107+
save_path = format_name_with_dist_and_time(save_path, state.run_name, state.timestamp)
108+
os.makedirs(save_path, exist_ok=True)
109+
if options.save_model:
110+
save_model_to_disk(
111+
state.model,
112+
save_path,
113+
options.sharded_checkpoint,
114+
options.precision,
115+
options.include_keys,
116+
options.ignore_keys,
117+
options.overwrite,
118+
options.save_format,
119+
)
120+
if options.save_optimizer:
121+
optimizer = state.optimizers[0]
122+
save_optim_to_disk(
123+
state.model,
124+
optimizer,
125+
save_path,
126+
options.sharded_checkpoint,
127+
options.precision,
128+
options.overwrite,
129+
options.save_format,
130+
)
131+
if options.save_resumption_state:
132+
save_resumption_state_to_disk(state, save_path)
133+
134+
save_composer_metadata_to_disk(
135+
save_path,
136+
state.model,
137+
options.sharded_checkpoint,
138+
options.precision,
139+
state.device,
140+
state.device_train_microbatch_size,
141+
)
142+
143+
144+
def save_model_to_disk(
145+
model: Union[ComposerModel, torch.nn.Module],
146+
destination_dir: str,
147+
sharded_checkpoint: bool = False,
148+
precision: str = 'fp32',
149+
include_keys: Optional[Union[str, Sequence[str]]] = None,
150+
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
151+
overwrite: bool = False,
152+
save_format: str = 'pt', # or hf, safetensor
153+
) -> Optional[str]:
154+
"""Saves a model to disk.
155+
156+
Args:
157+
model (Union[ComposerModel, torch.nn.Module]): The model to save.
158+
destination_dir (str): The directory to save the model to.
159+
Model will be saved as distination_dir/models/model.pt if sharded_checkpoint is False,
160+
otherwise all shards will be saved as destination_dir/models/__<rank>_0.distcp.
161+
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
162+
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
163+
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
164+
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
165+
overwrite (bool): If True, the file will be overwritten if it exists.
166+
save_format (str): The format to save the model in. One of 'pt', 'hf', or 'safetensor'.
167+
168+
Returns:
169+
str: The full path to the saved model.
170+
"""
171+
if save_format != 'pt':
172+
raise NotImplementedError(
173+
f"Saving checkpoint in format {save_format} is not supported. Please choose from ['pt'].",
174+
)
175+
model_state_dict = get_model_state_dict(
176+
model,
177+
sharded_checkpoint,
178+
precision,
179+
include_keys,
180+
ignore_keys,
181+
)
182+
183+
destination_file_path = (
184+
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else
185+
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME, MONOLITHIC_MODEL_CHECKPOINT_FILENAME)
186+
)
187+
saved_path = save_state_dict_to_disk(
188+
state_dict=model_state_dict,
189+
destination_file_path=destination_file_path,
190+
overwrite=overwrite,
191+
save_format=save_format,
192+
)
193+
return saved_path
194+
195+
196+
def save_optim_to_disk(
197+
model: Union[ComposerModel, torch.nn.Module],
198+
optimizer: torch.optim.Optimizer,
199+
destination_dir: str,
200+
sharded_checkpoint: bool = False,
201+
precision: str = 'fp32',
202+
overwrite: bool = False,
203+
save_format: str = 'pt',
204+
) -> Optional[str]:
205+
"""Saves an optimizer to disk.
206+
207+
Args:
208+
model (Union[ComposerModel, torch.nn.Module]): The model to save.
209+
optimizer (torch.optim.Optimizer): The optimizer to save.
210+
destination_dir (str): The directory to save the optimizer to.
211+
Optimizer will be saved as destination_dir/optim/optim.pt if sharded_checkpoint is False,
212+
otherwise all shards will be saved as destination_dir/optim/__<rank>_0.distcp.
213+
sharded_checkpoint (bool): Whether to save the optimizer as a sharded checkpoint.
214+
precision (str): The precision to save the optimizer in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
215+
overwrite (bool): If True, the file will be overwritten if it exists.
216+
save_format (str): The format to save the optimizer in. One of 'pt'.
217+
"""
218+
optim_state_dict = get_optim_state_dict(
219+
model,
220+
optimizer,
221+
sharded_state_dict=sharded_checkpoint,
222+
precision=precision,
223+
)
224+
destination_file_path = os.path.join(destination_dir,
225+
OPTIM_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else os.path.join(
226+
destination_dir,
227+
OPTIM_CHECKPOINT_DIRECTORY_NAME,
228+
OPTIM_MONO_CHECKPOINT_FILENAME,
229+
)
230+
saved_path = save_state_dict_to_disk(
231+
state_dict=optim_state_dict,
232+
destination_file_path=destination_file_path,
233+
overwrite=overwrite,
234+
save_format=save_format,
235+
)
236+
237+
return saved_path
238+
239+
240+
def save_composer_metadata_to_disk(
241+
destination_dir: str,
242+
model: Optional[Union[ComposerModel, torch.nn.Module]] = None,
243+
sharded_state_dict: Optional[bool] = None,
244+
precision: Optional[Union[str, torch.dtype]] = None,
245+
device: Optional[Device] = None,
246+
device_train_microbatch_size: Optional[Union[int, float]] = None,
247+
):
248+
"""Saves metadata about the model to disk.
249+
250+
Args:
251+
destination_dir (str): The directory to save the metadata to.
252+
model (Optional[Union[ComposerModel, torch.nn.Module]]): The model to save metadata about.
253+
sharded_state_dict (Optional[bool]): Whether the model is sharded.
254+
precision (Optional[Union[str, torch.dtype]]): The precision of the model.
255+
device (Optional[Device]): The device the model is on.
256+
device_train_microbatch_size (Optional[Union[int, float]]): The device train microbatch size.
257+
"""
258+
md_dict = get_metadata_state_dict(
259+
model,
260+
sharded_state_dict,
261+
precision,
262+
device,
263+
device_train_microbatch_size,
264+
)
265+
os.makedirs(destination_dir, exist_ok=True)
266+
destination_file_path = os.path.join(destination_dir, METADATA_CHECKPOINT_FILENAME)
267+
268+
if dist.get_global_rank() == 0:
269+
with open(destination_file_path, 'w') as f:
270+
json.dump(md_dict, f, indent=4)
271+
return destination_file_path
272+
273+
274+
def save_resumption_state_to_disk(
275+
state: State,
276+
destination_dir: str,
277+
):
278+
"""Saves the resumption state to disk.
279+
280+
Args:
281+
state (State): The state to save.
282+
destination_dir (str): The directory to save the resumption state to.
283+
"""
284+
resumption_state_dict = get_resumption_state_dict(state)
285+
destination_file_path = os.path.join(destination_dir, RESUMPTION_CHECKPOINT_FILENAME)
286+
with open(destination_file_path, 'wb') as f:
287+
pickle.dump(resumption_state_dict, f)
288+
return destination_file_path
289+
290+
19291
from composer.utils import dist
20292
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file
21293

@@ -80,6 +352,8 @@ def _save_sharded_state_dict_to_disk(
80352
)
81353
destination_file_path = stripped_path
82354

355+
# Wait for all ranks to get here before checking if the directory exists.
356+
dist.barrier()
83357
if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
84358
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')
85359

@@ -94,6 +368,9 @@ def _save_sharded_state_dict_to_disk(
94368
else:
95369
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
96370

371+
log.debug(
372+
f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
373+
)
97374
return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
98375

99376

@@ -106,13 +383,14 @@ def _save_full_state_dict_to_disk(
106383

107384
if save_format != 'pt':
108385
raise NotImplementedError(
109-
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
386+
f"Saving full state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
110387
)
111388

112389
if not overwrite and os.path.exists(destination_file_path):
113390
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')
114391

115392
if dist.get_global_rank() == 0:
393+
os.makedirs(os.path.dirname(destination_file_path), exist_ok=True)
116394
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
117395
return destination_file_path
118396
return None
@@ -130,7 +408,7 @@ def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
130408
for value in state_dict.values():
131409
if isinstance(value, ShardedTensor) or isinstance(value, DTensor):
132410
return True
133-
if isinstance(value, Dict):
411+
elif isinstance(value, Dict):
134412
is_sharded = is_state_dict_sharded(value)
135413
if is_sharded:
136414
return True

composer/checkpoint/state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def get_metadata_state_dict(
380380
sharded_state_dict: Optional[bool] = None,
381381
precision: Optional[Union[str, torch.dtype]] = None,
382382
device: Optional[Device] = None,
383-
device_train_microbatch_size: Optional[int] = None,
383+
device_train_microbatch_size: Optional[Union[int, float]] = None,
384384
) -> dict[str, Any]:
385385
"""Generate the metadata and integrations for a training run.
386386

0 commit comments

Comments
 (0)