|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: Apache-2.0 |
| 3 | + */ |
| 4 | + |
| 5 | +//===----------------- Dim.cpp - Lowering Dim Op ----------------===// |
| 6 | +// |
| 7 | +// Copyright 2022-2024 |
| 8 | +// |
| 9 | +// ============================================================================= |
| 10 | +// |
| 11 | +// This file lowers the ONNXDim operator to the Tensor dialect. |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "mlir/Dialect/Shape/IR/Shape.h" |
| 16 | +#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" |
| 17 | + |
| 18 | +using namespace mlir; |
| 19 | + |
| 20 | +namespace onnx_mlir { |
| 21 | + |
| 22 | +namespace { |
| 23 | + |
| 24 | +struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { |
| 25 | + ONNXDimOpLoweringToStablehlo(MLIRContext *ctx) |
| 26 | + : ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {} |
| 27 | + |
| 28 | + LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 29 | + ConversionPatternRewriter &rewriter) const final { |
| 30 | + Location loc = op->getLoc(); |
| 31 | + ONNXDimOp dimOp = cast<ONNXDimOp>(op); |
| 32 | + int64_t axisLit = dimOp.getAxis(); |
| 33 | + |
| 34 | + // Check that axisLit is a valid dimension index |
| 35 | + Value tensorArg = operands[0]; |
| 36 | + assert(tensorArg.getType().isa<RankedTensorType>() && |
| 37 | + "Expected ranked tensor type"); |
| 38 | + |
| 39 | + int64_t rank = tensorArg.getType().cast<RankedTensorType>().getRank(); |
| 40 | + |
| 41 | + assert((axisLit >= 0 && axisLit < rank) && |
| 42 | + "Axis must be in the range [0, input tensor rank - 1]"); |
| 43 | + |
| 44 | + Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensorArg); |
| 45 | + Value dimValue = |
| 46 | + rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit); |
| 47 | + Type dimType = dimOp.getDim().getType(); |
| 48 | + Type indexValueType = dimType.cast<ShapedType>().getElementType(); |
| 49 | + Value castedIndex = |
| 50 | + rewriter.create<arith::IndexCastOp>(loc, indexValueType, dimValue); |
| 51 | + Value indexTensor = rewriter.create<tensor::FromElementsOp>( |
| 52 | + loc, dimType, ArrayRef<Value>{castedIndex}); |
| 53 | + rewriter.replaceOp(op, indexTensor); |
| 54 | + return success(); |
| 55 | + } |
| 56 | +}; |
| 57 | + |
| 58 | +} // namespace |
| 59 | + |
| 60 | +void populateLoweringONNXDimOpToStablehloPattern( |
| 61 | + RewritePatternSet &patterns, MLIRContext *ctx) { |
| 62 | + patterns.insert<ONNXDimOpLoweringToStablehlo>(ctx); |
| 63 | +} |
| 64 | + |
| 65 | +} // namespace onnx_mlir |
0 commit comments