Skip to content

Commit 95a7d30

Browse files
Aleksandar Samardžićalexsamardzic
Aleksandar Samardžić
authored andcommitted
Add support for mixed 4-bit/8-bit data types GEMM
1 parent 56b46e2 commit 95a7d30

14 files changed

+844
-14
lines changed

include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
268268
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
269269

270270
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
271-
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
271+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
272272
ElementA, ElementB>::type;
273273

274274
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
@@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
294294
Policy, PartitionsK, AccumulatorsInRowMajor>;
295295
};
296296

297+
298+
/////////////////////////////////////////////////////////////////////////////////////////////////
299+
300+
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
301+
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
302+
template <
303+
/// Shape of one matrix production operation (concept: GemmShape)
304+
typename WarpShape_,
305+
/// Element type of A matrix
306+
typename ElementA,
307+
/// Layout of A matrix (concept: MatrixLayout)
308+
typename LayoutA,
309+
/// Element type of B matrix
310+
typename ElementB,
311+
/// Layout of B matrix (concept: MatrixLayout)
312+
typename LayoutB,
313+
/// Element type of C matrix
314+
typename ElementC,
315+
/// Layout of C matrix (concept: MatrixLayout)
316+
typename LayoutC,
317+
/// Number of partitions along K dimension
318+
int PartitionsK,
319+
/// Store the accumulators in row major or column major. Row major is used
320+
/// when output layout is interleaved.
321+
bool AccumulatorsInRowMajor>
322+
struct DefaultMmaTensorOp<
323+
WarpShape_,
324+
GemmShape<16, 8, 32>, // InstructionShape
325+
ElementA, // Element type of A matrix in Global Memory
326+
LayoutA, // Layout of A matrix in Global Memory
327+
ElementB, // Element type of B matrix in Global Memory
328+
LayoutB, // Layout of B matrix in Global Memory
329+
ElementC, // Element type of C matrix in Global Memory
330+
LayoutC, // Layout of C matrix in Global Memory
331+
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
332+
PartitionsK, AccumulatorsInRowMajor> {
333+
334+
335+
// Check if the ElementA and ElementB are of different data types
336+
static_assert(!platform::is_same<ElementA, ElementB>::value,
337+
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
338+
339+
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
340+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
341+
ElementA, ElementB>::type;
342+
343+
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
344+
using MmaElementA = ElementOperand;
345+
using MmaElementB = ElementOperand;
346+
using MmaElementC = ElementC;
347+
348+
// Uses
349+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
350+
cutlass::arch::Mma<
351+
GemmShape<16, 8, 32>,
352+
32,
353+
MmaElementA, cutlass::layout::RowMajor,
354+
MmaElementB, cutlass::layout::ColumnMajor,
355+
MmaElementC, cutlass::layout::RowMajor,
356+
arch::OpMultiplyAdd
357+
>,
358+
cutlass::MatrixShape<1, 1> >;
359+
360+
// Define the warp-level tensor op
361+
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
362+
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
363+
Policy, PartitionsK, AccumulatorsInRowMajor>;
364+
};
365+
297366
/////////////////////////////////////////////////////////////////////////////////////////////////
298367

299368
} // namespace warp

include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct FragmentShuffler {
104104
////////////////////////////////////////////////////////////////////////////////
105105

106106
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
107+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
107108
/// for operand A multiplicand going through upcasting.
108109
template <
109110
/// Element type for the operand in registers for the mma.sync
@@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
122123
NumElementsInWarpFragment,
123124
NumElementsInMmaFragment,
124125
Operand::kA,
125-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
126-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
126+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
127+
(sizeof_bits<ElementLoad_>::value == 8)) ||
128+
((sizeof_bits<ElementMma_>::value == 8) &&
129+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
127130
public:
128131
using ElementMma = ElementMma_;
129132
using ElementLoad = ElementLoad_;
@@ -187,6 +190,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
187190
////////////////////////////////////////////////////////////////////////////////
188191

189192
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
193+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
190194
/// for operand B multiplicand going through upcasting.
191195
template <
192196
/// Element type for the operand in registers for the mma.sync
@@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
205209
NumElementsInWarpFragment,
206210
NumElementsInMmaFragment,
207211
Operand::kB,
208-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
209-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
212+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
213+
(sizeof_bits<ElementLoad_>::value == 8)) ||
214+
((sizeof_bits<ElementMma_>::value == 8) &&
215+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
210216
public:
211217
using ElementMma = ElementMma_;
212218
using ElementLoad = ElementLoad_;

include/cutlass/numeric_conversion.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
25962596
}
25972597
};
25982598

2599+
/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
2600+
template <
2601+
FloatRoundStyle Round
2602+
>
2603+
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {
2604+
2605+
using result_type = Array<int8_t, 8>;
2606+
using source_type = Array<int4b_t, 8>;
2607+
static FloatRoundStyle const round_style = Round;
2608+
2609+
CUTLASS_HOST_DEVICE
2610+
static result_type convert(source_type const & source) {
2611+
2612+
unsigned const& storage = reinterpret_cast<unsigned const &>(source);
2613+
unsigned out[2];
2614+
2615+
asm volatile(
2616+
"{ .reg .u32 tmp0, tmp1, tmp2;"
2617+
"shl.b32 tmp0, %2, 4;"
2618+
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
2619+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2620+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2621+
"shr.u32 tmp0, tmp0, 4;"
2622+
"or.b32 tmp2, tmp0, tmp1;"
2623+
"and.b32 tmp0, %2, 0xf0f0f0f0;"
2624+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2625+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2626+
"shr.u32 tmp0, tmp0, 4;"
2627+
"or.b32 tmp0, tmp0, tmp1;"
2628+
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
2629+
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
2630+
"}"
2631+
: "=r"(out[0]), "=r"(out[1])
2632+
: "r"(storage));
2633+
2634+
return reinterpret_cast<result_type const &>(out);
2635+
}
2636+
2637+
CUTLASS_HOST_DEVICE
2638+
result_type operator()(source_type const &s) const {
2639+
return convert(s);
2640+
}
2641+
};
2642+
2643+
/// Partial specialization for Array<int8_t> <= Array<int4b_t>
2644+
template <
2645+
int N,
2646+
FloatRoundStyle Round
2647+
>
2648+
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
2649+
static_assert(!(N % 8), "N must be multiple of 8.");
2650+
2651+
using result_type = Array<int8_t, N>;
2652+
using source_type = Array<int4b_t, N>;
2653+
static FloatRoundStyle const round_style = Round;
2654+
2655+
CUTLASS_HOST_DEVICE
2656+
static result_type convert(source_type const & source) {
2657+
2658+
NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;
2659+
2660+
result_type result;
2661+
2662+
Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
2663+
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);
2664+
2665+
CUTLASS_PRAGMA_UNROLL
2666+
for (int i = 0; i < N / 8; ++i) {
2667+
result_ptr[i] = convert_vector_(source_ptr[i]);
2668+
}
2669+
2670+
return result;
2671+
}
2672+
2673+
CUTLASS_HOST_DEVICE
2674+
result_type operator()(source_type const &s) const {
2675+
return convert(s);
2676+
}
2677+
};
2678+
25992679
#endif // Conditional guards to enable partial specialization for packed integers
26002680

26012681
namespace detail {

python/cutlass_library/generator.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2836,6 +2836,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
28362836
else:
28372837
op.C.alignment = 8
28382838

2839+
#
2840+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version):
2841+
2842+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2843+
return
2844+
2845+
layouts = [
2846+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2847+
]
2848+
2849+
# Upcast on Operand A
2850+
math_instructions = [
2851+
MathInstruction( \
2852+
[16, 8, 32], \
2853+
DataType.s4, DataType.s8, DataType.s32, \
2854+
OpcodeClass.TensorOp, \
2855+
MathOperation.multiply_add_mixed_input_upcast),
2856+
]
2857+
2858+
min_cc = 80
2859+
max_cc = 1024
2860+
2861+
# For mixed-input alignment constraints are a list of lists, where the
2862+
# inner list contains the alignment constraints for operands/matrices
2863+
# [[alignA, alignB, alignC],..]
2864+
alignment_constraints = [[32, 16, 4],]
2865+
2866+
for math_inst in math_instructions:
2867+
tile_descriptions = [
2868+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2869+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2870+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2871+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2872+
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2873+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2874+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2875+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2876+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2877+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2878+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2879+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2880+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2881+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2882+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2883+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2884+
]
2885+
2886+
data_type = [
2887+
math_inst.element_a,
2888+
math_inst.element_b,
2889+
math_inst.element_accumulator,
2890+
math_inst.element_accumulator,
2891+
]
2892+
2893+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2894+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2895+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2896+
2897+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2898+
if math_inst.element_a != math_inst.element_accumulator:
2899+
alignment_constraints = [[32, 16, 16],]
2900+
2901+
data_type_mixed = [
2902+
math_inst.element_a,
2903+
math_inst.element_b,
2904+
math_inst.element_b,
2905+
math_inst.element_accumulator,
2906+
]
2907+
2908+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2909+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2910+
2911+
#
2912+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version):
2913+
2914+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2915+
return
2916+
2917+
layouts = [
2918+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2919+
]
2920+
2921+
# Upcast on Operand B
2922+
math_instructions = [
2923+
MathInstruction( \
2924+
[16, 8, 32], \
2925+
DataType.s8, DataType.s4, DataType.s32, \
2926+
OpcodeClass.TensorOp, \
2927+
MathOperation.multiply_add_mixed_input_upcast),
2928+
]
2929+
2930+
min_cc = 80
2931+
max_cc = 1024
2932+
2933+
# For mixed-input alignment constraints are a list of lists, where the
2934+
# inner list contains the alignment constraints for operands/matrices
2935+
# [[alignA, alignB, alignC],..]
2936+
alignment_constraints = [[16, 32, 4],]
2937+
2938+
for math_inst in math_instructions:
2939+
tile_descriptions = [
2940+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2941+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2942+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2943+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2944+
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2945+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2946+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2947+
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
2948+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2949+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2950+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2951+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2952+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2953+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2954+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2955+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2956+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2957+
]
2958+
2959+
data_type = [
2960+
math_inst.element_a,
2961+
math_inst.element_b,
2962+
math_inst.element_accumulator,
2963+
math_inst.element_accumulator,
2964+
]
2965+
2966+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2967+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2968+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2969+
2970+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2971+
if math_inst.element_a != math_inst.element_accumulator:
2972+
alignment_constraints = [[16, 32, 16],]
2973+
2974+
data_type_mixed = [
2975+
math_inst.element_a,
2976+
math_inst.element_b,
2977+
math_inst.element_a,
2978+
math_inst.element_accumulator,
2979+
]
2980+
2981+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2982+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2983+
28392984
#
28402985

28412986
#
@@ -4681,6 +4826,8 @@ def GenerateSM80(manifest, cuda_version):
46814826
GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version)
46824827
GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version)
46834828
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
4829+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version)
4830+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version)
46844831
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
46854832
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
46864833
GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)

0 commit comments

Comments
 (0)