1
1
import datetime
2
+ import os
2
3
import pathlib
3
- from typing import Optional
4
+ import shutil
5
+ import time
6
+ from typing import Any , Callable , Dict , Optional
4
7
5
8
import torch
6
9
from diffusers .utils import is_accelerate_available
7
10
8
11
from ..logging import get_logger
9
12
from ..utils import get_device_info
10
- from .base import BaseParallelBackend
11
- from .utils import apply_ddp_accelerate
13
+ from .base import BaseCheckpointer , BaseParallelBackend
12
14
13
15
14
16
if not is_accelerate_available ():
23
25
DistributedDataParallelKwargs ,
24
26
InitProcessGroupKwargs ,
25
27
ProjectConfiguration ,
28
+ set_seed ,
26
29
)
27
30
28
31
@@ -68,9 +71,31 @@ def __init__(
68
71
if dp_degree != world_size :
69
72
raise ValueError ("Data parallel degree must be equal to world size." )
70
73
71
- self ._accelerator : Accelerator = None
74
+ self ._accelerator = None
75
+ if world_size == 1 :
76
+ # Needs special handling for single GPU training
77
+ project_config = ProjectConfiguration (project_dir = self ._output_dir , logging_dir = self ._logging_dir )
78
+ dataloader_config = DataLoaderConfiguration (
79
+ split_batches = False , dispatch_batches = False , use_stateful_dataloader = True
80
+ )
81
+ init_process_group_kwargs = InitProcessGroupKwargs (
82
+ backend = self ._backend , timeout = datetime .timedelta (seconds = self ._timeout )
83
+ )
84
+ self ._accelerator = Accelerator (
85
+ project_config = project_config ,
86
+ dataloader_config = dataloader_config ,
87
+ gradient_accumulation_steps = gradient_accumulation_steps ,
88
+ log_with = None ,
89
+ kwargs_handlers = [init_process_group_kwargs ],
90
+ )
91
+ if torch .backends .mps .is_available ():
92
+ self ._accelerator .native_amp = False
93
+
72
94
self ._mesh : torch .distributed .DeviceMesh = None
73
95
96
+ def enable_determinism (self , seed : int ) -> None :
97
+ set_seed (seed )
98
+
74
99
def apply_ddp (self , model : torch .nn .Module , * args , ** kwargs ) -> torch .nn .Module :
75
100
project_config = None
76
101
ddp_kwargs = None
@@ -84,7 +109,7 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
84
109
init_process_group_kwargs = InitProcessGroupKwargs (
85
110
backend = self ._backend , timeout = datetime .timedelta (seconds = self ._timeout )
86
111
)
87
- self ._accelerator , model = apply_ddp_accelerate (
112
+ self ._accelerator , model = apply_ddp (
88
113
model ,
89
114
project_config ,
90
115
ddp_kwargs ,
@@ -96,6 +121,9 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
96
121
logger .debug ("Applied AccelerateParallel::apply_ddp to model." )
97
122
return model
98
123
124
+ def prepare_model (self , model : torch .nn .Module ) -> torch .nn .Module :
125
+ return self ._accelerator .prepare_model (model )
126
+
99
127
def prepare_dataset (self , dataset : torch .utils .data .IterableDataset ) -> torch .utils .data .IterableDataset :
100
128
logger .debug ("AccelerateParallelBackend::prepare_dataset completed!" )
101
129
return dataset
@@ -161,6 +189,9 @@ def _get_mesh():
161
189
self ._mesh = mesh
162
190
return _get_mesh ()
163
191
192
+ def get_checkpointer (self , * args , ** kwargs ):
193
+ return AccelerateCheckpointer (self ._accelerator , * args , ** kwargs )
194
+
164
195
@property
165
196
def world_size (self ):
166
197
return self ._accelerator .num_processes
@@ -191,6 +222,8 @@ def wait_for_everyone(self):
191
222
self ._accelerator .wait_for_everyone ()
192
223
193
224
def destroy (self ):
225
+ if self .is_main_process :
226
+ self .tracker .finish ()
194
227
self ._accelerator .end_training ()
195
228
196
229
@property
@@ -216,3 +249,134 @@ def context_parallel_enabled(self):
216
249
@property
217
250
def tensor_parallel_enabled (self ):
218
251
return self ._tp_degree > 1
252
+
253
+
254
+ class AccelerateCheckpointer (BaseCheckpointer ):
255
+ def __init__ (
256
+ self ,
257
+ accelerator : Accelerator ,
258
+ states : Dict [str , Any ],
259
+ checkpointing_steps : int ,
260
+ checkpointing_limit : int ,
261
+ output_dir : str ,
262
+ enable : bool = True ,
263
+ _callback_fn : Callable [[Dict [str , Any ]], Dict [str , Any ]] = None ,
264
+ _prefix : str = "finetrainers_step" ,
265
+ * args ,
266
+ ** kwargs ,
267
+ ) -> None :
268
+ self .accelerator = accelerator
269
+ self .states = states
270
+
271
+ self .checkpointing_steps = checkpointing_steps
272
+ self .checkpointing_limit = checkpointing_limit
273
+ self .output_dir = pathlib .Path (output_dir )
274
+ self .enable = enable
275
+ self ._callback_fn = _callback_fn
276
+ self ._prefix = _prefix
277
+
278
+ def save_model_hook (models , weights , output_dir : str ) -> None :
279
+ if not self .accelerator .is_main_process :
280
+ return
281
+
282
+ # TODO(aryan): this is a temporary assertion since we only support training transformer at the moment.
283
+ # Remove it when adding support for training text encoders/vae and more.
284
+ assert len (models ) == 1
285
+
286
+ _callback_fn (weights [0 ])
287
+ torch .save (self .states , os .path .join (output_dir , "states.pt" ))
288
+
289
+ def load_model_hook (models , input_dir ) -> None :
290
+ self .states = torch .load (os .path .join (input_dir , "states.pt" ))
291
+
292
+ self .accelerator .register_save_state_pre_hook (save_model_hook )
293
+ self .accelerator .register_load_state_pre_hook (load_model_hook )
294
+
295
+ logger .info (f"Checkpointing enabled. Checkpoints will be stored in '{ self .output_dir } '" )
296
+
297
+ def save (self , step : int = - 1 , force : bool = False , * , _device : torch .device , _is_main_process : bool ) -> str :
298
+ if not self ._should_checkpoint (step , force ):
299
+ return None
300
+
301
+ checkpoint_dir = self ._get_checkpoint_dir (step )
302
+ begin_time = time .monotonic ()
303
+ self .accelerator .save_state (checkpoint_dir .as_posix (), safe_serialization = True )
304
+ end_time = time .monotonic ()
305
+ logger .info (
306
+ f"Saved checkpoint in { end_time - begin_time :.2f} seconds at step { step } . Directory: { checkpoint_dir } "
307
+ )
308
+ self ._purge_stale_checkpoints ()
309
+
310
+ return checkpoint_dir .as_posix ()
311
+
312
+ def load (self , step : int = - 1 ) -> bool :
313
+ if not self .enable :
314
+ return False
315
+ if not self .output_dir .exists ():
316
+ return False
317
+ if step != - 1 and not self ._get_checkpoint_dir (step ).exists ():
318
+ return False
319
+
320
+ if step == - 1 :
321
+ latest_checkpoint_dir = self ._find_latest_checkpoint_dir ()
322
+ if latest_checkpoint_dir is None :
323
+ return False
324
+ step = int (latest_checkpoint_dir .name .split ("_" )[- 1 ])
325
+
326
+ checkpoint_dir = self ._get_checkpoint_dir (step )
327
+ logger .info (f"Loading checkpoint from '{ checkpoint_dir } ' at step { step } " )
328
+
329
+ begin_time = time .monotonic ()
330
+ self .accelerator .load_state (checkpoint_dir .as_posix ())
331
+ end_time = time .monotonic ()
332
+ logger .info (f"Loaded checkpoint in { end_time - begin_time :.2f} seconds." )
333
+
334
+ return True
335
+
336
+ def _should_checkpoint (self , step : int , force : bool ) -> bool :
337
+ if not self .enable :
338
+ return False
339
+ if not force :
340
+ if step % self .checkpointing_steps != 0 :
341
+ return False
342
+ return True
343
+
344
+ def _get_checkpoint_dir (self , step : int ) -> pathlib .Path :
345
+ return self .output_dir / f"{ self ._prefix } _{ step } "
346
+
347
+ def _find_latest_checkpoint_dir (self ) -> Optional [pathlib .Path ]:
348
+ checkpoints = sorted (self .output_dir .glob (f"{ self ._prefix } _*" ), key = lambda x : int (x .name .split ("_" )[- 1 ]))
349
+ return checkpoints [- 1 ] if len (checkpoints ) > 0 else None
350
+
351
+ def _purge_stale_checkpoints (self ) -> None :
352
+ if self .checkpointing_limit is None or self .checkpointing_limit <= 0 :
353
+ return
354
+ checkpoints = sorted (
355
+ self .output_dir .glob (f"{ self ._prefix } _*" ), key = lambda x : int (x .name .split ("_" )[- 1 ]), reverse = True
356
+ )
357
+ for checkpoint in checkpoints [self .checkpointing_limit :]:
358
+ logger .info (f"Deleting stale checkpoint: { checkpoint } " )
359
+ shutil .rmtree (checkpoint , ignore_errors = True )
360
+
361
+
362
+ def apply_ddp (
363
+ model : torch .nn .Module ,
364
+ project_config : Optional [ProjectConfiguration ] = None ,
365
+ ddp_kwargs : Optional [DistributedDataParallelKwargs ] = None ,
366
+ init_process_group_kwargs : Optional [InitProcessGroupKwargs ] = None ,
367
+ dataloader_config : Optional [DataLoaderConfiguration ] = None ,
368
+ gradient_accumulation_steps : Optional [int ] = None ,
369
+ accelerator : Optional [Accelerator ] = None ,
370
+ ) -> torch .nn .Module :
371
+ if accelerator is None :
372
+ accelerator = Accelerator (
373
+ project_config = project_config ,
374
+ dataloader_config = dataloader_config ,
375
+ gradient_accumulation_steps = gradient_accumulation_steps ,
376
+ log_with = None ,
377
+ kwargs_handlers = [ddp_kwargs , init_process_group_kwargs ],
378
+ )
379
+ if torch .backends .mps .is_available ():
380
+ accelerator .native_amp = False
381
+ accelerator .prepare_model (model )
382
+ return accelerator , model
0 commit comments