3
3
4
4
"""Useful functions for saving state dicts to disk."""
5
5
6
+ import json
6
7
import logging
7
8
import os
9
+ import pickle
8
10
import textwrap
9
11
import warnings
12
+ from dataclasses import dataclass
10
13
from pathlib import Path
11
- from typing import Any , Dict , Optional , Union
14
+ from typing import Any , Dict , Optional , Sequence , Union
12
15
13
16
import torch
14
17
import torch .distributed .checkpoint as DCP
15
18
from packaging import version
16
19
from torch .distributed ._shard .sharded_tensor import ShardedTensor
17
20
from torch .distributed ._tensor import DTensor
18
21
22
+ from composer .checkpoint .state_dict import (
23
+ get_metadata_state_dict ,
24
+ get_model_state_dict ,
25
+ get_optim_state_dict ,
26
+ get_resumption_state_dict ,
27
+ )
28
+ from composer .core import State , Time
29
+ from composer .devices import Device
30
+ from composer .models import ComposerModel
31
+ from composer .utils import dist
32
+ from composer .utils .checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME , _write_checkpoint_file
33
+ from composer .utils .file_helpers import format_name_with_dist_and_time
34
+
35
+ log = logging .getLogger (__name__ )
36
+
37
+ MODEL_CHECKPOINT_DIRECTORY_NAME = 'model'
38
+ MONOLITHIC_MODEL_CHECKPOINT_FILENAME = 'model.pt'
39
+ OPTIM_CHECKPOINT_DIRECTORY_NAME = 'optim'
40
+ OPTIM_MONO_CHECKPOINT_FILENAME = 'optim.pt'
41
+ METADATA_CHECKPOINT_FILENAME = 'composer_metadata.json'
42
+ RESUMPTION_CHECKPOINT_FILENAME = 'resumption.pkl'
43
+
44
+
45
+ @dataclass
46
+ class CheckpointSaveOptions :
47
+ """Options for saving a checkpoint to disk.
48
+
49
+ Args:
50
+ destination_dir (str): The directory to save the checkpoint to.
51
+ save_frequency (Union[str, int, Time]): The frequency to save the checkpoint.
52
+ If '1ep', the checkpoint will be saved after each epoch.
53
+ If '1ba', the checkpoint will be saved after each batch.
54
+ If an int, the checkpoint will be saved after that many epochs.
55
+ dir_prefix (str): The prefix to use for the directory name. Can include {epoch} and {batch}.
56
+ overwrite (bool): Whether to overwrite the checkpoint if it already exists.
57
+ save_model (bool): Whether to save the model.
58
+ save_optimizer (bool): Whether to save the optimizer.
59
+ save_resumption_state (bool): Whether to save the resumption state.
60
+ num_checkpoints_to_keep (int): The number of checkpoints to keep.
61
+ If -1, all checkpoints will be kept.
62
+ save_format (str): The format to save the model in. 'pt', which is the standard pytorch serializarion, is the only option for now.
63
+ sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
64
+ precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
65
+ include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
66
+ ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
67
+ """
68
+ destination_dir : str
69
+ save_frequency : Union [str , int , Time ] = '1ep'
70
+ dir_prefix : str = 'ep{epoch}-ba{batch}'
71
+ overwrite : bool = False
72
+ save_model : bool = True
73
+ save_optimizer : bool = True
74
+ save_resumption_state : bool = True
75
+ num_checkpoints_to_keep : int = - 1
76
+ save_format : str = 'pt'
77
+ sharded_checkpoint : bool = False
78
+ precision : str = 'bf16'
79
+ include_keys : Optional [Union [str , Sequence [str ]]] = None
80
+ ignore_keys : Optional [Union [str , Sequence [str ]]] = None
81
+
82
+
83
+ def save_checkpoint_to_disk (
84
+ state : State ,
85
+ options : Optional [Union [CheckpointSaveOptions , Dict ]] = None ,
86
+ destination_dir : Optional [str ] = None ,
87
+ ):
88
+ """Saves a checkpoint to disk.
89
+
90
+ Args:
91
+ state (State): The state to save.
92
+ options (Optional[Union[CheckpointSaveOptions, Dict]]): The options for saving the checkpoint.
93
+ If None, destination_dir must be provided.
94
+ destination_dir (Optional[str]): The directory to save the checkpoint to.
95
+ If options is provided, this will overwrite options.destination_dir.
96
+ """
97
+ if options is None :
98
+ if destination_dir is None :
99
+ raise ValueError ('destination_dir must be provided if options is None' )
100
+ options = CheckpointSaveOptions (destination_dir = destination_dir )
101
+ else :
102
+ if isinstance (options , Dict ):
103
+ options = CheckpointSaveOptions (** options )
104
+ if destination_dir is not None :
105
+ options .destination_dir = destination_dir
106
+ save_path = os .path .join (options .destination_dir , options .dir_prefix )
107
+ save_path = format_name_with_dist_and_time (save_path , state .run_name , state .timestamp )
108
+ os .makedirs (save_path , exist_ok = True )
109
+ if options .save_model :
110
+ save_model_to_disk (
111
+ state .model ,
112
+ save_path ,
113
+ options .sharded_checkpoint ,
114
+ options .precision ,
115
+ options .include_keys ,
116
+ options .ignore_keys ,
117
+ options .overwrite ,
118
+ options .save_format ,
119
+ )
120
+ if options .save_optimizer :
121
+ optimizer = state .optimizers [0 ]
122
+ save_optim_to_disk (
123
+ state .model ,
124
+ optimizer ,
125
+ save_path ,
126
+ options .sharded_checkpoint ,
127
+ options .precision ,
128
+ options .overwrite ,
129
+ options .save_format ,
130
+ )
131
+ if options .save_resumption_state :
132
+ save_resumption_state_to_disk (state , save_path )
133
+
134
+ save_composer_metadata_to_disk (
135
+ save_path ,
136
+ state .model ,
137
+ options .sharded_checkpoint ,
138
+ options .precision ,
139
+ state .device ,
140
+ state .device_train_microbatch_size ,
141
+ )
142
+
143
+
144
+ def save_model_to_disk (
145
+ model : Union [ComposerModel , torch .nn .Module ],
146
+ destination_dir : str ,
147
+ sharded_checkpoint : bool = False ,
148
+ precision : str = 'fp32' ,
149
+ include_keys : Optional [Union [str , Sequence [str ]]] = None ,
150
+ ignore_keys : Optional [Union [str , Sequence [str ]]] = None ,
151
+ overwrite : bool = False ,
152
+ save_format : str = 'pt' , # or hf, safetensor
153
+ ) -> Optional [str ]:
154
+ """Saves a model to disk.
155
+
156
+ Args:
157
+ model (Union[ComposerModel, torch.nn.Module]): The model to save.
158
+ destination_dir (str): The directory to save the model to.
159
+ Model will be saved as distination_dir/models/model.pt if sharded_checkpoint is False,
160
+ otherwise all shards will be saved as destination_dir/models/__<rank>_0.distcp.
161
+ sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
162
+ precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
163
+ include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
164
+ ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
165
+ overwrite (bool): If True, the file will be overwritten if it exists.
166
+ save_format (str): The format to save the model in. One of 'pt', 'hf', or 'safetensor'.
167
+
168
+ Returns:
169
+ str: The full path to the saved model.
170
+ """
171
+ if save_format != 'pt' :
172
+ raise NotImplementedError (
173
+ f"Saving checkpoint in format { save_format } is not supported. Please choose from ['pt']." ,
174
+ )
175
+ model_state_dict = get_model_state_dict (
176
+ model ,
177
+ sharded_checkpoint ,
178
+ precision ,
179
+ include_keys ,
180
+ ignore_keys ,
181
+ )
182
+
183
+ destination_file_path = (
184
+ os .path .join (destination_dir , MODEL_CHECKPOINT_DIRECTORY_NAME ) if sharded_checkpoint else
185
+ os .path .join (destination_dir , MODEL_CHECKPOINT_DIRECTORY_NAME , MONOLITHIC_MODEL_CHECKPOINT_FILENAME )
186
+ )
187
+ saved_path = save_state_dict_to_disk (
188
+ state_dict = model_state_dict ,
189
+ destination_file_path = destination_file_path ,
190
+ overwrite = overwrite ,
191
+ save_format = save_format ,
192
+ )
193
+ return saved_path
194
+
195
+
196
+ def save_optim_to_disk (
197
+ model : Union [ComposerModel , torch .nn .Module ],
198
+ optimizer : torch .optim .Optimizer ,
199
+ destination_dir : str ,
200
+ sharded_checkpoint : bool = False ,
201
+ precision : str = 'fp32' ,
202
+ overwrite : bool = False ,
203
+ save_format : str = 'pt' ,
204
+ ) -> Optional [str ]:
205
+ """Saves an optimizer to disk.
206
+
207
+ Args:
208
+ model (Union[ComposerModel, torch.nn.Module]): The model to save.
209
+ optimizer (torch.optim.Optimizer): The optimizer to save.
210
+ destination_dir (str): The directory to save the optimizer to.
211
+ Optimizer will be saved as destination_dir/optim/optim.pt if sharded_checkpoint is False,
212
+ otherwise all shards will be saved as destination_dir/optim/__<rank>_0.distcp.
213
+ sharded_checkpoint (bool): Whether to save the optimizer as a sharded checkpoint.
214
+ precision (str): The precision to save the optimizer in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
215
+ overwrite (bool): If True, the file will be overwritten if it exists.
216
+ save_format (str): The format to save the optimizer in. One of 'pt'.
217
+ """
218
+ optim_state_dict = get_optim_state_dict (
219
+ model ,
220
+ optimizer ,
221
+ sharded_state_dict = sharded_checkpoint ,
222
+ precision = precision ,
223
+ )
224
+ destination_file_path = os .path .join (destination_dir ,
225
+ OPTIM_CHECKPOINT_DIRECTORY_NAME ) if sharded_checkpoint else os .path .join (
226
+ destination_dir ,
227
+ OPTIM_CHECKPOINT_DIRECTORY_NAME ,
228
+ OPTIM_MONO_CHECKPOINT_FILENAME ,
229
+ )
230
+ saved_path = save_state_dict_to_disk (
231
+ state_dict = optim_state_dict ,
232
+ destination_file_path = destination_file_path ,
233
+ overwrite = overwrite ,
234
+ save_format = save_format ,
235
+ )
236
+
237
+ return saved_path
238
+
239
+
240
+ def save_composer_metadata_to_disk (
241
+ destination_dir : str ,
242
+ model : Optional [Union [ComposerModel , torch .nn .Module ]] = None ,
243
+ sharded_state_dict : Optional [bool ] = None ,
244
+ precision : Optional [Union [str , torch .dtype ]] = None ,
245
+ device : Optional [Device ] = None ,
246
+ device_train_microbatch_size : Optional [Union [int , float ]] = None ,
247
+ ):
248
+ """Saves metadata about the model to disk.
249
+
250
+ Args:
251
+ destination_dir (str): The directory to save the metadata to.
252
+ model (Optional[Union[ComposerModel, torch.nn.Module]]): The model to save metadata about.
253
+ sharded_state_dict (Optional[bool]): Whether the model is sharded.
254
+ precision (Optional[Union[str, torch.dtype]]): The precision of the model.
255
+ device (Optional[Device]): The device the model is on.
256
+ device_train_microbatch_size (Optional[Union[int, float]]): The device train microbatch size.
257
+ """
258
+ md_dict = get_metadata_state_dict (
259
+ model ,
260
+ sharded_state_dict ,
261
+ precision ,
262
+ device ,
263
+ device_train_microbatch_size ,
264
+ )
265
+ os .makedirs (destination_dir , exist_ok = True )
266
+ destination_file_path = os .path .join (destination_dir , METADATA_CHECKPOINT_FILENAME )
267
+
268
+ if dist .get_global_rank () == 0 :
269
+ with open (destination_file_path , 'w' ) as f :
270
+ json .dump (md_dict , f , indent = 4 )
271
+ return destination_file_path
272
+
273
+
274
+ def save_resumption_state_to_disk (
275
+ state : State ,
276
+ destination_dir : str ,
277
+ ):
278
+ """Saves the resumption state to disk.
279
+
280
+ Args:
281
+ state (State): The state to save.
282
+ destination_dir (str): The directory to save the resumption state to.
283
+ """
284
+ resumption_state_dict = get_resumption_state_dict (state )
285
+ destination_file_path = os .path .join (destination_dir , RESUMPTION_CHECKPOINT_FILENAME )
286
+ with open (destination_file_path , 'wb' ) as f :
287
+ pickle .dump (resumption_state_dict , f )
288
+ return destination_file_path
289
+
290
+
19
291
from composer .utils import dist
20
292
from composer .utils .checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME , _write_checkpoint_file
21
293
@@ -80,6 +352,8 @@ def _save_sharded_state_dict_to_disk(
80
352
)
81
353
destination_file_path = stripped_path
82
354
355
+ # Wait for all ranks to get here before checking if the directory exists.
356
+ dist .barrier ()
83
357
if dist .get_global_rank () == 0 and not overwrite and os .path .exists (destination_file_path ):
84
358
raise ValueError (f'Directory { destination_file_path } already exists. Set overwrite=True to overwrite it.' )
85
359
@@ -94,6 +368,9 @@ def _save_sharded_state_dict_to_disk(
94
368
else :
95
369
DCP .save (state_dict = state_dict , storage_writer = DCP .FileSystemWriter (destination_file_path ))
96
370
371
+ log .debug (
372
+ f'Finished saving of sharded state dict to { destination_file_path } /{ _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME } ' ,
373
+ )
97
374
return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
98
375
99
376
@@ -106,13 +383,14 @@ def _save_full_state_dict_to_disk(
106
383
107
384
if save_format != 'pt' :
108
385
raise NotImplementedError (
109
- f"Saving sharded state dict to disk in format { save_format } is not supported. Please choose from ['pt']." ,
386
+ f"Saving full state dict to disk in format { save_format } is not supported. Please choose from ['pt']." ,
110
387
)
111
388
112
389
if not overwrite and os .path .exists (destination_file_path ):
113
390
raise ValueError (f'File { destination_file_path } already exists. Set overwrite=True to overwrite it.' )
114
391
115
392
if dist .get_global_rank () == 0 :
393
+ os .makedirs (os .path .dirname (destination_file_path ), exist_ok = True )
116
394
_write_checkpoint_file (state_dict = state_dict , filename = destination_file_path )
117
395
return destination_file_path
118
396
return None
@@ -130,7 +408,7 @@ def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
130
408
for value in state_dict .values ():
131
409
if isinstance (value , ShardedTensor ) or isinstance (value , DTensor ):
132
410
return True
133
- if isinstance (value , Dict ):
411
+ elif isinstance (value , Dict ):
134
412
is_sharded = is_state_dict_sharded (value )
135
413
if is_sharded :
136
414
return True
0 commit comments