|
2 | 2 | import os
|
3 | 3 | import pathlib
|
4 | 4 | import sys
|
5 |
| -from typing import Any, Dict, List, Optional |
| 5 | +from typing import Any, Dict, List, Literal, Optional, Union |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 |
|
9 | 9 | from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType
|
10 | 10 | from .logging import get_logger
|
11 | 11 | from .parallel import ParallelBackendEnum
|
12 |
| -from .trainer.config_utils import ConfigMixin |
13 |
| -from .utils import get_non_null_items |
| 12 | +from .utils import ArgsConfigMixin, get_non_null_items |
14 | 13 |
|
15 | 14 |
|
16 | 15 | logger = get_logger()
|
17 | 16 |
|
| 17 | +# fmt: off |
| 18 | +# Must match src/finetrainers/models/attention_dispatch.py |
| 19 | +AttentionProviderTraining = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "xformers"] |
| 20 | +AttentionProviderValidation = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "sage", "sage_varlen", "_sage_qk_int8_pv_fp8_cuda", "_sage_qk_int8_pv_fp8_cuda_sm90", "_sage_qk_int8_pv_fp16_cuda", "_sage_qk_int8_pv_fp16_triton", "xformers"] |
| 21 | + |
| 22 | +# We do a union because every ArgsConfigMixin registered to BaseArgs can be looked up using the `__getattribute__` override |
| 23 | +BaseArgsType = Union["BaseArgs", "AttentionProviderArgs"] |
| 24 | +# fmt: on |
| 25 | + |
| 26 | + |
| 27 | +class AttentionProviderArgs(ArgsConfigMixin): |
| 28 | + """ |
| 29 | + Args: |
| 30 | + attn_provider_training (`List[str]`, defaults to `None`): |
| 31 | + Must be a string of the form `"<component_name>:<attention_provider>"`. For example, if you want to use |
| 32 | + flash varlen attention implementation on the `transformer` module, you can set this argument to |
| 33 | + `"transformer:flash_varlen"`. The attention provider will be used for both training and validation. |
| 34 | + Options for `<attention_provider>` are: |
| 35 | + flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math, xformers |
| 36 | + attn_provider_inference (`List[str]`, defaults to `None`): |
| 37 | + Must be a string of the form `"<component_name>:<attention_provider>"`. For example, if you want to use |
| 38 | + flash varlen attention implementation on the `transformer` module, you can set this argument to |
| 39 | + `"transformer:flash_varlen"`. The attention provider will be used for both training and validation. |
| 40 | + Options for `<attention_provider>` are: |
| 41 | + flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math, |
| 42 | + _native_math, sage, sage_varlen, _sage_qk_int8_pv_fp8_cuda, _sage_qk_int8_pv_fp8_cuda_sm90, |
| 43 | + _sage_qk_int8_pv_fp16_cuda, _sage_qk_int8_pv_fp16_triton, xformers |
| 44 | + """ |
| 45 | + |
| 46 | + attn_provider_training: List[AttentionProviderTraining] = None |
| 47 | + attn_provider_inference: List[AttentionProviderValidation] = None |
| 48 | + |
| 49 | + def add_args(self, parser: argparse.ArgumentParser) -> None: |
| 50 | + parser.add_argument( |
| 51 | + "--attn_provider_training", |
| 52 | + type=str, |
| 53 | + default=None, |
| 54 | + nargs="+", |
| 55 | + help="Attention provider for training. Must be a string of the form `<component_name>:<attention_provider>`.", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--attn_provider_inference", |
| 59 | + type=str, |
| 60 | + default=None, |
| 61 | + nargs="+", |
| 62 | + help="Attention provider for inference. Must be a string of the form `<component_name>:<attention_provider>`.", |
| 63 | + ) |
| 64 | + |
| 65 | + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): |
| 66 | + attn_training = argparse_args.attn_provider_training |
| 67 | + attn_inference = argparse_args.attn_provider_inference |
| 68 | + if attn_training is None: |
| 69 | + attn_training = [] |
| 70 | + if attn_inference is None: |
| 71 | + attn_inference = [] |
| 72 | + mapped_args.attn_provider_training = attn_training |
| 73 | + mapped_args.attn_provider_inference = attn_inference |
| 74 | + |
| 75 | + def validate_args(self, args: "BaseArgs"): |
| 76 | + pass |
| 77 | + |
| 78 | + def to_dict(self) -> Dict[str, Any]: |
| 79 | + return { |
| 80 | + "attn_provider_training": self.attn_provider_training, |
| 81 | + "attn_provider_inference": self.attn_provider_inference, |
| 82 | + } |
| 83 | + |
18 | 84 |
|
19 | 85 | class BaseArgs:
|
20 |
| - r""" |
| 86 | + """ |
21 | 87 | The arguments for the finetrainers training script.
|
22 | 88 |
|
23 | 89 | For helpful information about arguments, run `python train.py --help`.
|
@@ -314,16 +380,9 @@ class BaseArgs:
|
314 | 380 | vae_dtype: torch.dtype = torch.bfloat16
|
315 | 381 | layerwise_upcasting_modules: List[str] = []
|
316 | 382 | layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
|
317 |
| - layerwise_upcasting_skip_modules_pattern: List[str] = [ |
318 |
| - "patch_embed", |
319 |
| - "pos_embed", |
320 |
| - "x_embedder", |
321 |
| - "context_embedder", |
322 |
| - "time_embed", |
323 |
| - "^proj_in$", |
324 |
| - "^proj_out$", |
325 |
| - "norm", |
326 |
| - ] |
| 383 | + # fmt: off |
| 384 | + layerwise_upcasting_skip_modules_pattern: List[str] = ["patch_embed", "pos_embed", "x_embedder", "context_embedder", "time_embed", "^proj_in$", "^proj_out$", "norm"] |
| 385 | + # fmt: on |
327 | 386 |
|
328 | 387 | # Dataset arguments
|
329 | 388 | dataset_config: str = None
|
@@ -399,10 +458,21 @@ class BaseArgs:
|
399 | 458 | compile_modules: List[str] = []
|
400 | 459 | compile_scopes: List[str] = None
|
401 | 460 | allow_tf32: bool = False
|
402 |
| - float32_matmul_precision: Optional[str] = None |
| 461 | + float32_matmul_precision: str = "highest" |
403 | 462 |
|
404 |
| - # Additional registered arguments |
405 |
| - _registered_config_mixins: List[ConfigMixin] = [] |
| 463 | + # Attention provider arguments |
| 464 | + attention_provider_args: AttentionProviderArgs = AttentionProviderArgs() |
| 465 | + |
| 466 | + _registered_config_mixins: List[ArgsConfigMixin] = [] |
| 467 | + _arg_group_map: Dict[str, ArgsConfigMixin] = {} |
| 468 | + |
| 469 | + def __init__(self): |
| 470 | + self._arg_group_map: Dict[str, ArgsConfigMixin] = { |
| 471 | + "attention_provider_args": self.attention_provider_args, |
| 472 | + } |
| 473 | + |
| 474 | + for arg_config_mixin in self._arg_group_map.values(): |
| 475 | + self.register_args(arg_config_mixin) |
406 | 476 |
|
407 | 477 | def to_dict(self) -> Dict[str, Any]:
|
408 | 478 | parallel_arguments = {
|
@@ -545,7 +615,7 @@ def to_dict(self) -> Dict[str, Any]:
|
545 | 615 | "torch_config_arguments": torch_config_arguments,
|
546 | 616 | }
|
547 | 617 |
|
548 |
| - def register_args(self, config: ConfigMixin) -> None: |
| 618 | + def register_args(self, config: ArgsConfigMixin) -> None: |
549 | 619 | if not hasattr(self, "_extended_add_arguments"):
|
550 | 620 | self._extended_add_arguments = []
|
551 | 621 | self._extended_add_arguments.append((config.add_args, config.validate_args, config.map_args))
|
@@ -583,6 +653,25 @@ def parse_args(self):
|
583 | 653 |
|
584 | 654 | return mapped_args
|
585 | 655 |
|
| 656 | + def __getattribute__(self, name: str): |
| 657 | + try: |
| 658 | + return object.__getattribute__(self, name) |
| 659 | + except AttributeError: |
| 660 | + for arg_group in self._arg_group_map.values(): |
| 661 | + if hasattr(arg_group, name): |
| 662 | + return getattr(arg_group, name) |
| 663 | + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
| 664 | + |
| 665 | + def __setattr__(self, name: str, value: Any): |
| 666 | + if name in self.__dict__: |
| 667 | + object.__setattr__(self, name, value) |
| 668 | + return |
| 669 | + for arg_group in self._arg_group_map.values(): |
| 670 | + if hasattr(arg_group, name): |
| 671 | + setattr(arg_group, name, value) |
| 672 | + return |
| 673 | + object.__setattr__(self, name, value) |
| 674 | + |
586 | 675 |
|
587 | 676 | def _add_args(parser: argparse.ArgumentParser) -> None:
|
588 | 677 | _add_parallel_arguments(parser)
|
@@ -749,7 +838,7 @@ def _add_torch_config_arguments(parser: argparse.ArgumentParser) -> None:
|
749 | 838 | parser.add_argument(
|
750 | 839 | "--float32_matmul_precision",
|
751 | 840 | type=str,
|
752 |
| - default=None, |
| 841 | + default="highest", |
753 | 842 | choices=["highest", "high", "medium"],
|
754 | 843 | help="The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].",
|
755 | 844 | )
|
|
0 commit comments