Skip to content

Commit 409d9d2

Browse files
committed
Add shapeInference for the ops.
Signed-off-by: Haruki Imai <[email protected]>
1 parent fefae10 commit 409d9d2

File tree

4 files changed

+195
-0
lines changed

4 files changed

+195
-0
lines changed

src/Dialect/ONNX/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ add_onnx_mlir_library(OMONNXOps
3131
ONNXOps/Additional/Custom.cpp
3232
ONNXOps/Additional/Dim.cpp
3333
ONNXOps/Additional/EntryPoint.cpp
34+
ONNXOps/Additional/Fork.cpp
3435
ONNXOps/Additional/Return.cpp
3536
ONNXOps/Additional/LayoutTransform.cpp
3637
ONNXOps/Additional/None.cpp
38+
ONNXOps/Additional/Parallel.cpp
3739
ONNXOps/Additional/ShapeTransform.cpp
3840
ONNXOps/ControlFlow/If.cpp
3941
ONNXOps/ControlFlow/Loop.cpp
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===---------------- Fork.cpp - ONNX Operations -------------------------===//
6+
//
7+
// Copyright 2019-2024 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file provides definition of ONNX dialect Fork operation.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
16+
17+
using namespace mlir;
18+
using namespace onnx_mlir;
19+
20+
//===----------------------------------------------------------------------===//
21+
// ShapeHelper
22+
//===----------------------------------------------------------------------===//
23+
24+
template <>
25+
LogicalResult ONNXForkOpShapeHelper::computeShape() {
26+
ONNXForkOp forkOp = llvm::cast<ONNXForkOp>(op);
27+
(void)forkOp.inferShapes([](Region &region) {});
28+
Operation *yieldOp = forkOp.getBody().front().getTerminator();
29+
for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) {
30+
DimsExpr outputDims;
31+
Value returnVal = yieldOp->getOperands()[i];
32+
int64_t outRank = returnVal.getType().cast<ShapedType>().getRank();
33+
for (int64_t j = 0; j < outRank; ++j)
34+
outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j));
35+
setOutputDims(outputDims, i);
36+
}
37+
return success();
38+
}
39+
40+
//===----------------------------------------------------------------------===//
41+
// Type Inference
42+
//===----------------------------------------------------------------------===//
43+
44+
std::vector<Type> ONNXForkOp::resultTypeInference() {
45+
Operation *terminator = getRegion().back().getTerminator();
46+
auto bodyOutputTys = terminator->getOperandTypes();
47+
std::vector<Type> resultTypes;
48+
for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) {
49+
resultTypes.push_back(ty);
50+
}
51+
return resultTypes;
52+
}
53+
54+
//===----------------------------------------------------------------------===//
55+
// Shape Inference
56+
//===----------------------------------------------------------------------===//
57+
58+
LogicalResult ONNXForkOp::inferShapes(
59+
std::function<void(Region &)> doShapeInference) {
60+
doShapeInference(getRegion());
61+
for (auto [i, ty] : llvm::enumerate(resultTypeInference()))
62+
getResult(i).setType(ty);
63+
return success();
64+
}
65+
66+
//===----------------------------------------------------------------------===//
67+
// Builder: Refer to Async ExecuteOp
68+
//===----------------------------------------------------------------------===//
69+
void ONNXForkOp::build(OpBuilder &builder, OperationState &result,
70+
TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) {
71+
72+
result.addOperands(operands);
73+
result.addTypes(resultTypes);
74+
75+
// Add a body region with block arguments
76+
Region *bodyRegion = result.addRegion();
77+
bodyRegion->push_back(new Block);
78+
Block &bodyBlock = bodyRegion->front();
79+
for (Value operand : operands) {
80+
bodyBlock.addArgument(operand.getType(), operand.getLoc());
81+
}
82+
83+
// Create the default terminator if the builder is not provided and if the
84+
// expected result is empty. Otherwise, leave this to the caller
85+
// because we don't know which values to return from the execute op.
86+
if (resultTypes.empty() && !bodyBuilder) {
87+
OpBuilder::InsertionGuard guard(builder);
88+
builder.setInsertionPointToStart(&bodyBlock);
89+
builder.create<ONNXYieldOp>(result.location, ValueRange());
90+
} else if (bodyBuilder) {
91+
OpBuilder::InsertionGuard guard(builder);
92+
builder.setInsertionPointToStart(&bodyBlock);
93+
bodyBuilder(builder, result.location, bodyBlock.getArguments());
94+
}
95+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===---------------- Fork.cpp - ONNX Operations -------------------------===//
6+
//
7+
// Copyright 2019-2024 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file provides definition of ONNX dialect Fork operation.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
16+
17+
using namespace mlir;
18+
using namespace onnx_mlir;
19+
20+
//===----------------------------------------------------------------------===//
21+
// ShapeHelper
22+
//===----------------------------------------------------------------------===//
23+
24+
template <>
25+
LogicalResult ONNXParallelOpShapeHelper::computeShape() {
26+
ONNXParallelOp parallelOp = llvm::cast<ONNXParallelOp>(op);
27+
(void)parallelOp.inferShapes([](Region &region) {});
28+
Operation *yieldOp = parallelOp.getBody().front().getTerminator();
29+
for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) {
30+
DimsExpr outputDims;
31+
Value returnVal = yieldOp->getOperands()[i];
32+
int64_t outRank = returnVal.getType().cast<ShapedType>().getRank();
33+
for (int64_t j = 0; j < outRank; ++j)
34+
outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j));
35+
setOutputDims(outputDims, i);
36+
}
37+
return success();
38+
}
39+
40+
//===----------------------------------------------------------------------===//
41+
// Type Inference
42+
//===----------------------------------------------------------------------===//
43+
44+
std::vector<Type> ONNXParallelOp::resultTypeInference() {
45+
Operation *terminator = getRegion().back().getTerminator();
46+
auto bodyOutputTys = terminator->getOperandTypes();
47+
48+
std::vector<Type> resultTypes;
49+
for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) {
50+
resultTypes.push_back(ty);
51+
}
52+
return resultTypes;
53+
}
54+
55+
//===----------------------------------------------------------------------===//
56+
// Shape Inference
57+
//===----------------------------------------------------------------------===//
58+
59+
LogicalResult ONNXParallelOp::inferShapes(
60+
std::function<void(Region &)> doShapeInference) {
61+
doShapeInference(getRegion());
62+
for (auto [i, ty] : llvm::enumerate(resultTypeInference()))
63+
getResult(i).setType(ty);
64+
return success();
65+
}
66+
67+
//===----------------------------------------------------------------------===//
68+
// Builder: Refer to Async ExecuteOp
69+
//===----------------------------------------------------------------------===//
70+
void ONNXParallelOp::build(OpBuilder &builder, OperationState &result,
71+
TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) {
72+
73+
result.addOperands(operands);
74+
result.addTypes(resultTypes);
75+
76+
// Add a body region with block arguments
77+
Region *bodyRegion = result.addRegion();
78+
bodyRegion->push_back(new Block);
79+
Block &bodyBlock = bodyRegion->front();
80+
for (Value operand : operands) {
81+
bodyBlock.addArgument(operand.getType(), operand.getLoc());
82+
}
83+
84+
// Create the default terminator if the builder is not provided and if the
85+
// expected result is empty. Otherwise, leave this to the caller
86+
// because we don't know which values to return from the execute op.
87+
if (resultTypes.empty() && !bodyBuilder) {
88+
OpBuilder::InsertionGuard guard(builder);
89+
builder.setInsertionPointToStart(&bodyBlock);
90+
builder.create<ONNXYieldOp>(result.location, ValueRange());
91+
} else if (bodyBuilder) {
92+
OpBuilder::InsertionGuard guard(builder);
93+
builder.setInsertionPointToStart(&bodyBlock);
94+
bodyBuilder(builder, result.location, bodyBlock.getArguments());
95+
}
96+
}

src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,8 @@ using ONNXTileOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXTileOp>;
885885
using ONNXTopKOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXTopKOp>;
886886
using ONNXTransposeOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXTransposeOp>;
887887
using ONNXUpsampleOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXUpsampleOp>;
888+
using ONNXForkOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXForkOp>;
889+
using ONNXParallelOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXParallelOp>;
888890
// clang-format on
889891

890892
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)