Skip to content

Commit 845c2e1

Browse files
authored
Add contract to mlir-gen (#1027)
Adds `linalg.contract` to MLIR Gen so we can start testing contraction inputs.
1 parent 9f17b49 commit 845c2e1

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

test/Integration/mlir-gen.mlir

+11-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 | tpp-run -e entry -entry-point-result=void
99

1010
// Matmul only
11-
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void
11+
// RUN: mlir-gen --kernel=const --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=MATMUL
12+
// RUN: mlir-gen --kernel=const --batch=10 --layers=10,10 --output=generic | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=MATMUL
13+
// RUN: mlir-gen --kernel=const --batch=10 --layers=10,10 --output=contract | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=MATMUL
14+
// RUN: mlir-gen --kernel=const --batch=10 --layers=10,10 --output=named | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=MATMUL
15+
// RUN: mlir-gen --kernel=const --batch=10 --layers=10,10 --output=named --keep-generic-matmul | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=MATMUL
1216

1317
// Constant values
1418
// RUN: mlir-gen --kernel=const --bias --relu --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=CONSTANT
19+
// RUN: mlir-gen --kernel=const --bias --relu --batch=10 --layers=10,10 --output=generic | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=CONSTANT
20+
// RUN: mlir-gen --kernel=const --bias --relu --batch=10 --layers=10,10 --output=contract | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=CONSTANT
21+
// RUN: mlir-gen --kernel=const --bias --relu --batch=10 --layers=10,10 --output=named | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=CONSTANT
22+
// RUN: mlir-gen --kernel=const --bias --relu --batch=10 --layers=10,10 --output=named --keep-generic-matmul | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=CONSTANT
1523

1624
// Kernel - matmul
1725
// RUN: mlir-gen --kernel=args --seed=123 --float-type=f32 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL
@@ -23,6 +31,8 @@
2331
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
2432
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --tiles=2,2,2 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
2533

34+
// MATMUL:( 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 )
35+
2636
// CONSTANT:( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 )
2737

2838
// GEN-MATMUL: ( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 )

tools/mlir-gen/MLIRGen.cpp

+23-4
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
8585
auto optOutputOpKind =
8686
llvm::StringSwitch<std::optional<OutputOpKind>>(outputOpKindStr)
8787
.CaseLower("generic", OutputOpKind::Generic)
88+
.CaseLower("contract", OutputOpKind::Contract)
8889
.CaseLower("named", OutputOpKind::NamedOp)
8990
.Default(std::nullopt);
9091
assert(optOutputOpKind && "Invalid output Op kind");
92+
assert(!(optOutputOpKind == OutputOpKind::Contract && keepGenericMatmul) &&
93+
"Can't keep generic matmul with contract");
9194
outputOpKind = *optOutputOpKind;
9295

9396
// Parse kernel type
@@ -181,7 +184,7 @@ Value MLIRGenerator::createLayer(LayerArgs &args) {
181184
if (outputOpKind == OutputOpKind::Generic) {
182185
chain = lowerBiasAdd(chain, args.bias.value, args.output.value);
183186
chain = lowerRelu(chain, args.output.value);
184-
} else if (outputOpKind == OutputOpKind::NamedOp) {
187+
} else {
185188
chain = lowerNamedBiasAdd(chain, args.bias.value, args.output.value);
186189
chain = lowerNamedRelu(chain, args.output.value);
187190
}
@@ -190,7 +193,7 @@ Value MLIRGenerator::createLayer(LayerArgs &args) {
190193
if (args.index == layers.size() - 1) {
191194
if (outputOpKind == OutputOpKind::Generic) {
192195
chain = lowerSoftmax(chain, args.output.value);
193-
} else if (outputOpKind == OutputOpKind::NamedOp) {
196+
} else {
194197
chain = lowerNamedSoftmax(chain, args.output.value);
195198
}
196199
}
@@ -405,9 +408,10 @@ Value MLIRGenerator::lowerMatmul(Value input, Value weight, Value output) {
405408
reassociationIndices);
406409
}
407410

408-
if (outputOpKind == OutputOpKind::Generic ||
409-
(outputOpKind == OutputOpKind::NamedOp && keepGenericMatmul)) {
411+
if (outputOpKind == OutputOpKind::Generic || keepGenericMatmul) {
410412
chain = lowerGenericMatmul(input, weight, output);
413+
} else if (outputOpKind == OutputOpKind::Contract) {
414+
chain = lowerContract(input, weight, output);
411415
} else if (outputOpKind == OutputOpKind::NamedOp) {
412416
chain = lowerNamedMatmul(input, weight, output);
413417
}
@@ -442,6 +446,21 @@ Value MLIRGenerator::lowerGenericMatmul(Value input, Value weight,
442446
return matmul;
443447
}
444448

449+
Value MLIRGenerator::lowerContract(Value input, Value weight, Value output) {
450+
// Matmul as a linalg.contract
451+
SmallVector<Attribute> maps;
452+
maps.push_back(AffineMapAttr::get(getMap(input, MAP_MATMUL_INPUT))); // { 0, 2 }
453+
maps.push_back(AffineMapAttr::get(getMap(weight, MAP_MATMUL_WEIGHT))); // { 2, 1 }
454+
maps.push_back(AffineMapAttr::get(getMap(output, MAP_MATMUL_OUTPUT))); // { 0, 1 }
455+
auto contract = builder
456+
.create<linalg::ContractOp>(
457+
loc, output.getType(), ValueRange{input, weight}, ValueRange{output},
458+
builder.getArrayAttr(maps))
459+
.getResult(0);
460+
461+
return contract;
462+
}
463+
445464
Value MLIRGenerator::lowerBiasAdd(Value input, Value bias, Value output) {
446465
if (!enableBias)
447466
return input;

tools/mlir-gen/MLIRGen.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class MLIRGenerator {
7575
bool enableSoftmax;
7676

7777
/// List of linalg output Op kind which can be generated
78-
enum class OutputOpKind { Generic, NamedOp };
78+
enum class OutputOpKind { Generic, Contract, NamedOp };
7979

8080
/// Kind of linalg output Op to be generated
8181
OutputOpKind outputOpKind;
@@ -156,6 +156,9 @@ class MLIRGenerator {
156156
/// Creates linalg named matmul
157157
Value lowerNamedMatmul(Value, Value, Value);
158158

159+
/// Creates linalg contract
160+
Value lowerContract(Value, Value, Value);
161+
159162
/// Creates a bias add in the current function
160163
/// Args: Input, Output (same for in-place)
161164
/// Returns the chain value to be used in the next op

tools/mlir-gen/mlir-gen.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ using namespace mlir;
3333

3434
// Kind of linalg Op, generic or nameed ops
3535
llvm::cl::opt<std::string> outputOpKind(
36-
"output", llvm::cl::desc("Specifies linalg op kind generic or named"),
37-
llvm::cl::value_desc("generic,named"), llvm::cl::init("generic"));
36+
"output", llvm::cl::desc("Specifies linalg op kind generic, contract or named"),
37+
llvm::cl::value_desc("generic,contract,named"), llvm::cl::init("generic"));
3838

3939
// Enable emission of generic matmul when outputKind is named op
4040
llvm::cl::opt<bool> keepGenericMatmul(

0 commit comments

Comments
 (0)