Skip to content

Commit 9fac96d

Browse files
committed
Generate messages for legality check for MatMul op.
Signed-off-by: Haruki Imai <[email protected]>
1 parent 733dfac commit 9fac96d

File tree

3 files changed

+67
-16
lines changed

3 files changed

+67
-16
lines changed

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
2121
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
2222
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
23+
#include "llvm/Support/Debug.h"
24+
25+
#define DEBUG_TYPE "legality-check"
2326

2427
using namespace mlir;
2528
using namespace onnx_mlir;
@@ -266,6 +269,11 @@ bool meetPoolParamRestrictions(int64_t inputShape, int64_t kernelShape,
266269
return true;
267270
}
268271

272+
void emitLegalityCheckMessage(Operation *op, StringRef message) {
273+
LLVM_DEBUG(llvm::outs() << "[NNPA Legality Check] Warning: " << op->getLoc()
274+
<< " runs on CPU. Reason: " << message << "\n");
275+
}
276+
269277
/// Default legality check.
270278
template <typename OP_TYPE>
271279
bool isSuitableForZDNN(OP_TYPE op, const DimAnalysis *dimAnalysis) {
@@ -489,23 +497,37 @@ bool isSuitableForZDNN<ONNXExpOp>(
489497
template <>
490498
bool isSuitableForZDNN<ONNXMatMulOp>(
491499
ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
500+
492501
// Check NNPA level.
493-
if (!isCompatibleWithNNPALevel(NNPA_Z16))
502+
if (!isCompatibleWithNNPALevel(NNPA_Z16)) {
503+
emitLegalityCheckMessage(
504+
op.getOperation(), "Not compatible with NNPA level.");
494505
return false;
506+
}
495507
int64_t opnum = op.getNumOperands();
496508
if (opnum != 2) {
509+
emitLegalityCheckMessage(
510+
op.getOperation(), "The number of operands not 2.");
497511
return false;
498512
}
499-
if (!isValidElementTypeAndRank(op.getOperand(0)))
513+
if (!isValidElementTypeAndRank(op.getOperand(0))) {
514+
emitLegalityCheckMessage(
515+
op.getOperation(), "Operand 0 not valid element type and rank.");
500516
return false;
501-
if (!isValidElementTypeAndRank(op.getOperand(1)))
517+
}
518+
if (!isValidElementTypeAndRank(op.getOperand(1))) {
519+
emitLegalityCheckMessage(
520+
op.getOperation(), "Operand 1 not valid element type and rank.");
502521
return false;
522+
}
503523
ShapedType aType = op.getOperand(0).getType().cast<ShapedType>();
504524
ShapedType bType = op.getOperand(1).getType().cast<ShapedType>();
505525

506526
// Illegal if A or B is unranked.
507-
if (!aType.hasRank() || !bType.hasRank())
527+
if (!aType.hasRank() || !bType.hasRank()) {
528+
emitLegalityCheckMessage(op.getOperation(), "A or B is unranked.");
508529
return false;
530+
}
509531

510532
auto shapeA = aType.getShape();
511533
auto shapeB = bType.getShape();
@@ -518,21 +540,34 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
518540
// by using broadcasting etc.
519541
if ((shapeA.size() == 2) && (shapeB.size() == 2)) {
520542
// unstacked case
521-
if (aType.hasStaticShape() && bType.hasStaticShape())
522-
return (shapeA[1] == shapeB[0]);
523-
else
543+
if (aType.hasStaticShape() && bType.hasStaticShape()) {
544+
bool returnVal = (shapeA[1] == shapeB[0]);
545+
if (!returnVal)
546+
emitLegalityCheckMessage(
547+
op.getOperation(), "Unstacked case, dim A 1 not equal dim B 0.");
548+
return returnVal;
549+
} else
524550
return true;
525551
} else if ((shapeA.size() == 3) && (shapeB.size() == 3)) {
526552
// stacked w/o bcast case
527-
if (aType.hasStaticShape() && bType.hasStaticShape())
528-
return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1]));
529-
else
553+
if (aType.hasStaticShape() && bType.hasStaticShape()) {
554+
bool returnVal = ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1]));
555+
if (!returnVal)
556+
emitLegalityCheckMessage(
557+
op.getOperation(), "Stacked w/o bcast, dim A 0 not equal dim B 0 "
558+
"or dim A 2 not equal dim B 1.");
559+
return returnVal;
560+
} else
530561
return true;
531562
} else if ((shapeA.size() == 3) && (shapeB.size() == 2)) {
532563
// stacked w/ bcast
533-
if (aType.hasStaticShape() && bType.hasStaticShape())
534-
return (shapeA[2] == shapeB[0]);
535-
else
564+
if (aType.hasStaticShape() && bType.hasStaticShape()) {
565+
bool returnVal = (shapeA[2] == shapeB[0]);
566+
if (!returnVal)
567+
emitLegalityCheckMessage(
568+
op.getOperation(), "Stacked w/ bcast, dim A 2 not equal dim B 0.");
569+
return returnVal;
570+
} else
536571
return true;
537572
}
538573
return false; // unsupported case

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ mlir::StringRef getStrPaddingType(OP op);
4545
/// See "MaxPool2D/AvgPool2D Parameter Restrictions" in "zDNN API Reference"
4646
bool meetPoolParamRestrictions(int64_t inputShape, int64_t kernelShape,
4747
int64_t strides, int64_t outputShape, mlir::StringRef paddingType);
48+
49+
void emitLegalityCheckMessage(mlir::Operation *op, mlir::StringRef message);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,24 @@ void getRewriteONNXForZHighDynamicallyLegal(
589589
// No input is N-D (N > 3) but dimension N or M (NxK * KxM) is dynamic
590590
// or exceeds NNPA limitation.
591591
bool nExceeded, mExceeded;
592-
if (SplitLargeMatMulPattern::canBeRewritten(op, nExceeded, mExceeded))
592+
if (SplitLargeMatMulPattern::canBeRewritten(op, nExceeded, mExceeded)) {
593+
emitLegalityCheckMessage(op.getOperation(),
594+
"No input is N-D (N > 3) but dimension N or M (NxK * KxM) is "
595+
"dynamic or exceeds NNPA limitation.");
593596
return false;
597+
}
594598

595599
// - one input is N-D (N > 3) and the other is 2-D.
596-
if (aRank == 2 && bRank > 3)
600+
if (aRank == 2 && bRank > 3) {
601+
emitLegalityCheckMessage(op.getOperation(),
602+
"one input is N-D (N > 3) and the other is 2-D.");
597603
return false;
598-
if (bRank == 2 && aRank > 3)
604+
}
605+
if (bRank == 2 && aRank > 3) {
606+
emitLegalityCheckMessage(op.getOperation(),
607+
"one input is N-D (N > 3) and the other is 2-D.");
599608
return false;
609+
}
600610

601611
// - both inputs are *the same* N-D, N > 3 and there is no broadcasting
602612
if (aRank > 3 && (aRank == bRank)) {
@@ -607,6 +617,10 @@ void getRewriteONNXForZHighDynamicallyLegal(
607617
sameBatchDims =
608618
dimAnalysis->sameDynDim(op.getA(), i, op.getB(), i);
609619
}
620+
if (!sameBatchDims)
621+
emitLegalityCheckMessage(
622+
op.getOperation(), "both inputs are *the same* N-D, N > 3 and "
623+
"there is no broadcasting.");
610624
return !sameBatchDims;
611625
}
612626

0 commit comments

Comments
 (0)