Skip to content

Commit d835e6f

Browse files
authored
refactor: move triton dependency to flashinfer.triton (#918)
Some platforms do not support triton but user still need other functionlities (e.g. JIT) in flashinfer, this PR moves triton dependency to flashinfer.triton and defer the import so that user can still use flashinfer without installing triton.
1 parent bf2fdc5 commit d835e6f

File tree

4 files changed

+152
-111
lines changed

4 files changed

+152
-111
lines changed

flashinfer/gemm.py

+4-88
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from typing import Optional
1919

2020
import torch
21-
import triton
22-
import triton.language as tl
2321

2422
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
2523
from .utils import (
@@ -223,92 +221,6 @@ def _fake_cutlass_segment_gemm_sm90(
223221
return _gemm_module_sm90
224222

225223

226-
@triton.jit
227-
def compute_sm80_group_gemm_args(
228-
all_problems_ptr,
229-
x_ptr,
230-
w_ptr,
231-
y_ptr,
232-
x_ld_ptr,
233-
w_ld_ptr,
234-
y_ld_ptr,
235-
x,
236-
w,
237-
y,
238-
xy_indptr,
239-
w_indices,
240-
d_in,
241-
d_out,
242-
w_column_major,
243-
):
244-
245-
pid = tl.program_id(0)
246-
247-
m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid)
248-
k, n = d_in, d_out
249-
250-
tl.store(all_problems_ptr + pid * 3, m)
251-
tl.store(all_problems_ptr + pid * 3 + 1, n)
252-
tl.store(all_problems_ptr + pid * 3 + 2, k)
253-
254-
w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64)
255-
w_curr_ptr = w + w_i * k * n
256-
tl.store(w_ptr + pid, w_curr_ptr)
257-
258-
x_curr_ptr = x + tl.load(xy_indptr + pid) * k
259-
tl.store(x_ptr + pid, x_curr_ptr)
260-
261-
y_curr_ptr = y + tl.load(xy_indptr + pid) * n
262-
tl.store(y_ptr + pid, y_curr_ptr)
263-
264-
tl.store(x_ld_ptr + pid, k)
265-
tl.store(w_ld_ptr + pid, k if w_column_major else n)
266-
tl.store(y_ld_ptr + pid, n)
267-
268-
269-
@triton.jit
270-
def compute_sm90_group_gemm_args(
271-
all_problems_ptr,
272-
x_ptr,
273-
w_ptr,
274-
y_ptr,
275-
x_stride_ptr,
276-
w_stride_ptr,
277-
y_stride_ptr,
278-
x,
279-
w,
280-
y,
281-
xy_indptr,
282-
w_indices,
283-
d_in,
284-
d_out,
285-
w_column_major,
286-
):
287-
288-
pid = tl.program_id(0)
289-
290-
m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid)
291-
k, n = d_in, d_out
292-
293-
tl.store(all_problems_ptr + pid * 3, m)
294-
tl.store(all_problems_ptr + pid * 3 + 1, n)
295-
tl.store(all_problems_ptr + pid * 3 + 2, k)
296-
297-
w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64)
298-
w_curr_ptr = w + w_i * k * n
299-
tl.store(w_ptr + pid, w_curr_ptr)
300-
301-
x_curr_ptr = x + tl.load(xy_indptr + pid) * k
302-
tl.store(x_ptr + pid, x_curr_ptr)
303-
304-
y_curr_ptr = y + tl.load(xy_indptr + pid) * n
305-
tl.store(y_ptr + pid, y_curr_ptr)
306-
307-
tl.store(x_stride_ptr + pid, k)
308-
tl.store(w_stride_ptr + pid, k if w_column_major else n)
309-
tl.store(y_stride_ptr + pid, n)
310-
311-
312224
def launch_compute_sm80_group_gemm_args(
313225
x: torch.Tensor,
314226
weights: torch.Tensor,
@@ -340,6 +252,8 @@ def launch_compute_sm80_group_gemm_args(
340252
w_stride_data = torch.empty(batch_size, dtype=ld_type, device=device)
341253
y_stride_data = torch.empty(batch_size, dtype=ld_type, device=device)
342254

255+
from .triton.gemm import compute_sm80_group_gemm_args
256+
343257
compute_sm80_group_gemm_args[(batch_size,)](
344258
all_problems,
345259
x_data,
@@ -400,6 +314,8 @@ def launch_compute_sm90_group_gemm_args(
400314
w_stride_data = torch.empty(batch_size, dtype=stride_type, device=device)
401315
y_stride_data = torch.empty(batch_size, dtype=stride_type, device=device)
402316

317+
from .triton.gemm import compute_sm90_group_gemm_args
318+
403319
compute_sm90_group_gemm_args[(batch_size,)](
404320
all_problems,
405321
x_data,

flashinfer/page.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from typing import Optional, Tuple, Union
1818

1919
import torch
20-
import triton
21-
import triton.language as tl
2220

2321
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
2422
from .utils import (
@@ -142,27 +140,6 @@ def _fake_append_paged_kv_cache_kernel(
142140
pass
143141

144142

145-
@triton.jit
146-
def get_batch_indices_positions_kernel(
147-
append_indptr,
148-
seq_lens_ptr,
149-
batch_indices_ptr,
150-
positions_ptr,
151-
num_stages: tl.constexpr,
152-
):
153-
batch_idx = tl.program_id(0)
154-
155-
batch_start = tl.load(append_indptr + batch_idx)
156-
batch_end = tl.load(append_indptr + batch_idx + 1)
157-
seq_len = tl.load(seq_lens_ptr + batch_idx)
158-
159-
for i in tl.range(batch_start, batch_end, 128, num_stages=num_stages):
160-
offsets = tl.arange(0, 128) + i
161-
mask = offsets < batch_end
162-
tl.store(batch_indices_ptr + offsets, batch_idx, mask)
163-
tl.store(positions_ptr + offsets, offsets + seq_len - batch_end, mask)
164-
165-
166143
def get_batch_indices_positions(
167144
append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int
168145
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -210,6 +187,8 @@ def get_batch_indices_positions(
210187
batch_size = append_indptr.size(0) - 1
211188
batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32)
212189
positions = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32)
190+
from .triton.page import get_batch_indices_positions_kernel
191+
213192
get_batch_indices_positions_kernel[(batch_size,)](
214193
append_indptr, seq_lens, batch_indices, positions, num_stages=2
215194
)

flashinfer/triton/gemm.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import triton
18+
import triton.language as tl
19+
20+
21+
@triton.jit
22+
def compute_sm80_group_gemm_args(
23+
all_problems_ptr,
24+
x_ptr,
25+
w_ptr,
26+
y_ptr,
27+
x_ld_ptr,
28+
w_ld_ptr,
29+
y_ld_ptr,
30+
x,
31+
w,
32+
y,
33+
xy_indptr,
34+
w_indices,
35+
d_in,
36+
d_out,
37+
w_column_major,
38+
):
39+
40+
pid = tl.program_id(0)
41+
42+
m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid)
43+
k, n = d_in, d_out
44+
45+
tl.store(all_problems_ptr + pid * 3, m)
46+
tl.store(all_problems_ptr + pid * 3 + 1, n)
47+
tl.store(all_problems_ptr + pid * 3 + 2, k)
48+
49+
w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64)
50+
w_curr_ptr = w + w_i * k * n
51+
tl.store(w_ptr + pid, w_curr_ptr)
52+
53+
x_curr_ptr = x + tl.load(xy_indptr + pid) * k
54+
tl.store(x_ptr + pid, x_curr_ptr)
55+
56+
y_curr_ptr = y + tl.load(xy_indptr + pid) * n
57+
tl.store(y_ptr + pid, y_curr_ptr)
58+
59+
tl.store(x_ld_ptr + pid, k)
60+
tl.store(w_ld_ptr + pid, k if w_column_major else n)
61+
tl.store(y_ld_ptr + pid, n)
62+
63+
64+
@triton.jit
65+
def compute_sm90_group_gemm_args(
66+
all_problems_ptr,
67+
x_ptr,
68+
w_ptr,
69+
y_ptr,
70+
x_stride_ptr,
71+
w_stride_ptr,
72+
y_stride_ptr,
73+
x,
74+
w,
75+
y,
76+
xy_indptr,
77+
w_indices,
78+
d_in,
79+
d_out,
80+
w_column_major,
81+
):
82+
83+
pid = tl.program_id(0)
84+
85+
m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid)
86+
k, n = d_in, d_out
87+
88+
tl.store(all_problems_ptr + pid * 3, m)
89+
tl.store(all_problems_ptr + pid * 3 + 1, n)
90+
tl.store(all_problems_ptr + pid * 3 + 2, k)
91+
92+
w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64)
93+
w_curr_ptr = w + w_i * k * n
94+
tl.store(w_ptr + pid, w_curr_ptr)
95+
96+
x_curr_ptr = x + tl.load(xy_indptr + pid) * k
97+
tl.store(x_ptr + pid, x_curr_ptr)
98+
99+
y_curr_ptr = y + tl.load(xy_indptr + pid) * n
100+
tl.store(y_ptr + pid, y_curr_ptr)
101+
102+
tl.store(x_stride_ptr + pid, k)
103+
tl.store(w_stride_ptr + pid, k if w_column_major else n)
104+
tl.store(y_stride_ptr + pid, n)

flashinfer/triton/page.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple, Union
18+
19+
import torch
20+
import triton
21+
import triton.language as tl
22+
23+
24+
@triton.jit
25+
def get_batch_indices_positions_kernel(
26+
append_indptr,
27+
seq_lens_ptr,
28+
batch_indices_ptr,
29+
positions_ptr,
30+
num_stages: tl.constexpr,
31+
):
32+
batch_idx = tl.program_id(0)
33+
34+
batch_start = tl.load(append_indptr + batch_idx)
35+
batch_end = tl.load(append_indptr + batch_idx + 1)
36+
seq_len = tl.load(seq_lens_ptr + batch_idx)
37+
38+
for i in tl.range(batch_start, batch_end, 128, num_stages=num_stages):
39+
offsets = tl.arange(0, 128) + i
40+
mask = offsets < batch_end
41+
tl.store(batch_indices_ptr + offsets, batch_idx, mask)
42+
tl.store(positions_ptr + offsets, offsets + seq_len - batch_end, mask)

0 commit comments

Comments
 (0)