Skip to content

Commit 40fa36e

Browse files
committed
Add seqused_q in fwd / bwd and seqused_k in bwd.
1 parent c92ca63 commit 40fa36e

12 files changed

+163
-35
lines changed

flash_attn/bert_padding.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,32 +95,40 @@ def backward(ctx, grad_output, grad_residual):
9595
index_first_axis_residual = IndexFirstAxisResidual.apply
9696

9797

98-
def unpad_input(hidden_states, attention_mask):
98+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
9999
"""
100100
Arguments:
101101
hidden_states: (batch, seqlen, ...)
102102
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
103104
Return:
104-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
105-
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
105+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
106+
indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence.
106107
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
107108
max_seqlen_in_batch: int
109+
seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None.
108110
"""
109-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
110-
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
111+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
112+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
113+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
111115
max_seqlen_in_batch = seqlens_in_batch.max().item()
112116
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
113117
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114118
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115119
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116120
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117121
# so we write custom forward and backward to make it a bit faster.
118-
return (
122+
res = (
119123
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
120124
indices,
121125
cu_seqlens,
122126
max_seqlen_in_batch,
123127
)
128+
if unused_mask is not None:
129+
return res + (used_seqlens_in_batch, )
130+
else:
131+
return res
124132

125133

126134
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):

hopper/epilogue_bwd_sm90_tma.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd {
8080
Element* ptr_dV;
8181
StridedKV const stride_dV;
8282
int const* cu_seqlens = nullptr;
83+
int const* seqused = nullptr;
8384
};
8485

8586
// Device side kernel params
@@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd {
9192
StridedKV const stride_dV;
9293
TMA_dKV tma_store_dK, tma_store_dV;
9394
int const* cu_seqlens = nullptr;
95+
int const* seqused = nullptr;
9496
};
9597

9698
static Params
@@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd {
113115
select<1, 2>(TileShape_MNK{}),
114116
_1{}); // no mcast for dKV
115117
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
116-
tma_store_dK, tma_store_dV, args.cu_seqlens};
118+
tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
117119
}
118120

119121
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd {
185187
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
186188
bool const is_varlen = params.cu_seqlens != nullptr;
187189
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
188-
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
190+
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (
191+
params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]
192+
);
189193

190194
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
191195
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
@@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd {
236240
auto [n_block, bidh, bidb] = block_coord;
237241
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
238242
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
239-
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset;
243+
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);
240244

241245
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
242246
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)

hopper/flash.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params {
6868
int * __restrict__ cu_seqlens_q;
6969
int * __restrict__ cu_seqlens_k;
7070

71-
// If provided, the actual length of each k sequence.
71+
// If provided, the actual length of each q / o sequence.
72+
int * __restrict__ seqused_q;
73+
// If provided, the actual length of each k / v sequence.
7274
int * __restrict__ seqused_k;
7375

7476
int *__restrict__ blockmask;

hopper/flash_api.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params,
3636
at::Tensor out,
3737
void *cu_seqlens_q_d,
3838
void *cu_seqlens_k_d,
39+
void *seqused_q,
3940
void *seqused_k,
4041
void *p_d,
4142
void *softmax_lse_d,
@@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params &params,
8081

8182
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
8283
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
84+
params.seqused_q = static_cast<int *>(seqused_q);
8385
params.seqused_k = static_cast<int *>(seqused_k);
8486

8587
TORCH_CHECK(
@@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params &params,
171173
at::Tensor dv,
172174
void *cu_seqlens_q_d,
173175
void *cu_seqlens_k_d,
176+
void *seqused_q,
177+
void *seqused_k,
174178
void *dq_accum_d,
175179
void *dk_accum_d,
176180
void *dv_accum_d,
@@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params &params,
187191
q, k, v, out,
188192
cu_seqlens_q_d,
189193
cu_seqlens_k_d,
190-
nullptr,
194+
seqused_q,
195+
seqused_k,
191196
nullptr,
192197
softmax_lse_d,
193198
p_dropout,
@@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
364369
q_padded, k_padded, v_padded, out,
365370
/*cu_seqlens_q_d=*/nullptr,
366371
/*cu_seqlens_k_d=*/nullptr,
372+
/*seqused_q=*/nullptr,
367373
/*seqused_k=*/nullptr,
368374
nullptr,
369375
softmax_lse.data_ptr(),
@@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
426432
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
427433
const at::Tensor &cu_seqlens_q, // b+1
428434
const at::Tensor &cu_seqlens_k, // b+1
435+
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
429436
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
430437
int max_seqlen_q,
431438
const int max_seqlen_k,
@@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
482489
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
483490

484491
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
492+
if (seqused_q.has_value()){
493+
auto seqused_q_ = seqused_q.value();
494+
TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
495+
TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
496+
TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
497+
CHECK_SHAPE(seqused_q_, batch_size);
498+
}
499+
485500
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
486501
if (seqused_k.has_value()){
487502
auto seqused_k_ = seqused_k.value();
@@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
537552
q_padded, k_padded, v_padded, out,
538553
cu_seqlens_q_d,
539554
cu_seqlens_k.data_ptr(),
555+
seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
540556
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
541557
/*p_d=*/nullptr,
542558
softmax_lse.data_ptr(),
@@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
730746
head_size, head_size_rounded,
731747
q, k, v, out,
732748
dout_padded, dq, dk_expanded, dv_expanded,
733-
nullptr,
734-
nullptr,
749+
/*cu_seqlens_q_d=*/nullptr,
750+
/*cu_seqlens_k_d=*/nullptr,
751+
/*seqused_q=*/nullptr,
752+
/*seqused_k=*/nullptr,
735753
dq_accum.data_ptr(),
736754
// loop ? dk_accum.data_ptr() : nullptr,
737755
// loop ? dv_accum.data_ptr() : nullptr,
@@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
787805
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
788806
const at::Tensor &cu_seqlens_q, // b+1
789807
const at::Tensor &cu_seqlens_k, // b+1
808+
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
809+
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
790810
const int max_seqlen_q,
791811
const int max_seqlen_k, // max sequence length to choose the kernel
792812
const float softmax_scale,
@@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
854874
CHECK_SHAPE(out, total_q, num_heads, head_size);
855875
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
856876
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
877+
if (seqused_q.has_value()){
878+
auto seqused_q_ = seqused_q.value();
879+
TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
880+
TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
881+
TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
882+
CHECK_SHAPE(seqused_q_, batch_size);
883+
}
884+
857885
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
886+
if (seqused_k.has_value()){
887+
auto seqused_k_ = seqused_k.value();
888+
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
889+
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
890+
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
891+
CHECK_SHAPE(seqused_k_, batch_size);
892+
}
858893

859894
at::Tensor dq, dk, dv;
860895
if (dq_.has_value()) {
@@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
927962
dout_padded, dq, dk_expanded, dv_expanded,
928963
cu_seqlens_q.data_ptr(),
929964
cu_seqlens_k.data_ptr(),
965+
seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
966+
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
930967
dq_accum.data_ptr(),
931968
// loop ? dk_accum.data_ptr() : nullptr,
932969
// loop ? dv_accum.data_ptr() : nullptr,

hopper/flash_attn_interface.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def _flash_attn_varlen_forward(
7272
max_seqlen_k,
7373
softmax_scale,
7474
causal,
75+
seqused_q=None,
76+
seqused_k=None,
7577
):
7678
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
7779
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -82,7 +84,8 @@ def _flash_attn_varlen_forward(
8284
None,
8385
cu_seqlens_q,
8486
cu_seqlens_k,
85-
None,
87+
seqused_q,
88+
seqused_k,
8689
max_seqlen_q,
8790
max_seqlen_k,
8891
softmax_scale,
@@ -110,6 +113,8 @@ def _flash_attn_varlen_backward(
110113
softmax_scale,
111114
causal,
112115
deterministic=False,
116+
seqused_q=None,
117+
seqused_k=None,
113118
):
114119
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
115120
# dq, dk, dv are allocated by us so they should already be contiguous
@@ -132,6 +137,8 @@ def _flash_attn_varlen_backward(
132137
dv,
133138
cu_seqlens_q,
134139
cu_seqlens_k,
140+
seqused_q,
141+
seqused_k,
135142
max_seqlen_q,
136143
max_seqlen_k,
137144
softmax_scale,
@@ -207,6 +214,8 @@ def forward(
207214
softmax_scale,
208215
causal,
209216
deterministic=False,
217+
seqused_q=None,
218+
seqused_k=None,
210219
):
211220
if softmax_scale is None:
212221
softmax_scale = q.shape[-1] ** (-0.5)
@@ -220,9 +229,12 @@ def forward(
220229
max_seqlen_k,
221230
softmax_scale,
222231
causal=causal,
232+
seqused_q=seqused_q,
233+
seqused_k=seqused_k,
223234
)
224235
ctx.save_for_backward(
225-
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
236+
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
237+
seqused_q, seqused_k
226238
)
227239
ctx.max_seqlen_q = max_seqlen_q
228240
ctx.max_seqlen_k = max_seqlen_k
@@ -233,7 +245,7 @@ def forward(
233245

234246
@staticmethod
235247
def backward(ctx, dout, *args):
236-
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
248+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
237249
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
238250
_flash_attn_varlen_backward(
239251
dout,
@@ -252,11 +264,13 @@ def backward(ctx, dout, *args):
252264
ctx.softmax_scale,
253265
ctx.causal,
254266
ctx.deterministic,
267+
seqused_q,
268+
seqused_k,
255269
)
256270
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
257271
dk = dk[..., : dout.shape[-1]]
258272
dv = dv[..., : dout.shape[-1]]
259-
return dq, dk, dv, None, None, None, None, None, None, None
273+
return dq, dk, dv, None, None, None, None, None, None, None, None, None
260274

261275

262276
def flash_attn_func(
@@ -336,6 +350,8 @@ def flash_attn_varlen_func(
336350
softmax_scale=None,
337351
causal=False,
338352
deterministic=False,
353+
seqused_q=None,
354+
seqused_k=None,
339355
):
340356
"""
341357
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -366,6 +382,10 @@ def flash_attn_varlen_func(
366382
softmax_scale: float. The scaling of QK^T before applying softmax.
367383
Default to 1 / sqrt(headdim).
368384
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
385+
seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
386+
query and output tokens in each sequence.
387+
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
388+
key and value tokens in each sequence.
369389
Return:
370390
out: (total, nheads, headdim).
371391
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
@@ -383,4 +403,6 @@ def flash_attn_varlen_func(
383403
softmax_scale,
384404
causal,
385405
deterministic,
406+
seqused_q,
407+
seqused_k,
386408
)

hopper/flash_bwd_launch_template.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
4545
{params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ
4646
params.b,
4747
params.dq_semaphore,
48-
params.cu_seqlens_q
48+
params.cu_seqlens_q,
49+
params.seqused_q
4950
};
5051
typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
5152
int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
@@ -87,6 +88,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
8788
params.b,
8889
params.dq_semaphore,
8990
params.cu_seqlens_q, params.cu_seqlens_k,
91+
params.seqused_q, params.seqused_k
9092
};
9193
typename CollectiveEpilogue::Arguments epilogue_args {
9294
static_cast<Element*>(params.dk_ptr),
@@ -146,7 +148,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
146148
{!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ
147149
{params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
148150
params.scale_softmax,
149-
params.cu_seqlens_q
151+
params.cu_seqlens_q,
152+
params.seqused_q
150153
};
151154
typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
152155
int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));

0 commit comments

Comments
 (0)