Skip to content

Commit 5adec0d

Browse files
committed
misc: jit: Import jit_env as a module
1 parent 20892f7 commit 5adec0d

File tree

16 files changed

+161
-143
lines changed

16 files changed

+161
-143
lines changed

flashinfer/cascade.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import torch
2121

2222
from .decode import BatchDecodeWithPagedKVCacheWrapper
23-
from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops
23+
from .jit import JitSpec
24+
from .jit import env as jit_env
25+
from .jit import gen_jit_spec, has_prebuilt_ops
2426
from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache
2527
from .utils import register_custom_op, register_fake_op
2628

@@ -31,8 +33,8 @@ def gen_cascade_module() -> JitSpec:
3133
return gen_jit_spec(
3234
"cascade",
3335
[
34-
FLASHINFER_CSRC_DIR / "cascade.cu",
35-
FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu",
36+
jit_env.FLASHINFER_CSRC_DIR / "cascade.cu",
37+
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu",
3638
],
3739
)
3840

flashinfer/custom_all_reduce.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import torch
2323

24-
from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops
24+
from .jit import JitSpec
25+
from .jit import env as jit_env
26+
from .jit import gen_jit_spec, has_prebuilt_ops
2527
from .utils import register_custom_op
2628

2729
_comm_module = None
@@ -31,8 +33,8 @@ def gen_comm_module() -> JitSpec:
3133
return gen_jit_spec(
3234
"comm",
3335
[
34-
FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu",
35-
FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
36+
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu",
37+
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
3638
],
3739
)
3840

flashinfer/gemm.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@
2121
import torch
2222
import torch.nn.functional as F
2323

24-
from .jit import (
25-
FLASHINFER_CSRC_DIR,
26-
JitSpec,
27-
gen_jit_spec,
28-
has_prebuilt_ops,
29-
sm90a_nvcc_flags,
30-
sm100a_nvcc_flags,
31-
)
24+
from .jit import JitSpec
25+
from .jit import env as jit_env
26+
from .jit import gen_jit_spec, has_prebuilt_ops, sm90a_nvcc_flags, sm100a_nvcc_flags
3227
from .utils import (
3328
_get_cache_buf,
3429
determine_gemm_backend,
@@ -46,9 +41,9 @@ def gen_gemm_module() -> JitSpec:
4641
return gen_jit_spec(
4742
"gemm",
4843
[
49-
FLASHINFER_CSRC_DIR / "bmm_fp8.cu",
50-
FLASHINFER_CSRC_DIR / "group_gemm.cu",
51-
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
44+
jit_env.FLASHINFER_CSRC_DIR / "bmm_fp8.cu",
45+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm.cu",
46+
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
5247
],
5348
extra_ldflags=["-lcublas", "-lcublasLt"],
5449
)
@@ -157,10 +152,10 @@ def gen_gemm_sm100_module() -> JitSpec:
157152
return gen_jit_spec(
158153
"gemm_sm100",
159154
[
160-
FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu",
161-
FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu",
162-
FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu",
163-
FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu",
155+
jit_env.FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu",
156+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu",
157+
jit_env.FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu",
158+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu",
164159
],
165160
extra_cuda_cflags=sm100a_nvcc_flags,
166161
)
@@ -181,14 +176,14 @@ def gen_gemm_sm90_module() -> JitSpec:
181176
return gen_jit_spec(
182177
"gemm_sm90",
183178
[
184-
FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
185-
FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu",
186-
FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu",
187-
FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu",
188-
FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu",
189-
FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu",
190-
FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu",
191-
FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu",
179+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
180+
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu",
181+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu",
182+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu",
183+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu",
184+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu",
185+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu",
186+
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu",
192187
],
193188
extra_cuda_cflags=sm90a_nvcc_flags,
194189
)

flashinfer/jit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
# Re-export
21+
from . import env as env
2122
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
2223
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
2324
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
@@ -61,7 +62,6 @@
6162
from .core import gen_jit_spec as gen_jit_spec
6263
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
6364
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
64-
from .env import *
6565

6666
cuda_lib_path = os.environ.get(
6767
"CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/"

flashinfer/jit/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import jinja2
2020

21+
from . import env as jit_env
2122
from .core import JitSpec, gen_jit_spec
22-
from .env import FLASHINFER_GEN_SRC_DIR
2323
from .utils import write_if_different
2424

2525
activation_templ = r"""
@@ -77,7 +77,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str:
7777

7878

7979
def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> JitSpec:
80-
gen_directory = FLASHINFER_GEN_SRC_DIR
80+
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR
8181
os.makedirs(gen_directory, exist_ok=True)
8282
sources = [gen_directory / f"{act_func_name}_and_mul.cu"]
8383
write_if_different(

0 commit comments

Comments
 (0)