diff --git a/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp b/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp index 9e67d50285..66cca0e3d3 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp @@ -28,41 +28,78 @@ Value getIdentityValue( return nullptr; } -template <> -Value getIdentityValue( +Value getReduceMaxIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type elemType) { MathBuilder createMath(rewriter, loc); return rewriter.create( loc, createMath.negativeInfAttr(elemType)); } -template <> -Value getIdentityValue( +Value getReduceMinIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type elemType) { MathBuilder createMath(rewriter, loc); return rewriter.create( loc, createMath.positiveInfAttr(elemType)); } -template <> -Value getIdentityValue( +Value getReduceSumIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type elemType) { return rewriter.create( loc, rewriter.getZeroAttr(elemType)); } -template <> -Value getIdentityValue( +Value getReduceMeanIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type elemType) { return rewriter.create( loc, rewriter.getZeroAttr(elemType)); } +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceMaxIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceMaxIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceMinIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceMinIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceSumIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceSumIdentityValue(rewriter, loc, elemType); +} + +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type elemType) { + return getReduceMeanIdentityValue(rewriter, loc, elemType); +} + template <> Value getIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type elemType) { - return rewriter.create( - loc, rewriter.getZeroAttr(elemType)); + return getReduceMeanIdentityValue(rewriter, loc, elemType); } template @@ -78,12 +115,9 @@ llvm::SmallVector getDefinedAxes(Operation *op) { return definedAxes; } -template <> -llvm::SmallVector getDefinedAxes(Operation *op) { +llvm::SmallVector getDefinedAxesFromConstAxes( + Operation *op, Value axesValue, bool keepDims) { llvm::SmallVector definedAxes; - ONNXReduceSumOp reduceSumOp = cast(op); - Value axesValue = reduceSumOp.getAxes(); - // Assume it is verified that axes are known. Convert DenseElementsAttr to // ArrayAttr. if (!isNoneValue(axesValue) && getONNXConstantOp(axesValue)) { @@ -104,7 +138,7 @@ llvm::SmallVector getDefinedAxes(Operation *op) { assert(inputType != nullptr && outputType != nullptr && "not implemented for dynamic axes when either input or output is not " "ranked"); - bool keepDims = reduceSumOp.getKeepdims() == 1; + int64_t inputRank = inputType.getRank(); int64_t outputRank = outputType.getRank(); llvm::ArrayRef inputShape = inputType.getShape(); @@ -127,22 +161,69 @@ llvm::SmallVector getDefinedAxes(Operation *op) { return definedAxes; } +template <> +llvm::SmallVector getDefinedAxes(Operation *op) { + ONNXReduceMaxOp reduceMaxOp = cast(op); + Value axesValue = reduceMaxOp.getAxes(); + bool keepDims = reduceMaxOp.getKeepdims() == 1; + return getDefinedAxesFromConstAxes(op, axesValue, keepDims); +} + +template <> +llvm::SmallVector getDefinedAxes(Operation *op) { + ONNXReduceMinOp reduceMinOp = cast(op); + Value axesValue = reduceMinOp.getAxes(); + bool keepDims = reduceMinOp.getKeepdims() == 1; + return getDefinedAxesFromConstAxes(op, axesValue, keepDims); +} + +template <> +llvm::SmallVector getDefinedAxes(Operation *op) { + ONNXReduceSumOp reduceSumOp = cast(op); + Value axesValue = reduceSumOp.getAxes(); + bool keepDims = reduceSumOp.getKeepdims() == 1; + return getDefinedAxesFromConstAxes(op, axesValue, keepDims); +} + +template <> +llvm::SmallVector getDefinedAxes(Operation *op) { + ONNXReduceMeanOp reduceMeanOp = cast(op); + Value axesValue = reduceMeanOp.getAxes(); + bool keepDims = reduceMeanOp.getKeepdims() == 1; + return getDefinedAxesFromConstAxes(op, axesValue, keepDims); +} + // Block reduce ops template struct BlockReduceOp { using Op = void; }; +template <> +struct BlockReduceOp { + using Op = stablehlo::MaxOp; +}; + template <> struct BlockReduceOp { using Op = stablehlo::MaxOp; }; +template <> +struct BlockReduceOp { + using Op = stablehlo::MinOp; +}; + template <> struct BlockReduceOp { using Op = stablehlo::MinOp; }; +template <> +struct BlockReduceOp { + using Op = stablehlo::AddOp; +}; + template <> struct BlockReduceOp { using Op = stablehlo::AddOp; @@ -355,10 +436,14 @@ struct ONNXReductionOpLoweringToStablehlo : public ConversionPattern { void populateLoweringONNXReductionOpToStablehloPattern( RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert, + patterns.insert, + ONNXReductionOpLoweringToStablehlo, + ONNXReductionOpLoweringToStablehlo, ONNXReductionOpLoweringToStablehlo, ONNXReductionOpLoweringToStablehlo, ONNXReductionOpLoweringToStablehlo>(ctx); + patterns.insert>( + ctx, /*computeMean=*/true); patterns .insert>( ctx, /*computeMean=*/true); diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Reduction.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Reduction.mlir index a03e7483bc..b529b5a303 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Reduction.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Reduction.mlir @@ -175,3 +175,57 @@ func.func @test_reducemean_v13_2(%arg0 : tensor) -> tensor { // CHECK: [[VAR_6_:%.+]] = stablehlo.divide [[VAR_2_]], [[VAR_5_]] : tensor // CHECK: return [[VAR_6_]] : tensor // CHECK: } + +// ----- + +func.func @reduce_mean(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { + %0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.ReduceMean"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> + return %1 : tensor<2x5x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_mean +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<9.900000e+01> : tensor<2x5x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: [[VAR_3_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor) -> tensor<2x5xf32> +// CHECK: [[VAR_4_:%.+]] = stablehlo.dynamic_reshape [[VAR_3_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32> +// CHECK: [[VAR_5_:%.+]] = stablehlo.divide [[VAR_4_]], [[VAR_1_]] : tensor<2x5x1x1xf32> +// CHECK: return [[VAR_5_]] : tensor<2x5x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_max(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { + %0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.ReduceMax"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> + return %1 : tensor<2x5x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor) -> tensor<2x5xf32> +// CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32> +// CHECK: return [[VAR_3_]] : tensor<2x5x1x1xf32> +// CHECK: } + +// ----- + + +func.func @reduce_min(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { + %0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.ReduceMin"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> + return %1 : tensor<2x5x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_min +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0x7F800000> : tensor +// CHECK: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.minimum across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor) -> tensor<2x5xf32> +// CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32> +// CHECK: return [[VAR_3_]] : tensor<2x5x1x1xf32> +// CHECK: }