Skip to content

Commit 8222d3f

Browse files
authored
Support flash/flex/xformers/sage attention (#377)
* support attention providers for training/inference: flash attn, flex attn, native, xformers; for inference only: sage; refactor ConfigMixin for arguments * update * fix for sdpa replacement; fix for backward pass * fix flash-attn shape when not using varlen; remove contiguous for now; remove custom block mask code * update docs * add basic tests * update arg name * update docs * more doc updates * dispatcher fixes * add back context manager for external use; make style * update date
1 parent 3c583bf commit 8222d3f

File tree

20 files changed

+1572
-148
lines changed

20 files changed

+1572
-148
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
5757
- DDP, FSDP-2 & HSDP support for all models
5858
- LoRA and full-rank finetuning; Conditional Control training
5959
- Memory-efficient single-GPU training
60+
- Multiple attention backends supported - `flash`, `flex`, `sage`, `xformers` (see [attention](./docs/models/attention.md) docs)
6061
- Auto-detection of commonly used dataset formats
6162
- Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more
6263
- Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets
@@ -65,6 +66,8 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
6566

6667
## News
6768

69+
- 🔥 **2025-04-25**: Support for different attention providers added!
70+
- 🔥 **2025-04-21**: Wan I2V supported added!
6871
- 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan!
6972
- 🔥 **2025-04-08**: `torch.compile` support added!
7073
- 🔥 **2025-04-06**: Flux support added!

docs/args.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,26 @@ float32_matmul_precision (`str`, defaults to `highest`):
270270
The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].
271271
```
272272

273+
### Attention Provider
274+
275+
These arguments are relevant to setting attention provider for different modeling components. The attention providers may be set differently for training and validation/inference.
276+
277+
```
278+
attn_provider_training (`str`, defaults to "native"):
279+
The attention provider to use for training. Choose between
280+
[
281+
'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
282+
'_native_math'
283+
]
284+
attn_provider_inference (`str`, defaults to "native"):
285+
The attention provider to use for validation. Choose between
286+
[
287+
'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
288+
'_native_math', 'sage', 'sage_varlen', '_sage_qk_int8_pv_fp8_cuda', '_sage_qk_int8_pv_fp8_cuda_sm90',
289+
'_sage_qk_int8_pv_fp16_cuda', '_sage_qk_int8_pv_fp16_triton', 'xformers'
290+
]
291+
```
292+
273293
## SFT training
274294

275295
If using `--training_type lora`, these arguments can be specified.

docs/environment.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ NVIDIA A100-SXM4-80GB, 81920 MiB
2626
```
2727

2828
Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
29+
30+
31+
## Configuration
32+
33+
The following environment variables may be configured to change the default behaviour of finetrainers:
34+
35+
`FINETRAINERS_ATTN_PROVIDER`: Sets the default attention provider for training/validation. Defaults to `native`, as in native PyTorch SDPA. See [attention docs](./models/attention.md) for more information.
36+
`FINETRAINERS_ATTN_CHECKS`: Whether or not to run basic sanity checks when using different attention providers. This is useful for debugging but you should leave it disabled for longer training runs. Defaults to `"0"`. Can be set to a truthy env value.

docs/models/attention.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Attention backends
2+
3+
Finetrainers supports multiple attention backends to support different hardware and tradeoff between speed and memory usage. The following attention implementations are supported:
4+
- Training:
5+
- If model uses attention masks: `flash_varlen`, `flex`, `native`
6+
- If model does not use attention masks: `flash`, `flex`, `native`, `xformers`
7+
- Inference:
8+
- If model uses attention masks: `flash_varlen`, `flex`, `native`, `sage_varlen`
9+
- If model does not use attention masks: `flash`, `flash_varlen`, `flex`, `native`, `sage`, `sage_varlen`, `xformers`
10+
11+
Additionally, some specialized methods are available for debugging-specific purposes: `_native_cudnn`, `_native_efficient`, `_native_flash`, `_native_math`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`. With time, more attention-specific optimizations and custom implementations will be supported. Contributions are welcome!
12+
13+
Unfortunately, due to limited time for testing, only specific versions of packages that provide these implementations are supported. Other versions may work. The supported versions will be gradually made lower for more flexibility, but for now, please use the following versions:
14+
- `flash-attn>=2.6.3`
15+
- `sageattention>=2.1.1`
16+
- `xformers>=0.0.29.post3`
17+
18+
This guide will help you quickly install flash-attn, sageattention, and xformers to make your models run faster and use less memory for training/inference. We'll cover installation on Linux (Ubuntu 22.04) and Windows (using WSL).
19+
20+
Before you start, make sure to use a clean python virtual environment to not mess up your system seriously, or to avoid conflicting dependencies leading to failed installations which might leave the environment in hard-to-recover state.
21+
22+
### Flash attention
23+
24+
Providers covered: `flash`, `flash_varlen`
25+
26+
The installation steps have only been tested with Ubuntu 22.04; CUDA version higher than 12.2 and 12.6.
27+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
28+
- You might need the following packages: `pip install packaging ninja`
29+
- Linux: Run: `pip install flash-attn --no-build-isolation`. Verify the version with `pip show flash-attn`
30+
- WSL: Same instruction as above should work. Native Windows might require building from source - check community guiders and follow the instruction [here](https://github.com/Dao-AILab/flash-attention).
31+
32+
### Sage attention
33+
34+
Providers covered: `sage`, `sage_varlen`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`
35+
36+
FP8 implementations will require CUDA compute capability of 90 or higher (H100, RTX 5090, etc.). Some may work on compute capability 89 as well (RTX 4090, for example). For FP16 implementations, compute capability of atleast 80 is required (A100, RTX 3090, etc.). For other GPUs, FP16 implementations may or may not work (this is untested by me).
37+
38+
- Check your compute capability with the following command:
39+
```bash
40+
python -c "import torch; print(torch.cuda.get_device_capability())"
41+
```
42+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
43+
- You might need the following packages: `pip install triton`. For Windows, check out the [triton-windows](https://github.com/woct0rdho/triton-windows) project.
44+
- Linux/WSL: Run: `pip install git+https://github.com/thu-ml/SageAttention`. Verify the version with `pip show sageattention`.
45+
- Make sure to look at the official installation guide in [SageAttention](https://github.com/thu-ml/SageAttention) too!
46+
47+
### xformers
48+
49+
Providers covered: `xformers`
50+
51+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
52+
- Linux/WSL: Run: `pip install -U xformers --index-url https://download.pytorch.org/whl/cu126` (assuming CUDA 12.6). Verify the version with `pip show xformers`.
53+
- Make sure to look at the official installation guide in [xformers](https://github.com/facebookresearch/xformers) too!
54+
55+
----------
56+
57+
All other providers are either native PyTorch implementations or require a specific PyTorch version (for example, Flex Attention requires torch version of atleast 2.5.0).
58+
59+
----------
60+
61+
## Usage
62+
63+
There are two ways to use the attention dispatcher mechanism:
64+
- Replace `scaled_dot_product_attention` globally:
65+
```python
66+
import torch.nn.functional as F
67+
from finetrainers.models.attention_dispatch import attention_dispatch
68+
69+
F.scaled_dot_product_attention = attention_dispatch
70+
```
71+
- Replace all occurrences of `scaled_dot_product_attention` in your code with `attention_dispatch`.
72+
73+
```python
74+
# Use dispatcher directly
75+
from finetrainers.models.attention_dispatch import attention_provider, AttentionProvider
76+
77+
with attention_provider(AttentionProvider.FLASH_VARLEN):
78+
model(...)
79+
80+
# or,
81+
with attention_provider("sage_varlen"):
82+
model(...)
83+
```
84+
85+
## Context Parallel
86+
87+
TODO

finetrainers/args.py

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,88 @@
22
import os
33
import pathlib
44
import sys
5-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Literal, Optional, Union
66

77
import torch
88

99
from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType
1010
from .logging import get_logger
1111
from .parallel import ParallelBackendEnum
12-
from .trainer.config_utils import ConfigMixin
13-
from .utils import get_non_null_items
12+
from .utils import ArgsConfigMixin, get_non_null_items
1413

1514

1615
logger = get_logger()
1716

17+
# fmt: off
18+
# Must match src/finetrainers/models/attention_dispatch.py
19+
AttentionProviderTraining = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "xformers"]
20+
AttentionProviderValidation = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "sage", "sage_varlen", "_sage_qk_int8_pv_fp8_cuda", "_sage_qk_int8_pv_fp8_cuda_sm90", "_sage_qk_int8_pv_fp16_cuda", "_sage_qk_int8_pv_fp16_triton", "xformers"]
21+
22+
# We do a union because every ArgsConfigMixin registered to BaseArgs can be looked up using the `__getattribute__` override
23+
BaseArgsType = Union["BaseArgs", "AttentionProviderArgs"]
24+
# fmt: on
25+
26+
27+
class AttentionProviderArgs(ArgsConfigMixin):
28+
"""
29+
Args:
30+
attn_provider_training (`List[str]`, defaults to `None`):
31+
Must be a string of the form `"<component_name>:<attention_provider>"`. For example, if you want to use
32+
flash varlen attention implementation on the `transformer` module, you can set this argument to
33+
`"transformer:flash_varlen"`. The attention provider will be used for both training and validation.
34+
Options for `<attention_provider>` are:
35+
flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math, xformers
36+
attn_provider_inference (`List[str]`, defaults to `None`):
37+
Must be a string of the form `"<component_name>:<attention_provider>"`. For example, if you want to use
38+
flash varlen attention implementation on the `transformer` module, you can set this argument to
39+
`"transformer:flash_varlen"`. The attention provider will be used for both training and validation.
40+
Options for `<attention_provider>` are:
41+
flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math,
42+
_native_math, sage, sage_varlen, _sage_qk_int8_pv_fp8_cuda, _sage_qk_int8_pv_fp8_cuda_sm90,
43+
_sage_qk_int8_pv_fp16_cuda, _sage_qk_int8_pv_fp16_triton, xformers
44+
"""
45+
46+
attn_provider_training: List[AttentionProviderTraining] = None
47+
attn_provider_inference: List[AttentionProviderValidation] = None
48+
49+
def add_args(self, parser: argparse.ArgumentParser) -> None:
50+
parser.add_argument(
51+
"--attn_provider_training",
52+
type=str,
53+
default=None,
54+
nargs="+",
55+
help="Attention provider for training. Must be a string of the form `<component_name>:<attention_provider>`.",
56+
)
57+
parser.add_argument(
58+
"--attn_provider_inference",
59+
type=str,
60+
default=None,
61+
nargs="+",
62+
help="Attention provider for inference. Must be a string of the form `<component_name>:<attention_provider>`.",
63+
)
64+
65+
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
66+
attn_training = argparse_args.attn_provider_training
67+
attn_inference = argparse_args.attn_provider_inference
68+
if attn_training is None:
69+
attn_training = []
70+
if attn_inference is None:
71+
attn_inference = []
72+
mapped_args.attn_provider_training = attn_training
73+
mapped_args.attn_provider_inference = attn_inference
74+
75+
def validate_args(self, args: "BaseArgs"):
76+
pass
77+
78+
def to_dict(self) -> Dict[str, Any]:
79+
return {
80+
"attn_provider_training": self.attn_provider_training,
81+
"attn_provider_inference": self.attn_provider_inference,
82+
}
83+
1884

1985
class BaseArgs:
20-
r"""
86+
"""
2187
The arguments for the finetrainers training script.
2288
2389
For helpful information about arguments, run `python train.py --help`.
@@ -314,16 +380,9 @@ class BaseArgs:
314380
vae_dtype: torch.dtype = torch.bfloat16
315381
layerwise_upcasting_modules: List[str] = []
316382
layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
317-
layerwise_upcasting_skip_modules_pattern: List[str] = [
318-
"patch_embed",
319-
"pos_embed",
320-
"x_embedder",
321-
"context_embedder",
322-
"time_embed",
323-
"^proj_in$",
324-
"^proj_out$",
325-
"norm",
326-
]
383+
# fmt: off
384+
layerwise_upcasting_skip_modules_pattern: List[str] = ["patch_embed", "pos_embed", "x_embedder", "context_embedder", "time_embed", "^proj_in$", "^proj_out$", "norm"]
385+
# fmt: on
327386

328387
# Dataset arguments
329388
dataset_config: str = None
@@ -399,10 +458,21 @@ class BaseArgs:
399458
compile_modules: List[str] = []
400459
compile_scopes: List[str] = None
401460
allow_tf32: bool = False
402-
float32_matmul_precision: Optional[str] = None
461+
float32_matmul_precision: str = "highest"
403462

404-
# Additional registered arguments
405-
_registered_config_mixins: List[ConfigMixin] = []
463+
# Attention provider arguments
464+
attention_provider_args: AttentionProviderArgs = AttentionProviderArgs()
465+
466+
_registered_config_mixins: List[ArgsConfigMixin] = []
467+
_arg_group_map: Dict[str, ArgsConfigMixin] = {}
468+
469+
def __init__(self):
470+
self._arg_group_map: Dict[str, ArgsConfigMixin] = {
471+
"attention_provider_args": self.attention_provider_args,
472+
}
473+
474+
for arg_config_mixin in self._arg_group_map.values():
475+
self.register_args(arg_config_mixin)
406476

407477
def to_dict(self) -> Dict[str, Any]:
408478
parallel_arguments = {
@@ -545,7 +615,7 @@ def to_dict(self) -> Dict[str, Any]:
545615
"torch_config_arguments": torch_config_arguments,
546616
}
547617

548-
def register_args(self, config: ConfigMixin) -> None:
618+
def register_args(self, config: ArgsConfigMixin) -> None:
549619
if not hasattr(self, "_extended_add_arguments"):
550620
self._extended_add_arguments = []
551621
self._extended_add_arguments.append((config.add_args, config.validate_args, config.map_args))
@@ -583,6 +653,25 @@ def parse_args(self):
583653

584654
return mapped_args
585655

656+
def __getattribute__(self, name: str):
657+
try:
658+
return object.__getattribute__(self, name)
659+
except AttributeError:
660+
for arg_group in self._arg_group_map.values():
661+
if hasattr(arg_group, name):
662+
return getattr(arg_group, name)
663+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
664+
665+
def __setattr__(self, name: str, value: Any):
666+
if name in self.__dict__:
667+
object.__setattr__(self, name, value)
668+
return
669+
for arg_group in self._arg_group_map.values():
670+
if hasattr(arg_group, name):
671+
setattr(arg_group, name, value)
672+
return
673+
object.__setattr__(self, name, value)
674+
586675

587676
def _add_args(parser: argparse.ArgumentParser) -> None:
588677
_add_parallel_arguments(parser)
@@ -749,7 +838,7 @@ def _add_torch_config_arguments(parser: argparse.ArgumentParser) -> None:
749838
parser.add_argument(
750839
"--float32_matmul_precision",
751840
type=str,
752-
default=None,
841+
default="highest",
753842
choices=["highest", "high", "medium"],
754843
help="The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].",
755844
)

finetrainers/constants.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import os
22

33

4+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
5+
6+
FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
7+
FINETRAINERS_ATTN_PROVIDER = os.environ.get("FINETRAINERS_ATTN_PROVIDER", "native")
8+
FINETRAINERS_ATTN_CHECKS = os.getenv("FINETRAINERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
9+
410
DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
511
DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
612
DEFAULT_FRAME_BUCKETS = [49]
@@ -16,9 +22,6 @@
1622
for width in DEFAULT_WIDTH_BUCKETS:
1723
DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width))
1824

19-
20-
FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
21-
2225
PRECOMPUTED_DIR_NAME = "precomputed"
2326
PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
2427
PRECOMPUTED_LATENTS_DIR_NAME = "latents"

finetrainers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .attention_dispatch import AttentionProvider, attention_dispatch, attention_provider
12
from .modeling_utils import ControlModelSpecification, ModelSpecification

0 commit comments

Comments
 (0)