File tree Expand file tree Collapse file tree 2 files changed +20
-3
lines changed
src/Dialect/ONNX/Transforms Expand file tree Collapse file tree 2 files changed +20
-3
lines changed Original file line number Diff line number Diff line change @@ -607,6 +607,21 @@ Value normalizeConstantOp(
607
607
return createONNX.constant (denseAttr);
608
608
}
609
609
610
+ ElementsAttr reshapeElementsAttrToRank0WithDefaultValue (
611
+ PatternRewriter &rewriter, Value shape, Attribute val) {
612
+ if (!val) {
613
+ // Default is 0.0 in float32. It is not created by default in the ONNX
614
+ // getValue() as the ONNX td does not define a default value. So explicitly
615
+ // create a dense array of 1 zero value here.
616
+ Type elementType = rewriter.getF32Type ();
617
+ RankedTensorType tensorType = RankedTensorType::get ({1 }, elementType);
618
+ FloatAttr floatAttr = rewriter.getFloatAttr (elementType, 0.0 );
619
+ val = DenseElementsAttr::get (tensorType, floatAttr.getValue ());
620
+ }
621
+ return OnnxElementsAttrBuilder (shape.getContext ())
622
+ .reshape (cast<ElementsAttr>(val), {});
623
+ }
624
+
610
625
} // namespace onnx_mlir
611
626
612
627
namespace {
Original file line number Diff line number Diff line change @@ -68,8 +68,10 @@ def createScalarDenseAttrRank0
68
68
// Create a scalar DenseElementsAttr of tensor<dtype> from an ElementsAttr.
69
69
// The input ElementsAttr must have only one element. Otherwise only the first
70
70
// element is used to create the scalar DenseElementsAttr
71
- def ReshapeElementsAttrToRank0 : NativeCodeCall<
72
- "onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast<ElementsAttr>($0), {})">;
71
+ // When no attribute is provided ($1 is nullptr), then a default value attribute
72
+ // is created (float32 0.0) according to ONNX specs.
73
+ def ReshapeElementsAttrToRank0WithDefaultValue: NativeCodeCall<
74
+ "onnx_mlir::reshapeElementsAttrToRank0WithDefaultValue($_builder, $0, $1)">;
73
75
74
76
def ReplaceSequenceAt : NativeCodeCall<
75
77
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;
@@ -585,7 +587,7 @@ def ConvTransposeOpPattern2: Pattern<
585
587
586
588
def ConstantOfShapePattern: Pat<
587
589
(ONNXConstantOfShapeOp:$res $shape, $value),
588
- (ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)),
590
+ (ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0WithDefaultValue $shape, $value)),
589
591
$shape)
590
592
>;
591
593
You can’t perform that action at this time.
0 commit comments