@@ -2854,6 +2854,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
2854
2854
else :
2855
2855
op .C .alignment = 8
2856
2856
2857
+ #
2858
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version ):
2859
+
2860
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2861
+ return
2862
+
2863
+ layouts = [
2864
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2865
+ ]
2866
+
2867
+ # Upcast on Operand A
2868
+ math_instructions = [
2869
+ MathInstruction ( \
2870
+ [16 , 8 , 32 ], \
2871
+ DataType .s4 , DataType .s8 , DataType .s32 , \
2872
+ OpcodeClass .TensorOp , \
2873
+ MathOperation .multiply_add_mixed_input_upcast ),
2874
+ ]
2875
+
2876
+ min_cc = 80
2877
+ max_cc = 1024
2878
+
2879
+ # For mixed-input alignment constraints are a list of lists, where the
2880
+ # inner list contains the alignment constraints for operands/matrices
2881
+ # [[alignA, alignB, alignC],..]
2882
+ alignment_constraints = [[32 , 16 , 4 ],]
2883
+
2884
+ for math_inst in math_instructions :
2885
+ tile_descriptions = [
2886
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2887
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2888
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2889
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2890
+ TileDescription ([ 32 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2891
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2892
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2893
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2894
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2895
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2896
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2897
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2898
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2899
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2900
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2901
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2902
+ ]
2903
+
2904
+ data_type = [
2905
+ math_inst .element_a ,
2906
+ math_inst .element_b ,
2907
+ math_inst .element_accumulator ,
2908
+ math_inst .element_accumulator ,
2909
+ ]
2910
+
2911
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2912
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2913
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2914
+
2915
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2916
+ if math_inst .element_a != math_inst .element_accumulator :
2917
+ alignment_constraints = [[32 , 16 , 16 ],]
2918
+
2919
+ data_type_mixed = [
2920
+ math_inst .element_a ,
2921
+ math_inst .element_b ,
2922
+ math_inst .element_b ,
2923
+ math_inst .element_accumulator ,
2924
+ ]
2925
+
2926
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
2927
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2928
+
2929
+ #
2930
+ def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version ):
2931
+
2932
+ if not CudaToolkitVersionSatisfies (cuda_version , 11 , 0 ):
2933
+ return
2934
+
2935
+ layouts = [
2936
+ (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .ColumnMajor ),
2937
+ ]
2938
+
2939
+ # Upcast on Operand B
2940
+ math_instructions = [
2941
+ MathInstruction ( \
2942
+ [16 , 8 , 32 ], \
2943
+ DataType .s8 , DataType .s4 , DataType .s32 , \
2944
+ OpcodeClass .TensorOp , \
2945
+ MathOperation .multiply_add_mixed_input_upcast ),
2946
+ ]
2947
+
2948
+ min_cc = 80
2949
+ max_cc = 1024
2950
+
2951
+ # For mixed-input alignment constraints are a list of lists, where the
2952
+ # inner list contains the alignment constraints for operands/matrices
2953
+ # [[alignA, alignB, alignC],..]
2954
+ alignment_constraints = [[16 , 32 , 4 ],]
2955
+
2956
+ for math_inst in math_instructions :
2957
+ tile_descriptions = [
2958
+ TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2959
+ TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2960
+ TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2961
+ TileDescription ([ 64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2962
+ TileDescription ([256 , 32 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2963
+ TileDescription ([128 , 128 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2964
+ TileDescription ([ 64 , 128 , 64 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2965
+ TileDescription ([128 , 32 , 64 ], 6 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2966
+ TileDescription ([256 , 128 , 128 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
2967
+ TileDescription ([128 , 256 , 128 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
2968
+ TileDescription ([256 , 64 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2969
+ TileDescription ([ 64 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2970
+ TileDescription ([256 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2971
+ TileDescription ([ 32 , 256 , 128 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
2972
+ TileDescription ([128 , 128 , 128 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2973
+ TileDescription ([ 64 , 128 , 128 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
2974
+ TileDescription ([128 , 32 , 128 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
2975
+ ]
2976
+
2977
+ data_type = [
2978
+ math_inst .element_a ,
2979
+ math_inst .element_b ,
2980
+ math_inst .element_accumulator ,
2981
+ math_inst .element_accumulator ,
2982
+ ]
2983
+
2984
+ # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2985
+ operations = CreateGemmOperator (manifest , layouts , tile_descriptions , \
2986
+ data_type , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
2987
+
2988
+ # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2989
+ if math_inst .element_a != math_inst .element_accumulator :
2990
+ alignment_constraints = [[16 , 32 , 16 ],]
2991
+
2992
+ data_type_mixed = [
2993
+ math_inst .element_a ,
2994
+ math_inst .element_b ,
2995
+ math_inst .element_a ,
2996
+ math_inst .element_accumulator ,
2997
+ ]
2998
+
2999
+ operations += CreateGemmOperator (manifest , layouts , tile_descriptions , \
3000
+ data_type_mixed , alignment_constraints , None , EpilogueFunctor .LinearCombination , SwizzlingFunctor .Identity8 )
3001
+
2857
3002
#
2858
3003
2859
3004
#
@@ -4699,6 +4844,8 @@ def GenerateSM80(manifest, cuda_version):
4699
4844
GenerateSM80_TensorOp_16816_mixed_input_upcast_a (manifest , cuda_version )
4700
4845
GenerateSM80_TensorOp_16816_mixed_input_upcast_b (manifest , cuda_version )
4701
4846
GenerateSM80_TensorOp_16832_TN (manifest , cuda_version )
4847
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a (manifest , cuda_version )
4848
+ GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b (manifest , cuda_version )
4702
4849
GenerateSM80_SparseTensorOp_16864_TN (manifest , cuda_version )
4703
4850
GenerateSM80_TensorOp_16832_Interleaved (manifest , cuda_version )
4704
4851
GenerateSM80_TensorOp_16864_TN (manifest , cuda_version )
0 commit comments