@@ -2836,6 +2836,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
2836
2836
else :
2837
2837
op .C .alignment = 8
2838
2838
2839
+ #
2840
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version ):
2841
+
2842
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2843
+ return
2844
+
2845
+ layouts = [
2846
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2847
+ ]
2848
+
2849
+ # Upcast on Operand A
2850
+ math_instructions = [
2851
+ MathInstruction ( \
2852
+ [16 , 8 , 32 ], \
2853
+ DataType .s4 , DataType .s8 , DataType .s32 , \
2854
+ OpcodeClass .TensorOp , \
2855
+ MathOperation .multiply_add_mixed_input_upcast ),
2856
+ ]
2857
+
2858
+ min_cc = 80
2859
+ max_cc = 1024
2860
+
2861
+ # For mixed-input alignment constraints are a list of lists, where the
2862
+ # inner list contains the alignment constraints for operands/matrices
2863
+ # [[alignA, alignB, alignC],..]
2864
+ alignment_constraints = [[32 , 16 , 4 ],]
2865
+
2866
+ for math_inst in math_instructions :
2867
+ tile_descriptions = [
2868
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2869
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2870
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2871
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2872
+ TileDescription ([ 32 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2873
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2874
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2875
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2876
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2877
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2878
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2879
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2880
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2881
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2882
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2883
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2884
+ ]
2885
+
2886
+ data_type = [
2887
+ math_inst .element_a ,
2888
+ math_inst .element_b ,
2889
+ math_inst .element_accumulator ,
2890
+ math_inst .element_accumulator ,
2891
+ ]
2892
+
2893
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2894
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2895
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2896
+
2897
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2898
+ if math_inst .element_a != math_inst .element_accumulator :
2899
+ alignment_constraints = [[32 , 16 , 16 ],]
2900
+
2901
+ data_type_mixed = [
2902
+ math_inst .element_a ,
2903
+ math_inst .element_b ,
2904
+ math_inst .element_b ,
2905
+ math_inst .element_accumulator ,
2906
+ ]
2907
+
2908
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
2909
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2910
+
2911
+ #
2912
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version ):
2913
+
2914
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2915
+ return
2916
+
2917
+ layouts = [
2918
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2919
+ ]
2920
+
2921
+ # Upcast on Operand B
2922
+ math_instructions = [
2923
+ MathInstruction ( \
2924
+ [16 , 8 , 32 ], \
2925
+ DataType .s8 , DataType .s4 , DataType .s32 , \
2926
+ OpcodeClass .TensorOp , \
2927
+ MathOperation .multiply_add_mixed_input_upcast ),
2928
+ ]
2929
+
2930
+ min_cc = 80
2931
+ max_cc = 1024
2932
+
2933
+ # For mixed-input alignment constraints are a list of lists, where the
2934
+ # inner list contains the alignment constraints for operands/matrices
2935
+ # [[alignA, alignB, alignC],..]
2936
+ alignment_constraints = [[16 , 32 , 4 ],]
2937
+
2938
+ for math_inst in math_instructions :
2939
+ tile_descriptions = [
2940
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2941
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2942
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2943
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2944
+ TileDescription ([256 , 32 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2945
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2946
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2947
+ TileDescription ([128 , 32 , 64 ], 6 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2948
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2949
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2950
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2951
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2952
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2953
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2954
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2955
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2956
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2957
+ ]
2958
+
2959
+ data_type = [
2960
+ math_inst .element_a ,
2961
+ math_inst .element_b ,
2962
+ math_inst .element_accumulator ,
2963
+ math_inst .element_accumulator ,
2964
+ ]
2965
+
2966
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2967
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2968
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2969
+
2970
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2971
+ if math_inst .element_a != math_inst .element_accumulator :
2972
+ alignment_constraints = [[16 , 32 , 16 ],]
2973
+
2974
+ data_type_mixed = [
2975
+ math_inst .element_a ,
2976
+ math_inst .element_b ,
2977
+ math_inst .element_a ,
2978
+ math_inst .element_accumulator ,
2979
+ ]
2980
+
2981
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
2982
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2983
+
2839
2984
#
2840
2985
2841
2986
#
@@ -4681,6 +4826,8 @@ def GenerateSM80(manifest, cuda_version):
4681
4826
GenerateSM80_TensorOp_16816_mixed_input_upcast_a (manifest , cuda_version )
4682
4827
GenerateSM80_TensorOp_16816_mixed_input_upcast_b (manifest , cuda_version )
4683
4828
GenerateSM80_TensorOp_16832_TN (manifest , cuda_version )
4829
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version )
4830
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version )
4684
4831
GenerateSM80_SparseTensorOp_16864_TN (manifest , cuda_version )
4685
4832
GenerateSM80_TensorOp_16832_Interleaved (manifest , cuda_version )
4686
4833
GenerateSM80_TensorOp_16864_TN (manifest , cuda_version )
0 commit comments