Skip to content

Commit 7a2afa5

Browse files
authored
Add back accelerate compatibility (#339)
* update * update * add more tests * update * remove unused function
1 parent 7d70dd4 commit 7a2afa5

File tree

12 files changed

+656
-336
lines changed

12 files changed

+656
-336
lines changed

finetrainers/models/cogview4/base_specification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,11 @@ def forward(
291291
latents = posterior.sample(generator=generator)
292292
del posterior
293293

294-
latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
294+
if getattr(self.vae_config, "shift_factor", None) is not None:
295+
latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
296+
else:
297+
latents = latents * self.vae_config.scaling_factor
298+
295299
noise = torch.zeros_like(latents).normal_(generator=generator)
296300
timesteps = (sigmas.flatten() * 1000.0).long()
297301

finetrainers/parallel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .accelerate import AccelerateParallelBackend
55
from .ptd import PytorchDTensorParallelBackend
6-
from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
6+
from .utils import dist_max, dist_mean
77

88

99
ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]

finetrainers/parallel/accelerate.py

Lines changed: 169 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import datetime
2+
import os
23
import pathlib
3-
from typing import Optional
4+
import shutil
5+
import time
6+
from typing import Any, Callable, Dict, Optional
47

58
import torch
69
from diffusers.utils import is_accelerate_available
710

811
from ..logging import get_logger
912
from ..utils import get_device_info
10-
from .base import BaseParallelBackend
11-
from .utils import apply_ddp_accelerate
13+
from .base import BaseCheckpointer, BaseParallelBackend
1214

1315

1416
if not is_accelerate_available():
@@ -23,6 +25,7 @@
2325
DistributedDataParallelKwargs,
2426
InitProcessGroupKwargs,
2527
ProjectConfiguration,
28+
set_seed,
2629
)
2730

2831

@@ -68,9 +71,31 @@ def __init__(
6871
if dp_degree != world_size:
6972
raise ValueError("Data parallel degree must be equal to world size.")
7073

71-
self._accelerator: Accelerator = None
74+
self._accelerator = None
75+
if world_size == 1:
76+
# Needs special handling for single GPU training
77+
project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
78+
dataloader_config = DataLoaderConfiguration(
79+
split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
80+
)
81+
init_process_group_kwargs = InitProcessGroupKwargs(
82+
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
83+
)
84+
self._accelerator = Accelerator(
85+
project_config=project_config,
86+
dataloader_config=dataloader_config,
87+
gradient_accumulation_steps=gradient_accumulation_steps,
88+
log_with=None,
89+
kwargs_handlers=[init_process_group_kwargs],
90+
)
91+
if torch.backends.mps.is_available():
92+
self._accelerator.native_amp = False
93+
7294
self._mesh: torch.distributed.DeviceMesh = None
7395

96+
def enable_determinism(self, seed: int) -> None:
97+
set_seed(seed)
98+
7499
def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
75100
project_config = None
76101
ddp_kwargs = None
@@ -84,7 +109,7 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
84109
init_process_group_kwargs = InitProcessGroupKwargs(
85110
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
86111
)
87-
self._accelerator, model = apply_ddp_accelerate(
112+
self._accelerator, model = apply_ddp(
88113
model,
89114
project_config,
90115
ddp_kwargs,
@@ -96,6 +121,9 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
96121
logger.debug("Applied AccelerateParallel::apply_ddp to model.")
97122
return model
98123

124+
def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
125+
return self._accelerator.prepare_model(model)
126+
99127
def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
100128
logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
101129
return dataset
@@ -161,6 +189,9 @@ def _get_mesh():
161189
self._mesh = mesh
162190
return _get_mesh()
163191

192+
def get_checkpointer(self, *args, **kwargs):
193+
return AccelerateCheckpointer(self._accelerator, *args, **kwargs)
194+
164195
@property
165196
def world_size(self):
166197
return self._accelerator.num_processes
@@ -191,6 +222,8 @@ def wait_for_everyone(self):
191222
self._accelerator.wait_for_everyone()
192223

193224
def destroy(self):
225+
if self.is_main_process:
226+
self.tracker.finish()
194227
self._accelerator.end_training()
195228

196229
@property
@@ -216,3 +249,134 @@ def context_parallel_enabled(self):
216249
@property
217250
def tensor_parallel_enabled(self):
218251
return self._tp_degree > 1
252+
253+
254+
class AccelerateCheckpointer(BaseCheckpointer):
255+
def __init__(
256+
self,
257+
accelerator: Accelerator,
258+
states: Dict[str, Any],
259+
checkpointing_steps: int,
260+
checkpointing_limit: int,
261+
output_dir: str,
262+
enable: bool = True,
263+
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
264+
_prefix: str = "finetrainers_step",
265+
*args,
266+
**kwargs,
267+
) -> None:
268+
self.accelerator = accelerator
269+
self.states = states
270+
271+
self.checkpointing_steps = checkpointing_steps
272+
self.checkpointing_limit = checkpointing_limit
273+
self.output_dir = pathlib.Path(output_dir)
274+
self.enable = enable
275+
self._callback_fn = _callback_fn
276+
self._prefix = _prefix
277+
278+
def save_model_hook(models, weights, output_dir: str) -> None:
279+
if not self.accelerator.is_main_process:
280+
return
281+
282+
# TODO(aryan): this is a temporary assertion since we only support training transformer at the moment.
283+
# Remove it when adding support for training text encoders/vae and more.
284+
assert len(models) == 1
285+
286+
_callback_fn(weights[0])
287+
torch.save(self.states, os.path.join(output_dir, "states.pt"))
288+
289+
def load_model_hook(models, input_dir) -> None:
290+
self.states = torch.load(os.path.join(input_dir, "states.pt"))
291+
292+
self.accelerator.register_save_state_pre_hook(save_model_hook)
293+
self.accelerator.register_load_state_pre_hook(load_model_hook)
294+
295+
logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")
296+
297+
def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
298+
if not self._should_checkpoint(step, force):
299+
return None
300+
301+
checkpoint_dir = self._get_checkpoint_dir(step)
302+
begin_time = time.monotonic()
303+
self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True)
304+
end_time = time.monotonic()
305+
logger.info(
306+
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
307+
)
308+
self._purge_stale_checkpoints()
309+
310+
return checkpoint_dir.as_posix()
311+
312+
def load(self, step: int = -1) -> bool:
313+
if not self.enable:
314+
return False
315+
if not self.output_dir.exists():
316+
return False
317+
if step != -1 and not self._get_checkpoint_dir(step).exists():
318+
return False
319+
320+
if step == -1:
321+
latest_checkpoint_dir = self._find_latest_checkpoint_dir()
322+
if latest_checkpoint_dir is None:
323+
return False
324+
step = int(latest_checkpoint_dir.name.split("_")[-1])
325+
326+
checkpoint_dir = self._get_checkpoint_dir(step)
327+
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")
328+
329+
begin_time = time.monotonic()
330+
self.accelerator.load_state(checkpoint_dir.as_posix())
331+
end_time = time.monotonic()
332+
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")
333+
334+
return True
335+
336+
def _should_checkpoint(self, step: int, force: bool) -> bool:
337+
if not self.enable:
338+
return False
339+
if not force:
340+
if step % self.checkpointing_steps != 0:
341+
return False
342+
return True
343+
344+
def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
345+
return self.output_dir / f"{self._prefix}_{step}"
346+
347+
def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
348+
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
349+
return checkpoints[-1] if len(checkpoints) > 0 else None
350+
351+
def _purge_stale_checkpoints(self) -> None:
352+
if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
353+
return
354+
checkpoints = sorted(
355+
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
356+
)
357+
for checkpoint in checkpoints[self.checkpointing_limit :]:
358+
logger.info(f"Deleting stale checkpoint: {checkpoint}")
359+
shutil.rmtree(checkpoint, ignore_errors=True)
360+
361+
362+
def apply_ddp(
363+
model: torch.nn.Module,
364+
project_config: Optional[ProjectConfiguration] = None,
365+
ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
366+
init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
367+
dataloader_config: Optional[DataLoaderConfiguration] = None,
368+
gradient_accumulation_steps: Optional[int] = None,
369+
accelerator: Optional[Accelerator] = None,
370+
) -> torch.nn.Module:
371+
if accelerator is None:
372+
accelerator = Accelerator(
373+
project_config=project_config,
374+
dataloader_config=dataloader_config,
375+
gradient_accumulation_steps=gradient_accumulation_steps,
376+
log_with=None,
377+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
378+
)
379+
if torch.backends.mps.is_available():
380+
accelerator.native_amp = False
381+
accelerator.prepare_model(model)
382+
return accelerator, model

finetrainers/parallel/base.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from contextlib import contextmanager
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Callable, Dict, List, Optional
33

44
import torch
55

@@ -11,9 +11,18 @@ class BaseParallelBackend:
1111
Base class that contains properties and methods that should be implemented by different parallel backends.
1212
"""
1313

14+
def enable_determinism(self, seed: int) -> None:
15+
raise NotImplementedError("Method `enable_determinism` must be implemented by subclass.")
16+
1417
def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
1518
raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
1619

20+
def apply_fsdp2(self, *args, **kwargs) -> torch.nn.Module:
21+
raise NotImplementedError("Method `apply_fsdp2` must be implemented by subclass.")
22+
23+
def prepare_model(self, *args, **kwargs) -> Any:
24+
raise NotImplementedError("Method `prepare_model` must be implemented by subclass.")
25+
1726
def prepare_dataset(self, *args, **kwargs) -> Any:
1827
raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
1928

@@ -26,6 +35,9 @@ def prepare_optimizer(self, *args, **kwargs) -> Any:
2635
def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
2736
raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
2837

38+
def get_checkpointer(self, *args, **kwargs) -> None:
39+
raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.")
40+
2941
def initialize_trackers(
3042
self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
3143
) -> TrackerType:
@@ -94,3 +106,33 @@ def context_parallel_enabled(self):
94106
@property
95107
def tensor_parallel_enabled(self):
96108
raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
109+
110+
111+
class BaseCheckpointer:
112+
r"""
113+
Base class that contains properties and methods that should be implemented by different parallel backends.
114+
"""
115+
116+
def __init__(
117+
self,
118+
dataloader: torch.utils.data.DataLoader,
119+
model_parts: List[torch.nn.Module],
120+
optimizers: Any,
121+
schedulers: Any,
122+
states: Dict[str, Any],
123+
checkpointing_steps: int,
124+
checkpointing_limit: int,
125+
output_dir: str,
126+
enable: bool = True,
127+
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
128+
_prefix: str = "finetrainers_step",
129+
*args,
130+
**kwargs,
131+
) -> None:
132+
raise NotImplementedError("Method `__init__` must be implemented by subclass.")
133+
134+
def save(self, step: int, force: bool, *, _device: Optional[torch.device] = None, _is_main_process: bool) -> str:
135+
raise NotImplementedError("Method `save` must be implemented by subclass.")
136+
137+
def load(self, step: int = -1) -> bool:
138+
raise NotImplementedError("Method `load` must be implemented by subclass.")

0 commit comments

Comments
 (0)