@@ -72,6 +72,7 @@ def get_fmha_module(
72
72
pos_encoding_mode : PosEncodingMode ,
73
73
use_sliding_window : bool ,
74
74
use_logits_soft_cap : bool ,
75
+ use_fp16_qk_reduction : bool = False ,
75
76
):
76
77
if is_sm100a_supported (torch .device ("cuda" )):
77
78
return gen_fmha_cutlass_sm100a_module (
@@ -2366,9 +2367,12 @@ def plan(
2366
2367
logits_soft_cap > 0 , # use_logits_soft_cap
2367
2368
use_fp16_qk_reduction ,
2368
2369
)
2369
- self ._cached_module = get_batch_prefill_module (self ._backend )(
2370
- * get_module_args
2371
- )
2370
+ if self ._backend == "cutlass" :
2371
+ self ._cached_module = get_cutlass_mha_module ()(* get_module_args )
2372
+ else :
2373
+ self ._cached_module = get_batch_prefill_module (self ._backend )(
2374
+ * get_module_args
2375
+ )
2372
2376
2373
2377
self ._plan_info = self ._cached_module .plan (
2374
2378
self ._float_workspace_buffer ,
@@ -2727,3 +2731,247 @@ def fmha_varlen(
2727
2731
lse = lse_padded
2728
2732
2729
2733
return out , lse
2734
+
2735
+
2736
+ def get_cutlass_mha_module ():
2737
+ def backend_module (* args ):
2738
+ modules_dict = _batch_prefill_modules
2739
+
2740
+ if args not in modules_dict :
2741
+ uri = get_batch_prefill_uri ("cutlass" , * args )
2742
+ module = get_fmha_module (* args )
2743
+
2744
+ @register_custom_op (
2745
+ f"flashinfer::{ uri } _ragged_run" ,
2746
+ mutates_args = (
2747
+ "float_workspace_buffer" ,
2748
+ "int_workspace_buffer" ,
2749
+ "o" ,
2750
+ "maybe_lse" ,
2751
+ ),
2752
+ )
2753
+ def ragged_run (
2754
+ float_workspace_buffer : torch .Tensor ,
2755
+ int_workspace_buffer : torch .Tensor ,
2756
+ plan_info_vec : List [int ],
2757
+ q : torch .Tensor ,
2758
+ k : torch .Tensor ,
2759
+ v : torch .Tensor ,
2760
+ qo_indptr : torch .Tensor ,
2761
+ kv_indptr : torch .Tensor ,
2762
+ o : torch .Tensor ,
2763
+ maybe_lse : Optional [torch .Tensor ],
2764
+ mask_mode : int ,
2765
+ layout : int ,
2766
+ window_left : int ,
2767
+ maybe_custom_mask : Optional [torch .Tensor ],
2768
+ maybe_mask_indptr : Optional [torch .Tensor ],
2769
+ maybe_alibi_slopes : Optional [torch .Tensor ],
2770
+ logits_soft_cap : float ,
2771
+ sm_scale : float ,
2772
+ rope_scale : float ,
2773
+ rope_theta : float ,
2774
+ ) -> None :
2775
+ nnz_qo , num_qo_heads , head_dim_qk = q .shape
2776
+ nnz_kv , num_kv_heads , head_dim_vo = v .shape
2777
+
2778
+ sm_scale = 1.0 / math .sqrt (head_dim_qk )
2779
+
2780
+ qo_lens = qo_indptr [1 :] - qo_indptr [:- 1 ]
2781
+ kv_lens = kv_indptr [1 :] - kv_indptr [:- 1 ]
2782
+ batch_size = qo_lens .shape [0 ]
2783
+ max_qo_len = qo_lens .max ()
2784
+ max_kv_len = kv_lens .max ()
2785
+
2786
+ q_padded = torch .cat (
2787
+ [
2788
+ torch .zeros (
2789
+ max (max_qo_len , 128 ),
2790
+ q .shape [1 ],
2791
+ q .shape [2 ],
2792
+ device = q .device ,
2793
+ dtype = q .dtype ,
2794
+ ),
2795
+ q ,
2796
+ ],
2797
+ dim = 0 ,
2798
+ )[max (max_qo_len , 128 ) :]
2799
+
2800
+ qo_total_len = nnz_qo
2801
+
2802
+ k_padded = torch .cat (
2803
+ [
2804
+ torch .zeros (
2805
+ max (max_kv_len , 128 ),
2806
+ k .shape [1 ],
2807
+ k .shape [2 ],
2808
+ device = k .device ,
2809
+ dtype = k .dtype ,
2810
+ ),
2811
+ k ,
2812
+ ],
2813
+ dim = 0 ,
2814
+ )[max (max_kv_len , 128 ) :]
2815
+ v_padded = torch .cat (
2816
+ [
2817
+ torch .zeros (
2818
+ max (max_kv_len , 128 ),
2819
+ v .shape [1 ],
2820
+ v .shape [2 ],
2821
+ device = v .device ,
2822
+ dtype = v .dtype ,
2823
+ ),
2824
+ v ,
2825
+ ],
2826
+ dim = 0 ,
2827
+ )[max (max_kv_len , 128 ) :]
2828
+
2829
+ if o is None :
2830
+ out_padded = torch .empty (
2831
+ qo_total_len + max (max_qo_len , 128 ),
2832
+ num_qo_heads ,
2833
+ head_dim_vo ,
2834
+ device = q .device ,
2835
+ dtype = q .dtype ,
2836
+ )[max (max_qo_len , 128 ) :]
2837
+ else :
2838
+ out_padded = o
2839
+
2840
+ if maybe_lse is None :
2841
+ lse_padded = torch .empty (
2842
+ qo_total_len , num_qo_heads , device = q .device , dtype = torch .float32
2843
+ )
2844
+ else :
2845
+ lse_padded = maybe_lse
2846
+
2847
+ module .run (
2848
+ q_padded ,
2849
+ k_padded ,
2850
+ v_padded ,
2851
+ qo_lens ,
2852
+ kv_lens ,
2853
+ qo_indptr ,
2854
+ kv_indptr ,
2855
+ out_padded ,
2856
+ lse_padded ,
2857
+ mask_mode ,
2858
+ sm_scale ,
2859
+ num_qo_heads ,
2860
+ num_kv_heads ,
2861
+ head_dim_qk ,
2862
+ batch_size ,
2863
+ nnz_qo ,
2864
+ nnz_kv ,
2865
+ max_qo_len ,
2866
+ max_kv_len ,
2867
+ )
2868
+
2869
+ o = out_padded
2870
+ maybe_lse = lse_padded
2871
+
2872
+ return o , maybe_lse
2873
+
2874
+ @register_custom_op (
2875
+ f"flashinfer::{ uri } _paged_run" ,
2876
+ mutates_args = (
2877
+ "float_workspace_buffer" ,
2878
+ "int_workspace_buffer" ,
2879
+ "paged_k_cache" ,
2880
+ "paged_v_cache" ,
2881
+ "o" ,
2882
+ "maybe_lse" ,
2883
+ ),
2884
+ )
2885
+ def paged_run (
2886
+ float_workspace_buffer : torch .Tensor ,
2887
+ int_workspace_buffer : torch .Tensor ,
2888
+ plan_info_vec : List [int ],
2889
+ q : torch .Tensor ,
2890
+ paged_k_cache : torch .Tensor ,
2891
+ paged_v_cache : torch .Tensor ,
2892
+ qo_indptr : torch .Tensor ,
2893
+ paged_kv_indptr : torch .Tensor ,
2894
+ paged_kv_indices : torch .Tensor ,
2895
+ paged_kv_last_page_len : torch .Tensor ,
2896
+ o : torch .Tensor ,
2897
+ maybe_lse : Optional [torch .Tensor ],
2898
+ mask_mode : int ,
2899
+ layout : int ,
2900
+ window_left : int ,
2901
+ maybe_custom_mask : Optional [torch .Tensor ],
2902
+ maybe_mask_indptr : Optional [torch .Tensor ],
2903
+ maybe_alibi_slopes : Optional [torch .Tensor ],
2904
+ logits_soft_cap : float ,
2905
+ sm_scale : float ,
2906
+ rope_scale : float ,
2907
+ rope_theta : float ,
2908
+ ) -> None :
2909
+ pass
2910
+
2911
+ @register_fake_op (f"flashinfer::{ uri } _ragged_run" )
2912
+ def _fake_ragged_run (
2913
+ float_workspace_buffer : torch .Tensor ,
2914
+ int_workspace_buffer : torch .Tensor ,
2915
+ plan_info_vec : List [int ],
2916
+ q : torch .Tensor ,
2917
+ k : torch .Tensor ,
2918
+ v : torch .Tensor ,
2919
+ qo_indptr : torch .Tensor ,
2920
+ kv_indptr : torch .Tensor ,
2921
+ o : torch .Tensor ,
2922
+ maybe_lse : Optional [torch .Tensor ],
2923
+ mask_mode : int ,
2924
+ layout : int ,
2925
+ window_left : int ,
2926
+ maybe_custom_mask : Optional [torch .Tensor ],
2927
+ maybe_mask_indptr : Optional [torch .Tensor ],
2928
+ maybe_alibi_slopes : Optional [torch .Tensor ],
2929
+ logits_soft_cap : float ,
2930
+ sm_scale : float ,
2931
+ rope_scale : float ,
2932
+ rope_theta : float ,
2933
+ ) -> None :
2934
+ pass
2935
+
2936
+ @register_fake_op (f"flashinfer::{ uri } _paged_run" )
2937
+ def _fake_paged_run (
2938
+ float_workspace_buffer : torch .Tensor ,
2939
+ int_workspace_buffer : torch .Tensor ,
2940
+ plan_info_vec : List [int ],
2941
+ q : torch .Tensor ,
2942
+ paged_k_cache : torch .Tensor ,
2943
+ paged_v_cache : torch .Tensor ,
2944
+ qo_indptr : torch .Tensor ,
2945
+ paged_kv_indptr : torch .Tensor ,
2946
+ paged_kv_indices : torch .Tensor ,
2947
+ paged_kv_last_page_len : torch .Tensor ,
2948
+ o : torch .Tensor ,
2949
+ maybe_lse : Optional [torch .Tensor ],
2950
+ mask_mode : int ,
2951
+ layout : int ,
2952
+ window_left : int ,
2953
+ maybe_custom_mask : Optional [torch .Tensor ],
2954
+ maybe_mask_indptr : Optional [torch .Tensor ],
2955
+ maybe_alibi_slopes : Optional [torch .Tensor ],
2956
+ logits_soft_cap : float ,
2957
+ sm_scale : float ,
2958
+ rope_scale : float ,
2959
+ rope_theta : float ,
2960
+ ) -> None :
2961
+ pass
2962
+
2963
+ def plan (* args ):
2964
+ pass
2965
+
2966
+ # Register the module.
2967
+ #
2968
+ # Note that plan is not part of model logic. It should not be included in
2969
+ # Cuda Graph or torch.compile. So, we don't provide a torch library for plan.
2970
+ modules_dict [args ] = SimpleNamespace (
2971
+ plan = plan ,
2972
+ ragged_run = ragged_run ,
2973
+ paged_run = paged_run ,
2974
+ )
2975
+ return modules_dict [args ]
2976
+
2977
+ return backend_module
0 commit comments