Skip to content

Commit 1dfe9c0

Browse files
author
Aleksandar Samardžić
committed
Add support for mixed 4-bit/8-bit data types GEMM
1 parent c4e3e12 commit 1dfe9c0

14 files changed

+844
-14
lines changed

include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
268268
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
269269

270270
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
271-
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
271+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
272272
ElementA, ElementB>::type;
273273

274274
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
@@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
294294
Policy, PartitionsK, AccumulatorsInRowMajor>;
295295
};
296296

297+
298+
/////////////////////////////////////////////////////////////////////////////////////////////////
299+
300+
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
301+
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
302+
template <
303+
/// Shape of one matrix production operation (concept: GemmShape)
304+
typename WarpShape_,
305+
/// Element type of A matrix
306+
typename ElementA,
307+
/// Layout of A matrix (concept: MatrixLayout)
308+
typename LayoutA,
309+
/// Element type of B matrix
310+
typename ElementB,
311+
/// Layout of B matrix (concept: MatrixLayout)
312+
typename LayoutB,
313+
/// Element type of C matrix
314+
typename ElementC,
315+
/// Layout of C matrix (concept: MatrixLayout)
316+
typename LayoutC,
317+
/// Number of partitions along K dimension
318+
int PartitionsK,
319+
/// Store the accumulators in row major or column major. Row major is used
320+
/// when output layout is interleaved.
321+
bool AccumulatorsInRowMajor>
322+
struct DefaultMmaTensorOp<
323+
WarpShape_,
324+
GemmShape<16, 8, 32>, // InstructionShape
325+
ElementA, // Element type of A matrix in Global Memory
326+
LayoutA, // Layout of A matrix in Global Memory
327+
ElementB, // Element type of B matrix in Global Memory
328+
LayoutB, // Layout of B matrix in Global Memory
329+
ElementC, // Element type of C matrix in Global Memory
330+
LayoutC, // Layout of C matrix in Global Memory
331+
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
332+
PartitionsK, AccumulatorsInRowMajor> {
333+
334+
335+
// Check if the ElementA and ElementB are of different data types
336+
static_assert(!platform::is_same<ElementA, ElementB>::value,
337+
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
338+
339+
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
340+
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
341+
ElementA, ElementB>::type;
342+
343+
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
344+
using MmaElementA = ElementOperand;
345+
using MmaElementB = ElementOperand;
346+
using MmaElementC = ElementC;
347+
348+
// Uses
349+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
350+
cutlass::arch::Mma<
351+
GemmShape<16, 8, 32>,
352+
32,
353+
MmaElementA, cutlass::layout::RowMajor,
354+
MmaElementB, cutlass::layout::ColumnMajor,
355+
MmaElementC, cutlass::layout::RowMajor,
356+
arch::OpMultiplyAdd
357+
>,
358+
cutlass::MatrixShape<1, 1> >;
359+
360+
// Define the warp-level tensor op
361+
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
362+
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
363+
Policy, PartitionsK, AccumulatorsInRowMajor>;
364+
};
365+
297366
/////////////////////////////////////////////////////////////////////////////////////////////////
298367

299368
} // namespace warp

include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct FragmentShuffler {
104104
////////////////////////////////////////////////////////////////////////////////
105105

106106
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
107+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
107108
/// for operand A multiplicand going through upcasting.
108109
template <
109110
/// Element type for the operand in registers for the mma.sync
@@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
122123
NumElementsInWarpFragment,
123124
NumElementsInMmaFragment,
124125
Operand::kA,
125-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
126-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
126+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
127+
(sizeof_bits<ElementLoad_>::value == 8)) ||
128+
((sizeof_bits<ElementMma_>::value == 8) &&
129+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
127130
public:
128131
using ElementMma = ElementMma_;
129132
using ElementLoad = ElementLoad_;
@@ -187,6 +190,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
187190
////////////////////////////////////////////////////////////////////////////////
188191

189192
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
193+
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
190194
/// for operand B multiplicand going through upcasting.
191195
template <
192196
/// Element type for the operand in registers for the mma.sync
@@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
205209
NumElementsInWarpFragment,
206210
NumElementsInMmaFragment,
207211
Operand::kB,
208-
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
209-
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
212+
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
213+
(sizeof_bits<ElementLoad_>::value == 8)) ||
214+
((sizeof_bits<ElementMma_>::value == 8) &&
215+
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
210216
public:
211217
using ElementMma = ElementMma_;
212218
using ElementLoad = ElementLoad_;

include/cutlass/numeric_conversion.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
24912491
}
24922492
};
24932493

2494+
/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
2495+
template <
2496+
FloatRoundStyle Round
2497+
>
2498+
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {
2499+
2500+
using result_type = Array<int8_t, 8>;
2501+
using source_type = Array<int4b_t, 8>;
2502+
static FloatRoundStyle const round_style = Round;
2503+
2504+
CUTLASS_HOST_DEVICE
2505+
static result_type convert(source_type const & source) {
2506+
2507+
unsigned const& storage = reinterpret_cast<unsigned const &>(source);
2508+
unsigned out[2];
2509+
2510+
asm volatile(
2511+
"{ .reg .u32 tmp0, tmp1, tmp2;"
2512+
"shl.b32 tmp0, %2, 4;"
2513+
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
2514+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2515+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2516+
"shr.u32 tmp0, tmp0, 4;"
2517+
"or.b32 tmp2, tmp0, tmp1;"
2518+
"and.b32 tmp0, %2, 0xf0f0f0f0;"
2519+
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
2520+
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
2521+
"shr.u32 tmp0, tmp0, 4;"
2522+
"or.b32 tmp0, tmp0, tmp1;"
2523+
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
2524+
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
2525+
"}"
2526+
: "=r"(out[0]), "=r"(out[1])
2527+
: "r"(storage));
2528+
2529+
return reinterpret_cast<result_type const &>(out);
2530+
}
2531+
2532+
CUTLASS_HOST_DEVICE
2533+
result_type operator()(source_type const &s) const {
2534+
return convert(s);
2535+
}
2536+
};
2537+
2538+
/// Partial specialization for Array<int8_t> <= Array<int4b_t>
2539+
template <
2540+
int N,
2541+
FloatRoundStyle Round
2542+
>
2543+
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
2544+
static_assert(!(N % 8), "N must be multiple of 8.");
2545+
2546+
using result_type = Array<int8_t, N>;
2547+
using source_type = Array<int4b_t, N>;
2548+
static FloatRoundStyle const round_style = Round;
2549+
2550+
CUTLASS_HOST_DEVICE
2551+
static result_type convert(source_type const & source) {
2552+
2553+
NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;
2554+
2555+
result_type result;
2556+
2557+
Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
2558+
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);
2559+
2560+
CUTLASS_PRAGMA_UNROLL
2561+
for (int i = 0; i < N / 8; ++i) {
2562+
result_ptr[i] = convert_vector_(source_ptr[i]);
2563+
}
2564+
2565+
return result;
2566+
}
2567+
2568+
CUTLASS_HOST_DEVICE
2569+
result_type operator()(source_type const &s) const {
2570+
return convert(s);
2571+
}
2572+
};
2573+
24942574
#endif // Conditional guards to enable partial specialization for packed integers
24952575

24962576
namespace detail {

python/cutlass_library/generator.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,6 +2835,151 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
28352835
else:
28362836
op.C.alignment = 8
28372837

2838+
#
2839+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version):
2840+
2841+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2842+
return
2843+
2844+
layouts = [
2845+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2846+
]
2847+
2848+
# Upcast on Operand A
2849+
math_instructions = [
2850+
MathInstruction( \
2851+
[16, 8, 32], \
2852+
DataType.s4, DataType.s8, DataType.s32, \
2853+
OpcodeClass.TensorOp, \
2854+
MathOperation.multiply_add_mixed_input_upcast),
2855+
]
2856+
2857+
min_cc = 80
2858+
max_cc = 1024
2859+
2860+
# For mixed-input alignment constraints are a list of lists, where the
2861+
# inner list contains the alignment constraints for operands/matrices
2862+
# [[alignA, alignB, alignC],..]
2863+
alignment_constraints = [[32, 16, 4],]
2864+
2865+
for math_inst in math_instructions:
2866+
tile_descriptions = [
2867+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2868+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2869+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2870+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2871+
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2872+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2873+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2874+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2875+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2876+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2877+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2878+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2879+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2880+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2881+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2882+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2883+
]
2884+
2885+
data_type = [
2886+
math_inst.element_a,
2887+
math_inst.element_b,
2888+
math_inst.element_accumulator,
2889+
math_inst.element_accumulator,
2890+
]
2891+
2892+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2893+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2894+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2895+
2896+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2897+
if math_inst.element_a != math_inst.element_accumulator:
2898+
alignment_constraints = [[32, 16, 16],]
2899+
2900+
data_type_mixed = [
2901+
math_inst.element_a,
2902+
math_inst.element_b,
2903+
math_inst.element_b,
2904+
math_inst.element_accumulator,
2905+
]
2906+
2907+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2908+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2909+
2910+
#
2911+
def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version):
2912+
2913+
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
2914+
return
2915+
2916+
layouts = [
2917+
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
2918+
]
2919+
2920+
# Upcast on Operand B
2921+
math_instructions = [
2922+
MathInstruction( \
2923+
[16, 8, 32], \
2924+
DataType.s8, DataType.s4, DataType.s32, \
2925+
OpcodeClass.TensorOp, \
2926+
MathOperation.multiply_add_mixed_input_upcast),
2927+
]
2928+
2929+
min_cc = 80
2930+
max_cc = 1024
2931+
2932+
# For mixed-input alignment constraints are a list of lists, where the
2933+
# inner list contains the alignment constraints for operands/matrices
2934+
# [[alignA, alignB, alignC],..]
2935+
alignment_constraints = [[16, 32, 4],]
2936+
2937+
for math_inst in math_instructions:
2938+
tile_descriptions = [
2939+
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2940+
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2941+
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2942+
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2943+
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2944+
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
2945+
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
2946+
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
2947+
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
2948+
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
2949+
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2950+
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2951+
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2952+
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
2953+
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
2954+
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
2955+
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
2956+
]
2957+
2958+
data_type = [
2959+
math_inst.element_a,
2960+
math_inst.element_b,
2961+
math_inst.element_accumulator,
2962+
math_inst.element_accumulator,
2963+
]
2964+
2965+
# streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit.
2966+
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
2967+
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2968+
2969+
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation)
2970+
if math_inst.element_a != math_inst.element_accumulator:
2971+
alignment_constraints = [[16, 32, 16],]
2972+
2973+
data_type_mixed = [
2974+
math_inst.element_a,
2975+
math_inst.element_b,
2976+
math_inst.element_a,
2977+
math_inst.element_accumulator,
2978+
]
2979+
2980+
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
2981+
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)
2982+
28382983
#
28392984

28402985
#
@@ -4680,6 +4825,8 @@ def GenerateSM80(manifest, cuda_version):
46804825
GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version)
46814826
GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version)
46824827
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
4828+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version)
4829+
GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version)
46834830
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
46844831
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
46854832
GenerateSM80_TensorOp_16864_TN(manifest, cuda_version)

0 commit comments

Comments
 (0)