|
| 1 | +from dataclasses import dataclass |
| 2 | +from functools import partial |
| 3 | +from typing import Literal |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | + |
| 8 | +from . import ( |
| 9 | + get_mlstm_kernel, |
| 10 | + get_mlstm_sequence_kernel, |
| 11 | + get_mlstm_step_kernel, |
| 12 | +) |
| 13 | +from .kernel_wrappers import ( |
| 14 | + wrap_chunkwise__arbitrary_sequence_length, |
| 15 | + wrap_chunkwise__pad_zeros, |
| 16 | +) |
| 17 | + |
| 18 | +ChunkwiseKernelType = Literal[ |
| 19 | + "chunkwise--native_autograd", |
| 20 | + "chunkwise--native_custbw", |
| 21 | + "chunkwise--triton_limit_chunk", |
| 22 | + "chunkwise--triton_xl_chunk", |
| 23 | + "parallel--native_autograd", |
| 24 | + "parallel--native_custbw", |
| 25 | + "parallel--native_stablef_autograd", |
| 26 | + "parallel--native_stablef_custbw", |
| 27 | + "parallel--triton_limit_headdim", |
| 28 | +] |
| 29 | +SequenceKernelType = Literal[ |
| 30 | + "native_sequence__native", "native_sequence__triton_step_fused" |
| 31 | +] |
| 32 | +StepKernelType = Literal["native", "triton_fused"] |
| 33 | + |
| 34 | +DtypeType = Literal["float32", "bfloat16", "float16"] |
| 35 | + |
| 36 | +BackendModeType = Literal["train", "train_with_padding", "inference"] |
| 37 | + |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class mLSTMBackendConfig: |
| 41 | + chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd" |
| 42 | + """The chunkwise kernel to use for chunkwise parallel processing of the sequence. |
| 43 | + This kernel is used for training. |
| 44 | + Also supports fully parallel (i.e. quadratic) backends for comparison. |
| 45 | + """ |
| 46 | + sequence_kernel: SequenceKernelType = "native_sequence__native" |
| 47 | + """The sequence kernel to use for processing sequneces step-by-step. |
| 48 | + Used only for parts of the prefill sequence in inference mode. |
| 49 | + """ |
| 50 | + step_kernel: StepKernelType = "native" |
| 51 | + """The step kernel to use for processing a single step. |
| 52 | + Used for generation in inference mode. |
| 53 | + """ |
| 54 | + mode: BackendModeType = "train" |
| 55 | + """The mode of operation for the backend. Determines how the `forward` method behaves. |
| 56 | + """ |
| 57 | + chunk_size: int = 64 |
| 58 | + """The chunk size of the chunkwise kernel. |
| 59 | + If the mode is 'train_with_padding', this is the inputs are padded to multiples of this size. |
| 60 | + """ |
| 61 | + return_last_states: bool = True |
| 62 | + """Whether to return the last states of the sequence in training mode. |
| 63 | + Inference mode always returns the last states. |
| 64 | + """ |
| 65 | + autocast_kernel_dtype: DtypeType = "bfloat16" |
| 66 | + """The dtype to use for autocast behavior in the kernel. |
| 67 | + If autocast is enabled all inputs are cast to this dtype before the kernel is called. |
| 68 | + """ |
| 69 | + eps: float = 1e-6 |
| 70 | + """Epsilon value for numerical stability in the kernel.""" |
| 71 | + inference_state_dtype: DtypeType = "float32" |
| 72 | + """The dtype to use for the state tensors in inference mode.""" |
| 73 | + |
| 74 | + def __post_init__(self): |
| 75 | + if self.return_last_states and "parallel" in self.chunkwise_kernel: |
| 76 | + raise ValueError( |
| 77 | + "return_last_states=True is not supported with parallel kernels." |
| 78 | + ) |
| 79 | + if self.return_last_states and self.mode == "train_with_padding": |
| 80 | + raise ValueError( |
| 81 | + "return_last_states=True is not supported with train_with_padding mode." |
| 82 | + ) |
| 83 | + |
| 84 | + |
| 85 | +class mLSTMBackend(nn.Module): |
| 86 | + """mLSTM Backend Module for PyTorch. |
| 87 | +
|
| 88 | + This module wraps the mLSTM kernels and provides a high-level interface for training and inference. |
| 89 | + """ |
| 90 | + |
| 91 | + config_class = mLSTMBackendConfig |
| 92 | + |
| 93 | + def __init__(self, config: mLSTMBackendConfig): |
| 94 | + super().__init__() |
| 95 | + self.config = config |
| 96 | + self.chunkwise_kernel_fn = get_mlstm_kernel(config.chunkwise_kernel) |
| 97 | + self.sequence_kernel_fn = get_mlstm_sequence_kernel(config.sequence_kernel) |
| 98 | + self.step_kernel_fn = get_mlstm_step_kernel(config.step_kernel) |
| 99 | + |
| 100 | + self._inference_fn = partial( |
| 101 | + wrap_chunkwise__arbitrary_sequence_length, |
| 102 | + mlstm_chunkwise_kernel=self.chunkwise_kernel_fn, |
| 103 | + mlstm_sequence_kernel=partial( |
| 104 | + self.sequence_kernel_fn, |
| 105 | + dtype_state=getattr(torch, config.inference_state_dtype), |
| 106 | + ), |
| 107 | + mlstm_step_kernel=partial( |
| 108 | + self.step_kernel_fn, |
| 109 | + dtype_state=getattr(torch, config.inference_state_dtype), |
| 110 | + ), |
| 111 | + chunk_size=config.chunk_size, |
| 112 | + eps=config.eps, |
| 113 | + autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), |
| 114 | + return_last_states=True, |
| 115 | + ) |
| 116 | + |
| 117 | + train_kernel_fn = partial( |
| 118 | + self.chunkwise_kernel_fn, |
| 119 | + autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), |
| 120 | + eps=config.eps, |
| 121 | + chunk_size=config.chunk_size, |
| 122 | + ) |
| 123 | + if "with_padding" in config.mode: |
| 124 | + train_kernel_fn = partial( |
| 125 | + wrap_chunkwise__pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn |
| 126 | + ) |
| 127 | + self._train_fn = train_kernel_fn |
| 128 | + |
| 129 | + def forward( |
| 130 | + self, |
| 131 | + q: torch.Tensor, |
| 132 | + k: torch.Tensor, |
| 133 | + v: torch.Tensor, |
| 134 | + i: torch.Tensor, |
| 135 | + f: torch.Tensor, |
| 136 | + c_initial: torch.Tensor = None, |
| 137 | + n_initial: torch.Tensor = None, |
| 138 | + m_initial: torch.Tensor = None, |
| 139 | + return_last_states: bool = None, |
| 140 | + mode: Literal["train", "inference"] = None, |
| 141 | + ) -> ( |
| 142 | + torch.Tensor |
| 143 | + | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] |
| 144 | + ): |
| 145 | + """Forward pass of the mLSTM backend. |
| 146 | +
|
| 147 | + Depending on the configured mode, this method will call the appropriate kernel function. |
| 148 | +
|
| 149 | + Args: |
| 150 | + q: The query tensor of shape (B, NH, S, DHQK). |
| 151 | + k: The key tensor of shape (B, NH, S, DHQK). |
| 152 | + v: The value tensor of shape (B, NH, S, DHHV). |
| 153 | + i: The input gate preactivation tensor of shape (B, NH, S). |
| 154 | + f: The forget gate preactivation tensor of shape (B, NH, S). |
| 155 | + c_initial: The initial cell state tensor of shape (B, NH, DHQK, DHHV). |
| 156 | + Defaults to None. |
| 157 | + n_initial: The initial hidden state tensor of shape (B, NH, DHQK). Defaults to None. |
| 158 | + m_initial: The initial memory tensor of shape (B, NH, 1). Defaults to None. |
| 159 | + return_last_states: Whether to return the last states of the sequence. Defaults to None. |
| 160 | + If None, the value from the config is used. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + hidden states of shape (B, NH, S, DHHV) |
| 164 | + hidden states and last states the last states are the cell state c (B, NH, DHQK, DHHV), |
| 165 | + the normalizer state n (B, NH, DHQK), and the max state m (B, NH, 1) |
| 166 | + """ |
| 167 | + if mode is None: |
| 168 | + mode = self.config.mode |
| 169 | + |
| 170 | + if "train" in mode: |
| 171 | + if return_last_states is None: |
| 172 | + return_last_states = self.config.return_last_states |
| 173 | + |
| 174 | + if self.config.mode == "train_with_padding": |
| 175 | + assert not return_last_states, "return_last_states=True is not supported with train_with_padding mode." |
| 176 | + |
| 177 | + return self._train_fn( |
| 178 | + q=q, |
| 179 | + k=k, |
| 180 | + v=v, |
| 181 | + i=i, |
| 182 | + f=f, |
| 183 | + c_initial=c_initial, |
| 184 | + n_initial=n_initial, |
| 185 | + m_initial=m_initial, |
| 186 | + return_last_states=return_last_states, |
| 187 | + ) |
| 188 | + |
| 189 | + elif "inference" in mode: |
| 190 | + # inference mode always returns the last states |
| 191 | + return self._inference_fn( |
| 192 | + q=q, |
| 193 | + k=k, |
| 194 | + v=v, |
| 195 | + i=i, |
| 196 | + f=f, |
| 197 | + c_initial=c_initial, |
| 198 | + n_initial=n_initial, |
| 199 | + m_initial=m_initial, |
| 200 | + ) |
| 201 | + else: |
| 202 | + raise ValueError(f"Unknown mode: {self.config.mode}") |
| 203 | + |
| 204 | + def extra_repr(self) -> str: |
| 205 | + return f"{self.config}" |
0 commit comments