20
20
#include " src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
21
21
#include " src/Dialect/ONNX/ONNXDimAnalysis.hpp"
22
22
#include " src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
23
+ #include " llvm/Support/Debug.h"
24
+
25
+ #define DEBUG_TYPE " legality-check"
23
26
24
27
using namespace mlir ;
25
28
using namespace onnx_mlir ;
@@ -266,6 +269,11 @@ bool meetPoolParamRestrictions(int64_t inputShape, int64_t kernelShape,
266
269
return true ;
267
270
}
268
271
272
+ void emitLegalityCheckMessage (Operation *op, StringRef message) {
273
+ LLVM_DEBUG (llvm::outs () << " [NNPA Legality Check] Warning: " << op->getLoc ()
274
+ << " runs on CPU. Reason: " << message << " \n " );
275
+ }
276
+
269
277
// / Default legality check.
270
278
template <typename OP_TYPE>
271
279
bool isSuitableForZDNN (OP_TYPE op, const DimAnalysis *dimAnalysis) {
@@ -489,23 +497,37 @@ bool isSuitableForZDNN<ONNXExpOp>(
489
497
template <>
490
498
bool isSuitableForZDNN<ONNXMatMulOp>(
491
499
ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
500
+
492
501
// Check NNPA level.
493
- if (!isCompatibleWithNNPALevel (NNPA_Z16))
502
+ if (!isCompatibleWithNNPALevel (NNPA_Z16)) {
503
+ emitLegalityCheckMessage (
504
+ op.getOperation (), " Not compatible with NNPA level." );
494
505
return false ;
506
+ }
495
507
int64_t opnum = op.getNumOperands ();
496
508
if (opnum != 2 ) {
509
+ emitLegalityCheckMessage (
510
+ op.getOperation (), " The number of operands not 2." );
497
511
return false ;
498
512
}
499
- if (!isValidElementTypeAndRank (op.getOperand (0 )))
513
+ if (!isValidElementTypeAndRank (op.getOperand (0 ))) {
514
+ emitLegalityCheckMessage (
515
+ op.getOperation (), " Operand 0 not valid element type and rank." );
500
516
return false ;
501
- if (!isValidElementTypeAndRank (op.getOperand (1 )))
517
+ }
518
+ if (!isValidElementTypeAndRank (op.getOperand (1 ))) {
519
+ emitLegalityCheckMessage (
520
+ op.getOperation (), " Operand 1 not valid element type and rank." );
502
521
return false ;
522
+ }
503
523
ShapedType aType = op.getOperand (0 ).getType ().cast <ShapedType>();
504
524
ShapedType bType = op.getOperand (1 ).getType ().cast <ShapedType>();
505
525
506
526
// Illegal if A or B is unranked.
507
- if (!aType.hasRank () || !bType.hasRank ())
527
+ if (!aType.hasRank () || !bType.hasRank ()) {
528
+ emitLegalityCheckMessage (op.getOperation (), " A or B is unranked." );
508
529
return false ;
530
+ }
509
531
510
532
auto shapeA = aType.getShape ();
511
533
auto shapeB = bType.getShape ();
@@ -518,21 +540,34 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
518
540
// by using broadcasting etc.
519
541
if ((shapeA.size () == 2 ) && (shapeB.size () == 2 )) {
520
542
// unstacked case
521
- if (aType.hasStaticShape () && bType.hasStaticShape ())
522
- return (shapeA[1 ] == shapeB[0 ]);
523
- else
543
+ if (aType.hasStaticShape () && bType.hasStaticShape ()) {
544
+ bool returnVal = (shapeA[1 ] == shapeB[0 ]);
545
+ if (!returnVal)
546
+ emitLegalityCheckMessage (
547
+ op.getOperation (), " Unstacked case, dim A 1 not equal dim B 0." );
548
+ return returnVal;
549
+ } else
524
550
return true ;
525
551
} else if ((shapeA.size () == 3 ) && (shapeB.size () == 3 )) {
526
552
// stacked w/o bcast case
527
- if (aType.hasStaticShape () && bType.hasStaticShape ())
528
- return ((shapeA[0 ] == shapeB[0 ]) && (shapeA[2 ] == shapeB[1 ]));
529
- else
553
+ if (aType.hasStaticShape () && bType.hasStaticShape ()) {
554
+ bool returnVal = ((shapeA[0 ] == shapeB[0 ]) && (shapeA[2 ] == shapeB[1 ]));
555
+ if (!returnVal)
556
+ emitLegalityCheckMessage (
557
+ op.getOperation (), " Stacked w/o bcast, dim A 0 not equal dim B 0 "
558
+ " or dim A 2 not equal dim B 1." );
559
+ return returnVal;
560
+ } else
530
561
return true ;
531
562
} else if ((shapeA.size () == 3 ) && (shapeB.size () == 2 )) {
532
563
// stacked w/ bcast
533
- if (aType.hasStaticShape () && bType.hasStaticShape ())
534
- return (shapeA[2 ] == shapeB[0 ]);
535
- else
564
+ if (aType.hasStaticShape () && bType.hasStaticShape ()) {
565
+ bool returnVal = (shapeA[2 ] == shapeB[0 ]);
566
+ if (!returnVal)
567
+ emitLegalityCheckMessage (
568
+ op.getOperation (), " Stacked w/ bcast, dim A 2 not equal dim B 0." );
569
+ return returnVal;
570
+ } else
536
571
return true ;
537
572
}
538
573
return false ; // unsupported case
0 commit comments