diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 5f2821dd6..8d75d6198 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -20,7 +20,9 @@ import torch from .decode import BatchDecodeWithPagedKVCacheWrapper -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache from .utils import register_custom_op, register_fake_op @@ -31,8 +33,8 @@ def gen_cascade_module() -> JitSpec: return gen_jit_spec( "cascade", [ - FLASHINFER_CSRC_DIR / "cascade.cu", - FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "cascade.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu", ], ) diff --git a/flashinfer/custom_all_reduce.py b/flashinfer/custom_all_reduce.py index ac7a65a09..e11038da3 100644 --- a/flashinfer/custom_all_reduce.py +++ b/flashinfer/custom_all_reduce.py @@ -21,7 +21,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import register_custom_op _comm_module = None @@ -31,8 +33,8 @@ def gen_comm_module() -> JitSpec: return gen_jit_spec( "comm", [ - FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu", - FLASHINFER_CSRC_DIR / "custom_all_reduce.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu", ], ) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index ecc68bffc..f8476225e 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -21,14 +21,9 @@ import torch import torch.nn.functional as F -from .jit import ( - FLASHINFER_CSRC_DIR, - JitSpec, - gen_jit_spec, - has_prebuilt_ops, - sm90a_nvcc_flags, - sm100a_nvcc_flags, -) +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops, sm90a_nvcc_flags, sm100a_nvcc_flags from .utils import ( _get_cache_buf, determine_gemm_backend, @@ -46,9 +41,9 @@ def gen_gemm_module() -> JitSpec: return gen_jit_spec( "gemm", [ - FLASHINFER_CSRC_DIR / "bmm_fp8.cu", - FLASHINFER_CSRC_DIR / "group_gemm.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "bmm_fp8.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", ], extra_ldflags=["-lcublas", "-lcublasLt"], ) @@ -157,10 +152,10 @@ def gen_gemm_sm100_module() -> JitSpec: return gen_jit_spec( "gemm_sm100", [ - FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu", - FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu", - FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu", - FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu", + jit_env.FLASHINFER_CSRC_DIR / "gemm_groupwise_sm100.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_groupwise_sm100.cu", + jit_env.FLASHINFER_CSRC_DIR / "gemm_sm100_pybind.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm100_pybind.cu", ], extra_cuda_cflags=sm100a_nvcc_flags, ) @@ -181,14 +176,14 @@ def gen_gemm_sm90_module() -> JitSpec: return gen_jit_spec( "gemm_sm90", [ - FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu", - FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu", - FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu", - FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu", - FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu", - FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu", - FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_f16_f16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_bf16_bf16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_f16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_f16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e4m3_bf16_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "group_gemm_e5m2_bf16_sm90.cu", ], extra_cuda_cflags=sm90a_nvcc_flags, ) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index a42330cbc..4e05a9656 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -18,6 +18,7 @@ import os # Re-export +from . import env as env from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module @@ -61,7 +62,6 @@ from .core import gen_jit_spec as gen_jit_spec from .core import sm90a_nvcc_flags as sm90a_nvcc_flags from .core import sm100a_nvcc_flags as sm100a_nvcc_flags -from .env import * cuda_lib_path = os.environ.get( "CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/" diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index 78b1f1391..4d78616e5 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -18,8 +18,8 @@ import jinja2 +from . import env as jit_env from .core import JitSpec, gen_jit_spec -from .env import FLASHINFER_GEN_SRC_DIR from .utils import write_if_different activation_templ = r""" @@ -77,7 +77,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str: def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> JitSpec: - gen_directory = FLASHINFER_GEN_SRC_DIR + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR os.makedirs(gen_directory, exist_ok=True) sources = [gen_directory / f"{act_func_name}_and_mul.cu"] write_if_different( diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 682f5fb50..3e35b1f7a 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -20,8 +20,8 @@ import jinja2 import torch +from .. import env as jit_env from ..core import JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, sm100a_nvcc_flags -from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR from ..utils import ( dtype_map, filename_safe_dtype_map, @@ -121,11 +121,11 @@ def gen_batch_mla_module( head_dim_kpe, use_profiler, ) - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) if backend == "fa2": - with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "batch_mla_config.inc" write_if_different( @@ -146,14 +146,14 @@ def gen_batch_mla_module( "batch_mla_run.cu", "batch_mla_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) elif backend == "fa3": - with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "batch_mla_sm90_config.inc" write_if_different( @@ -173,7 +173,7 @@ def gen_batch_mla_module( "batch_mla_sm90_run.cu", "batch_mla_sm90_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -259,10 +259,10 @@ def gen_batch_decode_mla_module( use_logits_soft_cap, arc, ) - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - with open(FLASHINFER_CSRC_DIR / "batch_decode_mla_config.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "mla_config.inc" write_if_different( @@ -295,7 +295,7 @@ def gen_batch_decode_mla_module( source_paths = [] for filename in filenames: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -599,7 +599,7 @@ def gen_customize_pod_module( use_logits_soft_cap_d: bool = False, use_fp16_qk_reduction: bool = False, ) -> JitSpec: - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri ( additional_params_decl, @@ -612,10 +612,10 @@ def gen_customize_pod_module( additional_scalar_dtypes, ) - with open(FLASHINFER_CSRC_DIR / "pod_customize_config.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "pod_customize_config.jinja") as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "pod_kernel_inst.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "pod_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs = { @@ -665,7 +665,7 @@ def gen_customize_pod_module( "pod.cu", "pod_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -848,7 +848,7 @@ def gen_customize_single_decode_module( use_sliding_window: bool = False, use_logits_soft_cap: bool = False, ) -> JitSpec: - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri ( additional_params_decl, @@ -861,10 +861,12 @@ def gen_customize_single_decode_module( additional_scalar_dtypes, ) - with open(FLASHINFER_CSRC_DIR / "single_decode_customize_config.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "single_decode_customize_config.jinja" + ) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "single_decode_kernel_inst.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "single_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs = { @@ -902,7 +904,7 @@ def gen_customize_single_decode_module( "single_decode.cu", "single_decode_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -951,7 +953,7 @@ def gen_customize_single_prefill_module( if backend == "auto": raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri additional_params_decl, additional_func_params, additional_params_setter = ( generate_additional_params( additional_tensor_names, @@ -961,10 +963,14 @@ def gen_customize_single_prefill_module( ) ) - with open(FLASHINFER_CSRC_DIR / "single_prefill_customize_config.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "single_prefill_customize_config.jinja" + ) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "single_prefill_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "single_prefill_kernel_inst.jinja" + ) as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -993,7 +999,7 @@ def gen_customize_single_prefill_module( "single_prefill.cu", "single_prefill_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -1005,7 +1011,7 @@ def gen_customize_single_prefill_module( return gen_jit_spec(uri, source_paths) elif backend == "fa3": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( @@ -1025,10 +1031,10 @@ def gen_customize_single_prefill_module( _file_kernel_inst = "single_prefill_sm90_kernel_inst.jinja" _file_csrc = "single_prefill_sm90.cu" - with open(FLASHINFER_CSRC_DIR / _file_config) as f: + with open(jit_env.FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / _file_kernel_inst) as f: + with open(jit_env.FLASHINFER_CSRC_DIR / _file_kernel_inst) as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -1057,7 +1063,7 @@ def gen_customize_single_prefill_module( _file_csrc, "single_prefill_sm90_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -1093,7 +1099,7 @@ def gen_customize_batch_decode_module( use_sliding_window: bool = False, use_logits_soft_cap: bool = False, ) -> JitSpec: - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -1120,10 +1126,10 @@ def gen_customize_batch_decode_module( "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } - with open(FLASHINFER_CSRC_DIR / "batch_decode_customize_config.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_customize_config.jinja") as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) generated_inc_str = config_templ.render( @@ -1143,7 +1149,7 @@ def gen_customize_batch_decode_module( "batch_decode.cu", "batch_decode_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -1193,7 +1199,7 @@ def gen_customize_batch_prefill_module( if backend == "auto": raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -1203,13 +1209,19 @@ def gen_customize_batch_prefill_module( ) ) - with open(FLASHINFER_CSRC_DIR / "batch_prefill_customize_config.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_customize_config.jinja" + ) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja" + ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja" + ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -1249,7 +1261,7 @@ def gen_customize_batch_prefill_module( "batch_prefill.cu", "batch_prefill_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -1260,7 +1272,7 @@ def gen_customize_batch_prefill_module( write_if_different(generated_config_path, generated_inc_str) return gen_jit_spec(uri, source_paths) elif backend == "fa3": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -1281,13 +1293,13 @@ def gen_customize_batch_prefill_module( _file_ragged_kernel_inst = "batch_prefill_ragged_sm90_kernel_inst.jinja" _file_csrc = "batch_prefill_sm90.cu" - with open(FLASHINFER_CSRC_DIR / _file_config) as f: + with open(jit_env.FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / _file_paged_kernel_inst) as f: + with open(jit_env.FLASHINFER_CSRC_DIR / _file_paged_kernel_inst) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / _file_ragged_kernel_inst) as f: + with open(jit_env.FLASHINFER_CSRC_DIR / _file_ragged_kernel_inst) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -1321,7 +1333,7 @@ def gen_customize_batch_prefill_module( _file_csrc, "batch_prefill_sm90_jit_pybind.cu", ]: - src_path = FLASHINFER_CSRC_DIR / filename + src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -1389,8 +1401,8 @@ def gen_fmha_cutlass_sm100a_module( ) source_paths = [ - FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100.cu", - FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100_pybind.cu", + jit_env.FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100.cu", + jit_env.FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100_pybind.cu", ] return gen_jit_spec( uri, diff --git a/flashinfer/jit/attention/tvm.py b/flashinfer/jit/attention/tvm.py index b90a4cd86..b52a15a98 100644 --- a/flashinfer/jit/attention/tvm.py +++ b/flashinfer/jit/attention/tvm.py @@ -21,11 +21,7 @@ import jinja2 import torch -from ..env import ( - FLASHINFER_CSRC_DIR, - FLASHINFER_GEN_SRC_DIR, - FLASHINFER_TVM_BINDING_DIR, -) +from .. import env as jit_env from ..utils import ( dtype_map, mask_mode_literal, @@ -36,12 +32,12 @@ def gen_sampling_tvm_binding(uri: str): - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) source_paths = [] for filename in ["sampling.cu", "sampling_jit_tvm_binding.cu"]: - src_path = FLASHINFER_TVM_BINDING_DIR / filename + src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -91,7 +87,7 @@ def gen_customize_batch_prefill_tvm_binding( if backend == "auto": raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -102,14 +98,18 @@ def gen_customize_batch_prefill_tvm_binding( ) with open( - FLASHINFER_TVM_BINDING_DIR / "batch_prefill_customize_config.jinja" + jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_prefill_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja" + ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja" + ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -156,7 +156,7 @@ def gen_customize_batch_prefill_tvm_binding( "batch_prefill.cu", "batch_prefill_jit_tvm_binding.cu", ]: - src_path = FLASHINFER_TVM_BINDING_DIR / filename + src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -167,7 +167,7 @@ def gen_customize_batch_prefill_tvm_binding( write_if_different(generated_config_path, generated_inc_str) return uri, source_paths elif backend == "fa3": - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -179,17 +179,18 @@ def gen_customize_batch_prefill_tvm_binding( ) with open( - FLASHINFER_TVM_BINDING_DIR / "batch_prefill_sm90_customize_config.jinja" + jit_env.FLASHINFER_TVM_BINDING_DIR + / "batch_prefill_sm90_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) with open( - FLASHINFER_CSRC_DIR / "batch_prefill_paged_sm90_kernel_inst.jinja" + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_sm90_kernel_inst.jinja" ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) with open( - FLASHINFER_CSRC_DIR / "batch_prefill_ragged_sm90_kernel_inst.jinja" + jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_sm90_kernel_inst.jinja" ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) @@ -232,7 +233,7 @@ def gen_customize_batch_prefill_tvm_binding( "batch_prefill_sm90.cu", "batch_prefill_sm90_jit_tvm_binding.cu", ]: - src_path = FLASHINFER_TVM_BINDING_DIR / filename + src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -275,7 +276,7 @@ def gen_customize_batch_decode_tvm_binding( "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri (additional_params_decl, additional_func_params, additional_params_setter) = ( generate_additional_params( additional_tensor_names, @@ -285,10 +286,12 @@ def gen_customize_batch_decode_tvm_binding( ) ) - with open(FLASHINFER_TVM_BINDING_DIR / "batch_decode_customize_config.jinja") as f: + with open( + jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_decode_customize_config.jinja" + ) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -313,7 +316,7 @@ def gen_customize_batch_decode_tvm_binding( "batch_decode.cu", "batch_decode_jit_tvm_binding.cu", ]: - src_path = FLASHINFER_TVM_BINDING_DIR / filename + src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: @@ -334,10 +337,10 @@ def gen_batch_mla_tvm_binding( head_dim_ckv: int, head_dim_kpe: int, ): - gen_directory = FLASHINFER_GEN_SRC_DIR / uri + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - with open(FLASHINFER_TVM_BINDING_DIR / "batch_mla_config.jinja") as f: + with open(jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "batch_mla_config.inc" write_if_different( @@ -358,7 +361,7 @@ def gen_batch_mla_tvm_binding( "batch_mla_run.cu", "batch_mla_jit_tvm_binding.cu", ]: - src_path = FLASHINFER_TVM_BINDING_DIR / filename + src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 0ad6c556b..ebcf1923d 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -10,17 +10,12 @@ import torch.utils.cpp_extension as torch_cpp_ext from filelock import FileLock +from . import env as jit_env from .cpp_ext import generate_ninja_build_for_op, run_ninja -from .env import CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS -from .env import FLASHINFER_CSRC_DIR as FLASHINFER_CSRC_DIR -from .env import FLASHINFER_GEN_SRC_DIR as FLASHINFER_GEN_SRC_DIR -from .env import FLASHINFER_INCLUDE_DIR as FLASHINFER_INCLUDE_DIR -from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR -from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR from .utils import write_if_different -os.makedirs(FLASHINFER_WORKSPACE_DIR, exist_ok=True) -os.makedirs(FLASHINFER_CSRC_DIR, exist_ok=True) +os.makedirs(jit_env.FLASHINFER_WORKSPACE_DIR, exist_ok=True) +os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True) class FlashInferJITLogger(logging.Logger): @@ -28,7 +23,7 @@ def __init__(self, name): super().__init__(name) self.setLevel(logging.INFO) self.addHandler(logging.StreamHandler()) - log_path = FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" + log_path = jit_env.FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" if not os.path.exists(log_path): # create an empty file with open(log_path, "w") as f: # noqa: F841 @@ -58,10 +53,10 @@ def check_cuda_arch(): def clear_cache_dir(): - if os.path.exists(FLASHINFER_JIT_DIR): + if os.path.exists(jit_env.FLASHINFER_JIT_DIR): import shutil - shutil.rmtree(FLASHINFER_JIT_DIR) + shutil.rmtree(jit_env.FLASHINFER_JIT_DIR) sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] @@ -79,11 +74,11 @@ class JitSpec: @property def ninja_path(self) -> Path: - return FLASHINFER_JIT_DIR / self.name / "build.ninja" + return jit_env.FLASHINFER_JIT_DIR / self.name / "build.ninja" @property def library_path(self) -> Path: - return FLASHINFER_JIT_DIR / self.name / f"{self.name}.so" + return jit_env.FLASHINFER_JIT_DIR / self.name / f"{self.name}.so" def write_ninja(self) -> None: ninja_path = self.ninja_path @@ -101,7 +96,7 @@ def write_ninja(self) -> None: def build(self, verbose: bool) -> None: tmpdir = get_tmpdir() with FileLock(tmpdir / f"{self.name}.lock", thread_local=False): - run_ninja(FLASHINFER_JIT_DIR, self.ninja_path, verbose) + run_ninja(jit_env.FLASHINFER_JIT_DIR, self.ninja_path, verbose) def build_and_load(self): verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1" @@ -166,7 +161,7 @@ def gen_jit_spec( def get_tmpdir() -> Path: # TODO(lequn): Try /dev/shm first. This should help Lock on NFS. - tmpdir = FLASHINFER_JIT_DIR / "tmp" + tmpdir = jit_env.FLASHINFER_JIT_DIR / "tmp" if not tmpdir.exists(): tmpdir.mkdir(parents=True, exist_ok=True) return tmpdir @@ -182,7 +177,7 @@ def build_jit_specs(specs: List[JitSpec], verbose: bool) -> None: with FileLock(tmpdir / "flashinfer_jit.lock", thread_local=False): ninja_path = tmpdir / "flashinfer_jit.ninja" write_if_different(ninja_path, "\n".join(lines)) - run_ninja(FLASHINFER_JIT_DIR, ninja_path, verbose) + run_ninja(jit_env.FLASHINFER_JIT_DIR, ninja_path, verbose) def load_cuda_ops( diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index b67ca2fcc..2e088cc16 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -16,7 +16,7 @@ _get_pybind11_abi_build_flags, ) -from .env import FLASHINFER_DATA +from . import env as jit_env def _get_glibcxx_abi_build_flags() -> List[str]: @@ -105,7 +105,7 @@ def generate_ninja_build_for_op( f"name = {name}", f"cuda_home = {cuda_home}", f"torch_home = {_TORCH_PATH}", - f"flashinfer_data = {FLASHINFER_DATA.resolve()}", + f"flashinfer_data = {jit_env.FLASHINFER_DATA.resolve()}", f"cxx = {cxx}", f"nvcc = {nvcc}", "", diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 7e746aff7..af4316992 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -14,6 +14,10 @@ limitations under the License. """ +# NOTE(lequn): Do not "from .jit.env import xxx". +# Do "from .jit import env as jit_env" and use "jit_env.xxx" instead. +# This helps AOT script to override envs. + import os import pathlib import re diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 78f2c2f7f..e3cc8958c 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -20,14 +20,9 @@ import torch -from .jit import ( - FLASHINFER_CSRC_DIR, - JitSpec, - gen_batch_mla_module, - gen_jit_spec, - sm100a_nvcc_flags, -) -from .jit.env import CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_batch_mla_module, gen_jit_spec, sm100a_nvcc_flags from .utils import ( MaskMode, _check_shape_dtype_device, @@ -73,12 +68,12 @@ def gen_mla_module() -> JitSpec: return gen_jit_spec( "mla", [ - FLASHINFER_CSRC_DIR / "cutlass_mla.cu", - FLASHINFER_CSRC_DIR / "flashinfer_mla_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "cutlass_mla.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mla_ops.cu", ], extra_include_paths=[ - CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "77_blackwell_fmha", - CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "common", + jit_env.CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "77_blackwell_fmha", + jit_env.CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "common", ], extra_cuda_cflags=sm100a_nvcc_flags, ) diff --git a/flashinfer/norm.py b/flashinfer/norm.py index 20baed866..97452f595 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -19,7 +19,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import register_custom_op, register_fake_op _norm_module = None @@ -29,8 +31,8 @@ def gen_norm_module() -> JitSpec: return gen_jit_spec( "norm", [ - FLASHINFER_CSRC_DIR / "norm.cu", - FLASHINFER_CSRC_DIR / "flashinfer_norm_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "norm.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_ops.cu", ], ) diff --git a/flashinfer/page.py b/flashinfer/page.py index 4a08eb64e..bd0628d22 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -19,7 +19,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import ( TensorLayout, _check_kv_layout, @@ -35,8 +37,8 @@ def gen_page_module() -> JitSpec: return gen_jit_spec( "page", [ - FLASHINFER_CSRC_DIR / "page.cu", - FLASHINFER_CSRC_DIR / "flashinfer_page_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "page.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_page_ops.cu", ], ) diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index 5fb928106..3830310e1 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -19,7 +19,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import register_custom_op, register_fake_op _quantization_module = None @@ -29,8 +31,8 @@ def gen_quantization_module() -> JitSpec: return gen_jit_spec( "quantization", [ - FLASHINFER_CSRC_DIR / "quantization.cu", - FLASHINFER_CSRC_DIR / "flashinfer_quantization_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "quantization.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_quantization_ops.cu", ], ) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 1c3abb0f1..0199dc6b5 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -19,7 +19,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import register_custom_op, register_fake_op _rope_module = None @@ -29,8 +31,8 @@ def gen_rope_module() -> JitSpec: return gen_jit_spec( "rope", [ - FLASHINFER_CSRC_DIR / "rope.cu", - FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "rope.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu", ], ) diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 02d9caf47..43a44f17b 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -19,7 +19,9 @@ import torch -from .jit import FLASHINFER_CSRC_DIR, JitSpec, gen_jit_spec, has_prebuilt_ops +from .jit import JitSpec +from .jit import env as jit_env +from .jit import gen_jit_spec, has_prebuilt_ops from .utils import register_custom_op, register_fake_op _sampling_module = None @@ -29,9 +31,9 @@ def gen_sampling_module() -> JitSpec: return gen_jit_spec( "sampling", [ - FLASHINFER_CSRC_DIR / "sampling.cu", - FLASHINFER_CSRC_DIR / "renorm.cu", - FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", + jit_env.FLASHINFER_CSRC_DIR / "sampling.cu", + jit_env.FLASHINFER_CSRC_DIR / "renorm.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", ], )