Skip to content

Commit fcebe62

Browse files
Aleksandar Samardžićalexsamardzic
Aleksandar Samardžić
authored andcommitted
Add support for mixed 4-bit/8-bit data types GEMM
1 parent f93a691 commit fcebe62

14 files changed

+890
-14
lines changed

include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
268268
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
269269

270270
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
271-
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
271+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
272272
ElementA, ElementB>::type;
273273

274274
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
@@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
294294
Policy, PartitionsK, AccumulatorsInRowMajor>;
295295
};
296296

297+
298+
/////////////////////////////////////////////////////////////////////////////////////////////////
299+
300+
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
301+
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
302+
template <
303+
/// Shape of one matrix production operation (concept: GemmShape)
304+
typename WarpShape_,
305+
/// Element type of A matrix
306+
typename ElementA,
307+
/// Layout of A matrix (concept: MatrixLayout)
308+
typename LayoutA,
309+
/// Element type of B matrix
310+
typename ElementB,
311+
/// Layout of B matrix (concept: MatrixLayout)
312+
typename LayoutB,
313+
/// Element type of C matrix
314+
typename ElementC,
315+
/// Layout of C matrix (concept: MatrixLayout)
316+
typename LayoutC,
317+
/// Number of partitions along K dimension
318+
int PartitionsK,
319+
/// Store the accumulators in row major or column major. Row major is used
320+
/// when output layout is interleaved.
321+
bool AccumulatorsInRowMajor>
322+
struct DefaultMmaTensorOp<
323+
WarpShape_,
324+
GemmShape<16, 8, 32>, // InstructionShape
325+
ElementA, // Element type of A matrix in Global Memory
326+
LayoutA, // Layout of A matrix in Global Memory
327+
ElementB, // Element type of B matrix in Global Memory
328+
LayoutB, // Layout of B matrix in Global Memory
329+
ElementC, // Element type of C matrix in Global Memory
330+
LayoutC, // Layout of C matrix in Global Memory
331+
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
332+
PartitionsK, AccumulatorsInRowMajor> {
333+
334+
335+
// Check if the ElementA and ElementB are of different data types
336+
static_assert(!platform::is_same<ElementA, ElementB>::value,
337+
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
338+
339+
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
340+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
341+
ElementA, ElementB>::type;
342+
343+
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
344+
using MmaElementA = ElementOperand;
345+
using MmaElementB = ElementOperand;
346+
using MmaElementC = ElementC;
347+
348+
// Uses
349+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
350+
cutlass::arch::Mma<
351+
GemmShape<16, 8, 32>,
352+
32,
353+
MmaElementA, cutlass::layout::RowMajor,
354+
MmaElementB, cutlass::layout::ColumnMajor,
355+
MmaElementC, cutlass::layout::RowMajor,
356+
arch::OpMultiplyAddSaturate
357+
>,
358+
cutlass::MatrixShape<1, 1> >;
359+
360+
// Define the warp-level tensor op
361+
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
362+
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
363+
Policy, PartitionsK, AccumulatorsInRowMajor>;
364+
};
365+
297366
/////////////////////////////////////////////////////////////////////////////////////////////////
298367

299368
} // namespace warp

include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct FragmentShuffler {
104104
////////////////////////////////////////////////////////////////////////////////
105105

106106
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
107+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
107108
/// for operand A multiplicand going through upcasting.
108109
template <
109110
/// Element type for the operand in registers for the mma.sync
@@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
122123
NumElementsInWarpFragment,
123124
NumElementsInMmaFragment,
124125
Operand::kA,
125-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
126-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
126+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
127+
(sizeof_bits<ElementLoad_>::value == 8)) ||
128+
((sizeof_bits<ElementMma_>::value == 8) &&
129+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
127130
public:
128131
using ElementMma = ElementMma_;
129132
using ElementLoad = ElementLoad_;
@@ -187,6 +190,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
187190
////////////////////////////////////////////////////////////////////////////////
188191

189192
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
193+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
190194
/// for operand B multiplicand going through upcasting.
191195
template <
192196
/// Element type for the operand in registers for the mma.sync
@@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
205209
NumElementsInWarpFragment,
206210
NumElementsInMmaFragment,
207211
Operand::kB,
208-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
209-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
212+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
213+
(sizeof_bits<ElementLoad_>::value == 8)) ||
214+
((sizeof_bits<ElementMma_>::value == 8) &&
215+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
210216
public:
211217
using ElementMma = ElementMma_;
212218
using ElementLoad = ElementLoad_;

include/cutlass/numeric_conversion.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2771,6 +2771,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
27712771
}
27722772
};
27732773

2774+
/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
2775+
template <
2776+
FloatRoundStyle Round
2777+
>
2778+
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {
2779+
2780+
using result_type = Array<int8_t, 8>;
2781+
using source_type = Array<int4b_t, 8>;
2782+
static FloatRoundStyle const round_style = Round;
2783+
2784+
CUTLASS_HOST_DEVICE
2785+
static result_type convert(source_type const & source) {
2786+
2787+
unsigned const& storage = reinterpret_cast<unsigned const &>(source);
2788+
unsigned out[2];
2789+
2790+
asm volatile(
2791+
"{ .reg .u32 tmp0, tmp1, tmp2;"
2792+
"shl.b32 tmp0, %2, 4;"
2793+
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
2794+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2795+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2796+
"shr.u32 tmp0, tmp0, 4;"
2797+
"or.b32 tmp2, tmp0, tmp1;"
2798+
"and.b32 tmp0, %2, 0xf0f0f0f0;"
2799+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2800+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2801+
"shr.u32 tmp0, tmp0, 4;"
2802+
"or.b32 tmp0, tmp0, tmp1;"
2803+
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
2804+
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
2805+
"}"
2806+
: "=r"(out[0]), "=r"(out[1])
2807+
: "r"(storage));
2808+
2809+
return reinterpret_cast<result_type const &>(out);
2810+
}
2811+
2812+
CUTLASS_HOST_DEVICE
2813+
result_type operator()(source_type const &s) const {
2814+
return convert(s);
2815+
}
2816+
};
2817+
2818+
/// Partial specialization for Array<int8_t> <= Array<int4b_t>
2819+
template <
2820+
int N,
2821+
FloatRoundStyle Round
2822+
>
2823+
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
2824+
static_assert(!(N % 8), "N must be multiple of 8.");
2825+
2826+
using result_type = Array<int8_t, N>;
2827+
using source_type = Array<int4b_t, N>;
2828+
static FloatRoundStyle const round_style = Round;
2829+
2830+
CUTLASS_HOST_DEVICE
2831+
static result_type convert(source_type const & source) {
2832+
2833+
NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;
2834+
2835+
result_type result;
2836+
2837+
Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
2838+
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);
2839+
2840+
CUTLASS_PRAGMA_UNROLL
2841+
for (int i = 0; i < N / 8; ++i) {
2842+
result_ptr[i] = convert_vector_(source_ptr[i]);
2843+
}
2844+
2845+
return result;
2846+
}
2847+
2848+
CUTLASS_HOST_DEVICE
2849+
result_type operator()(source_type const &s) const {
2850+
return convert(s);
2851+
}
2852+
};
2853+
27742854
#endif // Conditional guards to enable partial specialization for packed integers
27752855

27762856
namespace detail {

python/cutlass_library/generator.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2854,6 +2854,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
28542854
else:
28552855
op.C.alignment = 8
28562856

2857+
#
2858+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version):
2859+
2860+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2861+
return
2862+
2863+
layouts = [
2864+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2865+
]
2866+
2867+
# Upcast on Operand A
2868+
math_instructions = [
2869+
MathInstruction( \
2870+
[16, 8, 32], \
2871+
DataType.s4, DataType.s8, DataType.s32, \
2872+
OpcodeClass.TensorOp, \
2873+
MathOperation.multiply_add_mixed_input_upcast),
2874+
]
2875+
2876+
min_cc = 80
2877+
max_cc = 1024
2878+
2879+
# For mixed-input alignment constraints are a list of lists, where the
2880+
# inner list contains the alignment constraints for operands/matrices
2881+
# [[alignA, alignB, alignC],..]
2882+
alignment_constraints = [[32, 16, 4],]
2883+
2884+
for math_inst in math_instructions:
2885+
tile_descriptions = [
2886+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2887+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2888+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2889+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2890+
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2891+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2892+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2893+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2894+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2895+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2896+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2897+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2898+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2899+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2900+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2901+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2902+
]
2903+
2904+
data_type = [
2905+
math_inst.element_a,
2906+
math_inst.element_b,
2907+
math_inst.element_accumulator,
2908+
math_inst.element_accumulator,
2909+
]
2910+
2911+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2912+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2913+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2914+
2915+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2916+
if math_inst.element_a != math_inst.element_accumulator:
2917+
alignment_constraints = [[32, 16, 16],]
2918+
2919+
data_type_mixed = [
2920+
math_inst.element_a,
2921+
math_inst.element_b,
2922+
math_inst.element_b,
2923+
math_inst.element_accumulator,
2924+
]
2925+
2926+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2927+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2928+
2929+
#
2930+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version):
2931+
2932+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2933+
return
2934+
2935+
layouts = [
2936+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2937+
]
2938+
2939+
# Upcast on Operand B
2940+
math_instructions = [
2941+
MathInstruction( \
2942+
[16, 8, 32], \
2943+
DataType.s8, DataType.s4, DataType.s32, \
2944+
OpcodeClass.TensorOp, \
2945+
MathOperation.multiply_add_mixed_input_upcast),
2946+
]
2947+
2948+
min_cc = 80
2949+
max_cc = 1024
2950+
2951+
# For mixed-input alignment constraints are a list of lists, where the
2952+
# inner list contains the alignment constraints for operands/matrices
2953+
# [[alignA, alignB, alignC],..]
2954+
alignment_constraints = [[16, 32, 4],]
2955+
2956+
for math_inst in math_instructions:
2957+
tile_descriptions = [
2958+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2959+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2960+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2961+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2962+
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2963+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2964+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2965+
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
2966+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2967+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2968+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2969+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2970+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2971+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2972+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2973+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2974+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2975+
]
2976+
2977+
data_type = [
2978+
math_inst.element_a,
2979+
math_inst.element_b,
2980+
math_inst.element_accumulator,
2981+
math_inst.element_accumulator,
2982+
]
2983+
2984+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2985+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2986+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2987+
2988+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2989+
if math_inst.element_a != math_inst.element_accumulator:
2990+
alignment_constraints = [[16, 32, 16],]
2991+
2992+
data_type_mixed = [
2993+
math_inst.element_a,
2994+
math_inst.element_b,
2995+
math_inst.element_a,
2996+
math_inst.element_accumulator,
2997+
]
2998+
2999+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
3000+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
3001+
28573002
#
28583003

28593004
#
@@ -4699,6 +4844,8 @@ def GenerateSM80(manifest, cuda_version):
46994844
GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version)
47004845
GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version)
47014846
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
4847+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version)
4848+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version)
47024849
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
47034850
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
47044851
GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)

0 commit comments

Comments
 (0)