@@ -2835,6 +2835,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
2835
2835
else :
2836
2836
op .C .alignment = 8
2837
2837
2838
+ #
2839
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version ):
2840
+
2841
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2842
+ return
2843
+
2844
+ layouts = [
2845
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2846
+ ]
2847
+
2848
+ # Upcast on Operand A
2849
+ math_instructions = [
2850
+ MathInstruction ( \
2851
+ [16 , 8 , 32 ], \
2852
+ DataType .s4 , DataType .s8 , DataType .s32 , \
2853
+ OpcodeClass .TensorOp , \
2854
+ MathOperation .multiply_add_mixed_input_upcast ),
2855
+ ]
2856
+
2857
+ min_cc = 80
2858
+ max_cc = 1024
2859
+
2860
+ # For mixed-input alignment constraints are a list of lists, where the
2861
+ # inner list contains the alignment constraints for operands/matrices
2862
+ # [[alignA, alignB, alignC],..]
2863
+ alignment_constraints = [[32 , 16 , 4 ],]
2864
+
2865
+ for math_inst in math_instructions :
2866
+ tile_descriptions = [
2867
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2868
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2869
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2870
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2871
+ TileDescription ([ 32 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2872
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2873
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2874
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2875
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2876
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2877
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2878
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2879
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2880
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2881
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2882
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2883
+ ]
2884
+
2885
+ data_type = [
2886
+ math_inst .element_a ,
2887
+ math_inst .element_b ,
2888
+ math_inst .element_accumulator ,
2889
+ math_inst .element_accumulator ,
2890
+ ]
2891
+
2892
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2893
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2894
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2895
+
2896
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2897
+ if math_inst .element_a != math_inst .element_accumulator :
2898
+ alignment_constraints = [[32 , 16 , 16 ],]
2899
+
2900
+ data_type_mixed = [
2901
+ math_inst .element_a ,
2902
+ math_inst .element_b ,
2903
+ math_inst .element_b ,
2904
+ math_inst .element_accumulator ,
2905
+ ]
2906
+
2907
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
2908
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2909
+
2910
+ #
2911
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version ):
2912
+
2913
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2914
+ return
2915
+
2916
+ layouts = [
2917
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2918
+ ]
2919
+
2920
+ # Upcast on Operand B
2921
+ math_instructions = [
2922
+ MathInstruction ( \
2923
+ [16 , 8 , 32 ], \
2924
+ DataType .s8 , DataType .s4 , DataType .s32 , \
2925
+ OpcodeClass .TensorOp , \
2926
+ MathOperation .multiply_add_mixed_input_upcast ),
2927
+ ]
2928
+
2929
+ min_cc = 80
2930
+ max_cc = 1024
2931
+
2932
+ # For mixed-input alignment constraints are a list of lists, where the
2933
+ # inner list contains the alignment constraints for operands/matrices
2934
+ # [[alignA, alignB, alignC],..]
2935
+ alignment_constraints = [[16 , 32 , 4 ],]
2936
+
2937
+ for math_inst in math_instructions :
2938
+ tile_descriptions = [
2939
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2940
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2941
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2942
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2943
+ TileDescription ([256 , 32 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2944
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2945
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2946
+ TileDescription ([128 , 32 , 64 ], 6 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2947
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2948
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2949
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2950
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2951
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2952
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2953
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2954
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2955
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2956
+ ]
2957
+
2958
+ data_type = [
2959
+ math_inst .element_a ,
2960
+ math_inst .element_b ,
2961
+ math_inst .element_accumulator ,
2962
+ math_inst .element_accumulator ,
2963
+ ]
2964
+
2965
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2966
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2967
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2968
+
2969
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2970
+ if math_inst .element_a != math_inst .element_accumulator :
2971
+ alignment_constraints = [[16 , 32 , 16 ],]
2972
+
2973
+ data_type_mixed = [
2974
+ math_inst .element_a ,
2975
+ math_inst .element_b ,
2976
+ math_inst .element_a ,
2977
+ math_inst .element_accumulator ,
2978
+ ]
2979
+
2980
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
2981
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2982
+
2838
2983
#
2839
2984
2840
2985
#
@@ -4680,6 +4825,8 @@ def GenerateSM80(manifest, cuda_version):
4680
4825
GenerateSM80_TensorOp_16816_mixed_input_upcast_a (manifest , cuda_version )
4681
4826
GenerateSM80_TensorOp_16816_mixed_input_upcast_b (manifest , cuda_version )
4682
4827
GenerateSM80_TensorOp_16832_TN (manifest , cuda_version )
4828
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version )
4829
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version )
4683
4830
GenerateSM80_SparseTensorOp_16864_TN (manifest , cuda_version )
4684
4831
GenerateSM80_TensorOp_16832_Interleaved (manifest , cuda_version )
4685
4832
GenerateSM80_TensorOp_16864_TN (manifest , cuda_version )
0 commit comments