Skip to content

misc: jit: Import jit_env as a module #1073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import torch

from .decode import BatchDecodeWithPagedKVCacheWrapper
from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec, has_prebuilt_ops
from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache
from .utils import register_custom_op, register_fake_op

Expand All @@ -31,8 +33,8 @@ def gen_cascade_module() -> JitSpec:
return gen_jit_spec(
"cascade",
[
FLASHINFER_CSRC_DIR / "cascade.cu",
FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu",
jit_env.FLASHINFER_CSRC_DIR / "cascade.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu",
],
)

Expand Down
8 changes: 5 additions & 3 deletions flashinfer/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import torch

from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec, has_prebuilt_ops
from .utils import register_custom_op

_comm_module = None
Expand All @@ -31,8 +33,8 @@ def gen_comm_module() -> JitSpec:
return gen_jit_spec(
"comm",
[
FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu",
FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu",
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
],
)

Expand Down
41 changes: 18 additions & 23 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@
import torch
import torch.nn.functional as F

from .jit import (
FLASHINFER_CSRC_DIR,
JitSpec,
gen_jit_spec,
has_prebuilt_ops,
sm90a_nvcc_flags,
sm100a_nvcc_flags,
)
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec, has_prebuilt_ops, sm90a_nvcc_flags, sm100a_nvcc_flags
from .utils import (
_get_cache_buf,
determine_gemm_backend,
Expand All @@ -46,9 +41,9 @@ def gen_gemm_module() -> JitSpec:
return gen_jit_spec(
"gemm",
[
FLASHINFER_CSRC_DIR / "bmm_fp8.cu",
FLASHINFER_CSRC_DIR / "group_gemm.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
jit_env.FLASHINFER_CSRC_DIR / "bmm_fp8.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
],
extra_ldflags=["-lcublas", "-lcublasLt"],
)
Expand Down Expand Up @@ -157,10 +152,10 @@ def gen_gemm_sm100_module() -> JitSpec:
return gen_jit_spec(
"gemm_sm100",
[
FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu",
FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu",
FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu",
FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu",
jit_env.FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu",
jit_env.FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu",
],
extra_cuda_cflags=sm100a_nvcc_flags,
)
Expand All @@ -181,14 +176,14 @@ def gen_gemm_sm90_module() -> JitSpec:
return gen_jit_spec(
"gemm_sm90",
[
FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu",
FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu",
FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu",
FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu",
FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu",
FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu",
FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu",
],
extra_cuda_cflags=sm90a_nvcc_flags,
)
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os

# Re-export
from . import env as env
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
Expand Down Expand Up @@ -61,7 +62,6 @@
from .core import gen_jit_spec as gen_jit_spec
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
from .env import *

cuda_lib_path = os.environ.get(
"CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/"
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/jit/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import jinja2

from . import env as jit_env
from .core import JitSpec, gen_jit_spec
from .env import FLASHINFER_GEN_SRC_DIR
from .utils import write_if_different

activation_templ = r"""
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str:


def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> JitSpec:
gen_directory = FLASHINFER_GEN_SRC_DIR
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR
os.makedirs(gen_directory, exist_ok=True)
sources = [gen_directory / f"{act_func_name}_and_mul.cu"]
write_if_different(
Expand Down
Loading