Skip to content

Commit 3042dc8

Browse files
maximilianmbeckMaximilian Beck
and
Maximilian Beck
authored
feat: Early refactor merge (#38)
* feat: Add chunk param selection heuristic for xl_chunk size and align last states format (#33) * chore: rename. * wip: change state shape. * feat: add heuristic kernel parameter selection to xl_chunk_size kernel. * chore: update doc. * chore: improve readability. * feat: Add all sequence length kernel dispatcher (#34) * chore: rename. * wip: change state shape. * feat: add heuristic kernel parameter selection to xl_chunk_size kernel. * wip: add arbitrary sequence length wrapper + test. Test fails atm. * fix: arbitrary sequence length wrapper. * feat: make state dtype of trition step fused kernel configurable. * feat: add test for limit + xl chunk size kernels. * chore: improve doc. * fix: native sequence return states. * feat: enable single step in arbitrary sequence length. * chore: add mlstm sequence kernel registry. * chore: fix typo. * Merge branch 'refactor_further' into add_all_sequence_length_kernel_dispatcher * chore: prettify * chore: use tl.exp instead of tl.exp2 * const * chore: improve error message. * chore: control logging only via logger. * chore: fix merge * chore: simplify kernel call logic. * chore: add doc strings. * feat: Add pytorch backend module (#35) * chore: rename. * wip: change state shape. * feat: add heuristic kernel parameter selection to xl_chunk_size kernel. * wip: add arbitrary sequence length wrapper + test. Test fails atm. * fix: arbitrary sequence length wrapper. * feat: make state dtype of trition step fused kernel configurable. * feat: add test for limit + xl chunk size kernels. * chore: improve doc. * fix: native sequence return states. * feat: enable single step in arbitrary sequence length. * chore: add mlstm sequence kernel registry. * chore: fix typo. * feat: add preliminary backend module. * chore: support different state dtype in triton step kernel. Store outputs only from a single thread block, i.e. avoid duplicate store. * feat: fully support state dtype in recurrent sequence loop. * fix: handle the case where target chunk size is smaller than default block size. * feat: add pad zeros kernel wrapper. * feat: integrate pad with zeros into backend module. * chore: add doc to mlstm backend config. * chore: add doc to forward pass. * fix: typo. add doc. * chore: improve error message. * chore: clean up backend module. * fix: merge * chore: change number format. --------- Co-authored-by: Maximilian Beck <[email protected]> * chore: remove deprecated argument. --------- Co-authored-by: Maximilian Beck <[email protected]>
1 parent fa3eeee commit 3042dc8

35 files changed

+2210
-241
lines changed

.vscode/launch.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"justMyCode": false,
1818
"args": [
1919
// "-cn",
20-
"${workspaceFolder}/tests/test_mlstm/test_parallel/test_parallel_torch.py",
20+
"${workspaceFolder}/tests/torch/test_arbitrary_sequence_length.py",
2121
],
2222
"env": {
2323
"CUDA_VISIBLE_DEVICES": "0",

.vscode/settings.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"source.organizeImports": "explicit"
88
}
99
},
10-
"editor.formatOnSave": true,
10+
"editor.formatOnSave": false,
1111
"isort.args": ["--profile", "ruff"],
1212
"files.watcherExclude": {
1313
"outputs/**": true,

README.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ In this repository we collect clean implementations of the different mLSTM formu
88
def mlstm_interface(
99
q: torch.Tensor, # (B, NH, S, DHQK)
1010
k: torch.Tensor, # (B, NH, S, DHQK)
11-
v: torch.Tensor, # (B, NH, S, DHV)
11+
v: torch.Tensor, # (B, NH, S, DHHV)
1212
i: torch.Tensor, # (B, NH, S)
1313
f: torch.Tensor, # (B, NH, S)
14-
c_initial: torch.Tensor = None, # (B, NH, DHQK, DHV)
14+
c_initial: torch.Tensor = None, # (B, NH, DHQK, DHHV)
1515
n_initial: torch.Tensor = None, # (B, NH, DHQK)
16-
m_initial: torch.Tensor = None, # (B, NH) # TODO change the shape of this to (B, NH, 1)
16+
m_initial: torch.Tensor = None, # (B, NH, 1)
1717
return_last_states: bool = False,
1818
eps: float = 1e-6,
19-
autocast_kernel_dtype: torch.dtype = torch.float16,
19+
autocast_kernel_dtype: torch.dtype = torch.bfloat16,
2020
chunk_size: int = 64,
2121
**kwargs,
2222
) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
23-
# (B, NH, S, DHV) | ((B, NH, S, DHV), ((B, NH, DHQK, DHV), (B, NH, DHQK), (B, NH)))
23+
# (B, NH, S, DHHV) | ((B, NH, S, DHHV), ((B, NH, DHQK, DHHV), (B, NH, DHQK), (B, NH)))
2424
"""
2525
Returns:
2626
torch.Tensor: matH outputs (no n and m values, no last states)
@@ -35,17 +35,17 @@ def mlstm_interface(
3535
def mlstm_step_interface(
3636
q: torch.Tensor, # (B, NH, DHQK)
3737
k: torch.Tensor, # (B, NH, DHQK)
38-
v: torch.Tensor, # (B, NH, DHV)
38+
v: torch.Tensor, # (B, NH, DHHV)
3939
i: torch.Tensor, # (B, NH, 1)
4040
f: torch.Tensor, # (B, NH, 1)
41-
c: torch.Tensor, # (B, NH, DHQK, DHV)
41+
c: torch.Tensor, # (B, NH, DHQK, DHHV)
4242
n: torch.Tensor, # (B, NH, DHQK)
4343
m: torch.Tensor, # (B, NH, 1)
4444
eps: float = 1e-6,
4545
**kwargs,
4646
) -> tuple[
4747
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
48-
]: # vecH, (matC_state_new (B, NH, DHQK, DHV), vecN_state_new (B, NH, DHQK), vecM_state_new (B, NH, 1))
48+
]: # vecH, (matC_state_new (B, NH, DHQK, DHHV), vecN_state_new (B, NH, DHQK), vecM_state_new (B, NH, 1))
4949
```
5050

5151
## Kernel variants

mlstm_kernels/torch/__init__.py

+47-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ def _create_module_sequence_backend_registry() -> dict[str, dict[str, Callable]]
1111
}
1212
return module_backend_registry
1313

14+
1415
def get_available_mlstm_kernels() -> list[str]:
1516
"""
16-
Get a list of available mlstm sequence kernel names.
17+
Get a list of available mlstm sequence kernels.
18+
These kernels process a sequence in the parallel or chunkwise parallel mode of the mLSTM.
19+
They do not support arbitrary sequence lengths.
20+
They are used for training and prefill processing during inference of the mLSTM during.
1721
"""
1822
module_backend_registry = _create_module_sequence_backend_registry()
1923

@@ -24,11 +28,6 @@ def get_available_mlstm_kernels() -> list[str]:
2428
]
2529
return backend_names
2630

27-
def get_available_mlstm_step_kernels() -> list[str]:
28-
from .recurrent import registry_step as mlstm_recurrent_step_registry
29-
backend_names = list(mlstm_recurrent_step_registry.keys())
30-
return backend_names
31-
3231

3332
def get_mlstm_kernel(name: str) -> Callable:
3433
"""
@@ -54,12 +53,22 @@ def get_mlstm_kernel(name: str) -> Callable:
5453

5554
if backend_name not in module_backend_registry[module_name]:
5655
raise ValueError(
57-
f"Unknown backend name: {backend_name}. Available backend names: {list(module_backend_registry[module_name].keys())}"
56+
f"Unknown mlstm kernel backend name: {backend_name}. Available backend names: {list(module_backend_registry[module_name].keys())}"
5857
)
5958

6059
return module_backend_registry[module_name][backend_name]
6160

6261

62+
def get_available_mlstm_step_kernels() -> list[str]:
63+
"""Returns the available mlstm step kernels.
64+
These kernels can be used to compute a single time step of the mLSTM, i.e. for generation.
65+
"""
66+
from .recurrent import registry_step as mlstm_recurrent_step_registry
67+
68+
backend_names = list(mlstm_recurrent_step_registry.keys())
69+
return backend_names
70+
71+
6372
def get_mlstm_step_kernel(name: str) -> Callable:
6473
"""
6574
Get a mlstm step kernel function by name.
@@ -73,7 +82,36 @@ def get_mlstm_step_kernel(name: str) -> Callable:
7382

7483
if name not in mlstm_recurrent_step_registry:
7584
raise ValueError(
76-
f"Unknown backend name: {name}. Available backend names: {list(mlstm_recurrent_step_registry.keys())}"
85+
f"Unknown step kernel backend name: {name}. Available backend names: {list(mlstm_recurrent_step_registry.keys())}"
86+
)
87+
88+
return mlstm_recurrent_step_registry[name]
89+
90+
91+
def get_available_mlstm_sequence_kernels() -> list[str]:
92+
"""Returns the available mlstm sequence kernels.
93+
These kernels process a sequence in the recurrent mode of the mLSTM and hence support any sequence length.
94+
"""
95+
from .recurrent import registry_sequence as mlstm_recurrent_sequence_registry
96+
97+
backend_names = list(mlstm_recurrent_sequence_registry.keys())
98+
return backend_names
99+
100+
101+
def get_mlstm_sequence_kernel(name: str) -> Callable:
102+
"""
103+
Get a mlstm sequence kernel function by name.
104+
105+
Naming convention:
106+
name = "<backend_name>"
107+
108+
backend_name: The name of the kernel function as defined in the registry in the __init__.py file of the module.
109+
"""
110+
from .recurrent import registry_sequence as mlstm_recurrent_sequence_registry
111+
112+
if name not in mlstm_recurrent_sequence_registry:
113+
raise ValueError(
114+
f"Unknown backend name: {name}. Available backend names: {list(mlstm_recurrent_sequence_registry.keys())}"
77115
)
78116

79-
return mlstm_recurrent_step_registry[name]
117+
return mlstm_recurrent_sequence_registry[name]

mlstm_kernels/torch/backend_module.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from dataclasses import dataclass
2+
from functools import partial
3+
from typing import Literal
4+
5+
import torch
6+
from torch import nn
7+
8+
from . import (
9+
get_mlstm_kernel,
10+
get_mlstm_sequence_kernel,
11+
get_mlstm_step_kernel,
12+
)
13+
from .kernel_wrappers import (
14+
wrap_chunkwise__arbitrary_sequence_length,
15+
wrap_chunkwise__pad_zeros,
16+
)
17+
18+
ChunkwiseKernelType = Literal[
19+
"chunkwise--native_autograd",
20+
"chunkwise--native_custbw",
21+
"chunkwise--triton_limit_chunk",
22+
"chunkwise--triton_xl_chunk",
23+
"parallel--native_autograd",
24+
"parallel--native_custbw",
25+
"parallel--native_stablef_autograd",
26+
"parallel--native_stablef_custbw",
27+
"parallel--triton_limit_headdim",
28+
]
29+
SequenceKernelType = Literal[
30+
"native_sequence__native", "native_sequence__triton_step_fused"
31+
]
32+
StepKernelType = Literal["native", "triton_fused"]
33+
34+
DtypeType = Literal["float32", "bfloat16", "float16"]
35+
36+
BackendModeType = Literal["train", "train_with_padding", "inference"]
37+
38+
39+
@dataclass
40+
class mLSTMBackendConfig:
41+
chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd"
42+
"""The chunkwise kernel to use for chunkwise parallel processing of the sequence.
43+
This kernel is used for training.
44+
Also supports fully parallel (i.e. quadratic) backends for comparison.
45+
"""
46+
sequence_kernel: SequenceKernelType = "native_sequence__native"
47+
"""The sequence kernel to use for processing sequneces step-by-step.
48+
Used only for parts of the prefill sequence in inference mode.
49+
"""
50+
step_kernel: StepKernelType = "native"
51+
"""The step kernel to use for processing a single step.
52+
Used for generation in inference mode.
53+
"""
54+
mode: BackendModeType = "train"
55+
"""The mode of operation for the backend. Determines how the `forward` method behaves.
56+
"""
57+
chunk_size: int = 64
58+
"""The chunk size of the chunkwise kernel.
59+
If the mode is 'train_with_padding', this is the inputs are padded to multiples of this size.
60+
"""
61+
return_last_states: bool = True
62+
"""Whether to return the last states of the sequence in training mode.
63+
Inference mode always returns the last states.
64+
"""
65+
autocast_kernel_dtype: DtypeType = "bfloat16"
66+
"""The dtype to use for autocast behavior in the kernel.
67+
If autocast is enabled all inputs are cast to this dtype before the kernel is called.
68+
"""
69+
eps: float = 1e-6
70+
"""Epsilon value for numerical stability in the kernel."""
71+
inference_state_dtype: DtypeType = "float32"
72+
"""The dtype to use for the state tensors in inference mode."""
73+
74+
def __post_init__(self):
75+
if self.return_last_states and "parallel" in self.chunkwise_kernel:
76+
raise ValueError(
77+
"return_last_states=True is not supported with parallel kernels."
78+
)
79+
if self.return_last_states and self.mode == "train_with_padding":
80+
raise ValueError(
81+
"return_last_states=True is not supported with train_with_padding mode."
82+
)
83+
84+
85+
class mLSTMBackend(nn.Module):
86+
"""mLSTM Backend Module for PyTorch.
87+
88+
This module wraps the mLSTM kernels and provides a high-level interface for training and inference.
89+
"""
90+
91+
config_class = mLSTMBackendConfig
92+
93+
def __init__(self, config: mLSTMBackendConfig):
94+
super().__init__()
95+
self.config = config
96+
self.chunkwise_kernel_fn = get_mlstm_kernel(config.chunkwise_kernel)
97+
self.sequence_kernel_fn = get_mlstm_sequence_kernel(config.sequence_kernel)
98+
self.step_kernel_fn = get_mlstm_step_kernel(config.step_kernel)
99+
100+
self._inference_fn = partial(
101+
wrap_chunkwise__arbitrary_sequence_length,
102+
mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
103+
mlstm_sequence_kernel=partial(
104+
self.sequence_kernel_fn,
105+
dtype_state=getattr(torch, config.inference_state_dtype),
106+
),
107+
mlstm_step_kernel=partial(
108+
self.step_kernel_fn,
109+
dtype_state=getattr(torch, config.inference_state_dtype),
110+
),
111+
chunk_size=config.chunk_size,
112+
eps=config.eps,
113+
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
114+
return_last_states=True,
115+
)
116+
117+
train_kernel_fn = partial(
118+
self.chunkwise_kernel_fn,
119+
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
120+
eps=config.eps,
121+
chunk_size=config.chunk_size,
122+
)
123+
if "with_padding" in config.mode:
124+
train_kernel_fn = partial(
125+
wrap_chunkwise__pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn
126+
)
127+
self._train_fn = train_kernel_fn
128+
129+
def forward(
130+
self,
131+
q: torch.Tensor,
132+
k: torch.Tensor,
133+
v: torch.Tensor,
134+
i: torch.Tensor,
135+
f: torch.Tensor,
136+
c_initial: torch.Tensor = None,
137+
n_initial: torch.Tensor = None,
138+
m_initial: torch.Tensor = None,
139+
return_last_states: bool = None,
140+
mode: Literal["train", "inference"] = None,
141+
) -> (
142+
torch.Tensor
143+
| tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
144+
):
145+
"""Forward pass of the mLSTM backend.
146+
147+
Depending on the configured mode, this method will call the appropriate kernel function.
148+
149+
Args:
150+
q: The query tensor of shape (B, NH, S, DHQK).
151+
k: The key tensor of shape (B, NH, S, DHQK).
152+
v: The value tensor of shape (B, NH, S, DHHV).
153+
i: The input gate preactivation tensor of shape (B, NH, S).
154+
f: The forget gate preactivation tensor of shape (B, NH, S).
155+
c_initial: The initial cell state tensor of shape (B, NH, DHQK, DHHV).
156+
Defaults to None.
157+
n_initial: The initial hidden state tensor of shape (B, NH, DHQK). Defaults to None.
158+
m_initial: The initial memory tensor of shape (B, NH, 1). Defaults to None.
159+
return_last_states: Whether to return the last states of the sequence. Defaults to None.
160+
If None, the value from the config is used.
161+
162+
Returns:
163+
hidden states of shape (B, NH, S, DHHV)
164+
hidden states and last states the last states are the cell state c (B, NH, DHQK, DHHV),
165+
the normalizer state n (B, NH, DHQK), and the max state m (B, NH, 1)
166+
"""
167+
if mode is None:
168+
mode = self.config.mode
169+
170+
if "train" in mode:
171+
if return_last_states is None:
172+
return_last_states = self.config.return_last_states
173+
174+
if self.config.mode == "train_with_padding":
175+
assert not return_last_states, "return_last_states=True is not supported with train_with_padding mode."
176+
177+
return self._train_fn(
178+
q=q,
179+
k=k,
180+
v=v,
181+
i=i,
182+
f=f,
183+
c_initial=c_initial,
184+
n_initial=n_initial,
185+
m_initial=m_initial,
186+
return_last_states=return_last_states,
187+
)
188+
189+
elif "inference" in mode:
190+
# inference mode always returns the last states
191+
return self._inference_fn(
192+
q=q,
193+
k=k,
194+
v=v,
195+
i=i,
196+
f=f,
197+
c_initial=c_initial,
198+
n_initial=n_initial,
199+
m_initial=m_initial,
200+
)
201+
else:
202+
raise ValueError(f"Unknown mode: {self.config.mode}")
203+
204+
def extra_repr(self) -> str:
205+
return f"{self.config}"

0 commit comments

Comments
 (0)