@@ -85,9 +85,12 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
85
85
auto optOutputOpKind =
86
86
llvm::StringSwitch<std::optional<OutputOpKind>>(outputOpKindStr)
87
87
.CaseLower (" generic" , OutputOpKind::Generic)
88
+ .CaseLower (" contract" , OutputOpKind::Contract)
88
89
.CaseLower (" named" , OutputOpKind::NamedOp)
89
90
.Default (std::nullopt);
90
91
assert (optOutputOpKind && " Invalid output Op kind" );
92
+ assert (!(optOutputOpKind == OutputOpKind::Contract && keepGenericMatmul) &&
93
+ " Can't keep generic matmul with contract" );
91
94
outputOpKind = *optOutputOpKind;
92
95
93
96
// Parse kernel type
@@ -181,7 +184,7 @@ Value MLIRGenerator::createLayer(LayerArgs &args) {
181
184
if (outputOpKind == OutputOpKind::Generic) {
182
185
chain = lowerBiasAdd (chain, args.bias .value , args.output .value );
183
186
chain = lowerRelu (chain, args.output .value );
184
- } else if (outputOpKind == OutputOpKind::NamedOp) {
187
+ } else {
185
188
chain = lowerNamedBiasAdd (chain, args.bias .value , args.output .value );
186
189
chain = lowerNamedRelu (chain, args.output .value );
187
190
}
@@ -190,7 +193,7 @@ Value MLIRGenerator::createLayer(LayerArgs &args) {
190
193
if (args.index == layers.size () - 1 ) {
191
194
if (outputOpKind == OutputOpKind::Generic) {
192
195
chain = lowerSoftmax (chain, args.output .value );
193
- } else if (outputOpKind == OutputOpKind::NamedOp) {
196
+ } else {
194
197
chain = lowerNamedSoftmax (chain, args.output .value );
195
198
}
196
199
}
@@ -405,9 +408,10 @@ Value MLIRGenerator::lowerMatmul(Value input, Value weight, Value output) {
405
408
reassociationIndices);
406
409
}
407
410
408
- if (outputOpKind == OutputOpKind::Generic ||
409
- (outputOpKind == OutputOpKind::NamedOp && keepGenericMatmul)) {
411
+ if (outputOpKind == OutputOpKind::Generic || keepGenericMatmul) {
410
412
chain = lowerGenericMatmul (input, weight, output);
413
+ } else if (outputOpKind == OutputOpKind::Contract) {
414
+ chain = lowerContract (input, weight, output);
411
415
} else if (outputOpKind == OutputOpKind::NamedOp) {
412
416
chain = lowerNamedMatmul (input, weight, output);
413
417
}
@@ -442,6 +446,21 @@ Value MLIRGenerator::lowerGenericMatmul(Value input, Value weight,
442
446
return matmul;
443
447
}
444
448
449
+ Value MLIRGenerator::lowerContract (Value input, Value weight, Value output) {
450
+ // Matmul as a linalg.contract
451
+ SmallVector<Attribute> maps;
452
+ maps.push_back (AffineMapAttr::get (getMap (input, MAP_MATMUL_INPUT))); // { 0, 2 }
453
+ maps.push_back (AffineMapAttr::get (getMap (weight, MAP_MATMUL_WEIGHT))); // { 2, 1 }
454
+ maps.push_back (AffineMapAttr::get (getMap (output, MAP_MATMUL_OUTPUT))); // { 0, 1 }
455
+ auto contract = builder
456
+ .create <linalg::ContractOp>(
457
+ loc, output.getType (), ValueRange{input, weight}, ValueRange{output},
458
+ builder.getArrayAttr (maps))
459
+ .getResult (0 );
460
+
461
+ return contract;
462
+ }
463
+
445
464
Value MLIRGenerator::lowerBiasAdd (Value input, Value bias, Value output) {
446
465
if (!enableBias)
447
466
return input;
0 commit comments