Skip to content

Commit 699920b

Browse files
Fix to add default value to ConstantOfShape (onnx#3174)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 9804476 commit 699920b

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,21 @@ Value normalizeConstantOp(
607607
return createONNX.constant(denseAttr);
608608
}
609609

610+
ElementsAttr reshapeElementsAttrToRank0WithDefaultValue(
611+
PatternRewriter &rewriter, Value shape, Attribute val) {
612+
if (!val) {
613+
// Default is 0.0 in float32. It is not created by default in the ONNX
614+
// getValue() as the ONNX td does not define a default value. So explicitly
615+
// create a dense array of 1 zero value here.
616+
Type elementType = rewriter.getF32Type();
617+
RankedTensorType tensorType = RankedTensorType::get({1}, elementType);
618+
FloatAttr floatAttr = rewriter.getFloatAttr(elementType, 0.0);
619+
val = DenseElementsAttr::get(tensorType, floatAttr.getValue());
620+
}
621+
return OnnxElementsAttrBuilder(shape.getContext())
622+
.reshape(cast<ElementsAttr>(val), {});
623+
}
624+
610625
} // namespace onnx_mlir
611626

612627
namespace {

src/Dialect/ONNX/Transforms/Decompose.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ def createScalarDenseAttrRank0
6868
// Create a scalar DenseElementsAttr of tensor<dtype> from an ElementsAttr.
6969
// The input ElementsAttr must have only one element. Otherwise only the first
7070
// element is used to create the scalar DenseElementsAttr
71-
def ReshapeElementsAttrToRank0 : NativeCodeCall<
72-
"onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast<ElementsAttr>($0), {})">;
71+
// When no attribute is provided ($1 is nullptr), then a default value attribute
72+
// is created (float32 0.0) according to ONNX specs.
73+
def ReshapeElementsAttrToRank0WithDefaultValue: NativeCodeCall<
74+
"onnx_mlir::reshapeElementsAttrToRank0WithDefaultValue($_builder, $0, $1)">;
7375

7476
def ReplaceSequenceAt : NativeCodeCall<
7577
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;
@@ -585,7 +587,7 @@ def ConvTransposeOpPattern2: Pattern<
585587

586588
def ConstantOfShapePattern: Pat<
587589
(ONNXConstantOfShapeOp:$res $shape, $value),
588-
(ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)),
590+
(ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0WithDefaultValue $shape, $value)),
589591
$shape)
590592
>;
591593

0 commit comments

Comments
 (0)