diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index abc789c5cd..c0571a37fc 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index 0c4f778713..ad7283a53c 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 80a83f03e0..38d6eac50e 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -3589,6 +3589,63 @@ where the mean and variance are computed per instance per group of channels, and groups `num_groups` should be divisible by the number of channels so that there are an equal number of channels per group. +The overall computation has two stages: the first stage normalizes the elements to +have zero mean and unit variance for each instance in each group, and the second +stage scales and shifts the results of the first stage. The floating-point precision +used in the first stage is determined by the `stash_type` attribute. For example, +if `stash_type` is 1, the operator casts all input variables to 32-bit float, +performs the computation, and finally casts the normalized results back to the +original type of `X`. The second stage does not depend on `stash_type`. + +When the number of groups is the same as the number of channels, this operator is +equivalent to InstanceNormalization. When there is only one group, this operator +is equivalent to LayerNormalization. + +Traits: `AlwaysSpeculatableImplTrait` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
epsilon::mlir::FloatAttr32-bit float attribute
num_groups::mlir::IntegerAttr64-bit signed integer attribute
stash_type::mlir::IntegerAttr64-bit signed integer attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `scale` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `bias` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +### `onnx.GroupNormalizationV18` (ONNXGroupNormalizationV18Op) + +_ONNX GroupNormalization operation_ + +A GroupNormalization function. Carries out group normalization as described in +the paper https://arxiv.org/abs/1803.08494 + +This operator transforms input according to +``` +y = scale * (x - mean) / sqrt(variance + epsilon) + bias, +``` +where the mean and variance are computed per instance per group of channels, and +`scale` and `bias` should be specified for each group of channels. The number of +groups `num_groups` should be divisible by the number of channels so that there are +an equal number of channels per group. + When the number of groups is the same as the number of channels, this operator is equivalent to InstanceNormalization. When there is only one group, this operator is equivalent to LayerNormalization. diff --git a/docs/Prerequisite.md b/docs/Prerequisite.md index b65fa3cce0..60a48c84cd 100644 --- a/docs/Prerequisite.md +++ b/docs/Prerequisite.md @@ -26,4 +26,4 @@ Ninja can be installed with apt on Debian/Ubuntu Linux, or brew on MacOS. On RHE Java SDK can be installed with distro specific package manager on Linux such as yum on RHEL/Fedora, apt on Debian/Ubuntu, or brew on MacOS. Java SDK is only required if you plan to use the onnx-mlir `--EmitJNI` option to compile a model into a jar file for use in a Java environment. Note that the jar file contains native model runtime library called through JNI so it is not portable across different architectures. To check the java version, run `java --version`. -All the `PyPi` package dependencies and their appropriate versions are captured in [requirements.txt](requirements.txt). +All the `PyPi` package dependencies and their appropriate versions are captured in [requirements.txt](https://github.com/onnx/onnx-mlir/blob/main/requirements.txt). diff --git a/docs/SupportedONNXOps-NNPA.md b/docs/SupportedONNXOps-NNPA.md index ab21ed5ef2..b8c1536457 100644 --- a/docs/SupportedONNXOps-NNPA.md +++ b/docs/SupportedONNXOps-NNPA.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *NNPA*. -Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index a9206358ad..2aea8be068 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *cpu*. -Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. | Op |Supported Opsets (inclusive) |Limitations |Notes | @@ -36,8 +36,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **BitwiseOr** |18 - * | | | | **BitwiseXor** |18 - * | | | | **BlackmanWindow** |none | | | | -| **Cast** |6 - * |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. | | -| **CastLike** |19 - * |CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. | | +| **Cast** |6 - * |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. | | +| **CastLike** |19 - * |CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. | | | **CastMap** |none | | | | | **CategoryMapper** |none | | | | | **Ceil** |6 - * | | | @@ -48,8 +48,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Compress** |9 - * | | | | **Concat** |6 - * | | | | **ConcatFromSequence** |none | | | | -| **Constant** |6 - * | | | -| **ConstantOfShape** |9 - * | | | +| **Constant** |6 - * |Does not support int4 and uint4. | | +| **ConstantOfShape** |9 - * |Does not support int4 and uint4. | | | **Conv** |6 - * | | | | **ConvInteger** |none | | | | | **ConvTranspose** |6 - * |Spatial dimensions (H and W in input `X`, and kH and kW in input `W`) must be static dimension. | | @@ -59,7 +59,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **DFT** |17 - * | | | | **DeformConv** |none | | | | | **DepthToSpace** |13 - * | | | -| **DequantizeLinear** |10 - * |Only support for per-tensor or layer dequantization. No support for per-axis dequantization. | | +| **DequantizeLinear** |10 - * |Only support for per-tensor or layer dequantization. No support for per-axis dequantization. Does not support int4 and uint4. | | | **Det** |none | | | | | **DictVectorizer** |none | | | | | **Div** |6 - * |No support for short integers. | | @@ -73,7 +73,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Expand** |8 - * |Input `shape` must have static shape. | | | **EyeLike** |none | | | | | **FeatureVectorizer** |none | | | | -| **Flatten** |6 - * | | | +| **Flatten** |6 - * |Does not support int4 and uint4. | | | **Floor** |6 - * | | | | **GRU** |7 - * |W, B and R must be constants. | | | **Gather** |6 - * | | | @@ -94,8 +94,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **HardSigmoid** |6 - * | | | | **HardSwish** |none | | | | | **Hardmax** |6 - * | | | -| **Identity** |16 - * |Sequence identity not supported. | | -| **If** |16 - * |Sequence and Optional outputs are not supported. | | +| **Identity** |16 - * |Sequence identity not supported. Does not support int4 and uint4. | | +| **If** |16 - * |Sequence and Optional outputs are not supported. Does not support int4 and uint4. | | | **Imputer** |none | | | | | **InstanceNormalization** |6 - * | | | | **IsInf** |20 - * |Currently no support for float16 infinity value. Only for float32 and float64. | | @@ -111,7 +111,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **LinearRegressor** |none | | | | | **Log** |6 - * | | | | **LogSoftmax** |13 - * |Axis 0, 1, and default currently disabled due to changes in ONNX 1.8.1/Opset 13. |Temporally removed due to changes in onnx 1.8.1. | -| **Loop** |6 - * |Input must have static shape. | | +| **Loop** |6 - * |Input must have static shape. Does not support int4 and uint4. | | | **LpNormalization** |none | | | | | **LpPool** |none | | | | | **MatMul** |6 - * | | | @@ -142,11 +142,11 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **OptionalHasElement** |none | | | | | **Or** |7 - * | | | | **PRelu** |6 - * | | | -| **Pad** |6 - * |axes input not supported. | | +| **Pad** |6 - * |axes input not supported. Does not support int4 and uint4. | | | **Pow** |7 - * |No support for power with integer types. | | | **QLinearConv** |none | | | | | **QLinearMatMul** |none | | | | -| **QuantizeLinear** |10 - * |Do not support per-axis and i8 quantization. | | +| **QuantizeLinear** |10 - * |Does not support per-axis and i8 quantization. Does not support int4 and uint4. | | | **RNN** |7 - * |W, B and R must be constants. | | | **RandomNormal** |none | | | | | **RandomNormalLike** |none | | | | @@ -165,7 +165,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **ReduceSum** |6 - * |Default axis and do_not_keep_dim not supported. |Default axis and do_not_keep_dim temporarily removed due to changes in onnx 1.8.1. | | **ReduceSumSquare** |13 - * |Default axis and do_not_keep_dim not supported. | | | **Relu** |6 - * | | | -| **Reshape** |6 - * |allowzero not supported. Input `shape` must have static dimension. | | +| **Reshape** |6 - * |allowzero not supported. Input `shape` must have static dimension. Does not support int4 and uint4. | | | **Resize** |10 - * |Missing support for linear, cubic, crop, pytorch_half_pixel, and floor. Attributes antialias, axes and keep_aspect_ratio_policy are not supported. `scales` and `sizes` must have static dimension. | | | **ReverseSequence** |10 - * | | | | **RoiAlign** |none | | | | @@ -174,7 +174,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **SVMClassifier** |none | | | | | **SVMRegressor** |none | | | | | **Scaler** |none | | | | -| **Scan** |8 - * |Does not support dynamic shapes. |Precision issue with newer opset, maybe just unsupported. Dynamic shape?. | +| **Scan** |8 - * |Does not support dynamic shapes. Does not support int4 and uint4. |Precision issue with newer opset, maybe just unsupported. Dynamic shape?. | | **Scatter** |none | | | | | **ScatterElements** |11 - * |Does not support duplicate indices. | | | **ScatterND** |11 - * |Does not support scatternd add/multiply. | | @@ -186,13 +186,13 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **SequenceInsert** |11 - * |Does not support unranked sequence element. | | | **SequenceLength** |none | | | | | **SequenceMap** |none | | | | -| **Shape** |15 - * |Does not support start and end attributes. | | +| **Shape** |15 - * |Does not support start and end attributes. Does not support int4 and uint4. | | | **Shrink** |none | | | | | **Sigmoid** |6 - * | | | | **Sign** |9 - * | | | | **Sin** |7 - * | | | | **Sinh** |9 - * | | | -| **Size** |13 - * | | | +| **Size** |13 - * |Does not support int4 and uint4. | | | **Slice** |13 - * |Axis must be a constant argument. |Add tests to slices, currently have none. | | **Softmax** |6 - * | | | | **SoftmaxCrossEntropyLoss** |none | | | | @@ -202,7 +202,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Split** |6 - * |Does not support static and dynamic shape, zero size splits. |Temporally removed due to changes in onnx 1.8.1. | | **SplitToSequence** |none | | | | | **Sqrt** |6 - * | | | -| **Squeeze** |6 - * |Does not support static and dynamic shape. |Temporally removed due to changes in onnx 1.8.1. | +| **Squeeze** |6 - * |Does not support static and dynamic shape. Does not support int4 and uint4. |Temporally removed due to changes in onnx 1.8.1. | | **StringNormalizer** |none | | | | | **Sub** |6 - * |Does not support short integers. | | | **Sum** |6 - * | | | @@ -212,12 +212,12 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **ThresholdedRelu** |none | | | | | **Tile** |6 - * | | | | **TopK** |10 - * |`K`, the number of top elements to retrieve, must have static shape. | | -| **Transpose** |6 - * | | | +| **Transpose** |6 - * |Does not support int4 and uint4. | | | **TreeEnsembleClassifier** |none | | | | | **TreeEnsembleRegressor** |none | | | | | **Trilu** |14 - * | | | | **Unique** |11 - * | | | -| **Unsqueeze** |6 - * |Does not support static and dynamic shape. |Temporally removed due to changes in onnx 1.8.1. | +| **Unsqueeze** |6 - * |Does not support static and dynamic shape. Does not support int4 and uint4. |Temporally removed due to changes in onnx 1.8.1. | | **Upsample** |7 - * |Input `X` and `Y` must have static shape. | | | **Where** |9 - * | | | | **Xor** |7 - * | | | diff --git a/requirements.txt b/requirements.txt index f77bb46e5b..9f7c35fa97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ lit~=15.0 # numpy 1.24 deprecates np.object, np.bool, np.float, np.complex, np.str, # and np.int which are used heavily in onnx-mlir. -numpy~=1.22.2, <=1.23.5 +numpy==2.0.1 +onnx==1.16.2 protobuf==4.21.12 -pytest~=7.2 -pytest-xdist~=3.0 +pytest==8.3.2 +pytest-xdist==3.6.1 diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 6122d0c205..067839b22f 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -80,7 +80,7 @@ op_dialect_version_map_["Gradient"] = {1}; op_dialect_version_map_["Greater"] = {13}; op_dialect_version_map_["GreaterOrEqual"] = {16}; op_dialect_version_map_["GridSample"] = {16}; -op_dialect_version_map_["GroupNormalization"] = {18}; +op_dialect_version_map_["GroupNormalization"] = {21, 18}; op_dialect_version_map_["HammingWindow"] = {17}; op_dialect_version_map_["HannWindow"] = {17}; op_dialect_version_map_["HardSigmoid"] = {6}; @@ -358,6 +358,8 @@ import_handler_map_["GridSample"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalization"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GroupNormalizationV18"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["HammingWindow"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["HannWindow"] = diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 03975bb1d0..8a5bd9d657 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -42,6 +42,8 @@ bool enableONNXHybridPass; // common for both std::vector functionsToDecompose; // common for both std::string opsForCall; // common for both bool disableKrnlOpFusion; // common for both +bool disableQuantZeroPoint; // common for both +bool enableKrnlBufferReuse; // common for both bool disableMemRefPrefetch; // common for both EmissionTargetType emissionTarget; // onnx-mlir only bool invokeOnnxVersionConverter; // onnx-mlir only @@ -194,7 +196,7 @@ static llvm::cl::list> llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt enableONNXHybridPassOpt("onnx-hybrid-pass", - llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n" + llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n" "Set to 'false' if you want to disable ONNX hybrid pass."), llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -207,14 +209,31 @@ static llvm::cl::list> static llvm::cl::opt disableKrnlOpFusionOpt( "disable-krnl-op-fusion", - llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n" + llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n" "Set to 'true' if you want to disable fusion."), llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt disable_quantization_zero_point( + "disable-quantization-zero-point", + llvm::cl::desc( + "Disable the use of zero-point in quantization (default=false).\n" + "Set to 'true' if you want to disable the use of zero-point\n" + "in dyn/static quantization/dequantization."), + llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + +static llvm::cl::opt enableKrnlBufferReuseOpt( + "enable-krnl-buffer-reuse", + llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass" + "(default=false)\n" + "Set to 'true' if you want to enable buffer reuse."), + llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::opt disableMemRefPrefetchOpt( "disable-memref-prefetch", - llvm::cl::desc("disable generation of memref.prefetch (default=false)\n" + llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n" "Set to 'true' if you want to disable prefetch."), llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -1136,7 +1155,6 @@ std::string getLibraryPath() { // as lrodataScript. std::string getToolPath( const std::string &tool, bool flag /*false by default*/) { - if (!flag) { std::string execDir = llvm::sys::path::parent_path(getExecPath()).str(); llvm::SmallString<8> toolPath(execDir); diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index fe12e4511c..3e0940d70b 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -87,6 +87,8 @@ extern bool enableONNXHybridPass; // common for both extern std::vector functionsToDecompose; // common for both extern std::string opsForCall; // common for both extern bool disableKrnlOpFusion; // common for both +extern bool disableQuantZeroPoint; // common for both +extern bool enableKrnlBufferReuse; // common for both extern bool disableMemRefPrefetch; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only extern bool invokeOnnxVersionConverter; // onnx-mlir only diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 8c57a65c76..81c4b9768b 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -28,6 +28,82 @@ using namespace mlir; namespace onnx_mlir { +// Check the input, x, can be reused as the output buffer +bool isBufferReusable(Value x, MemRefType outputType) { + if (!x.hasOneUse()) + return false; + + Type xType = x.getType(); + auto inputType = dyn_cast(xType); + if (!inputType) + return false; + // Currently, only static shape could be reused. + // ToFix: use DimAnalysis to handle dynamic shape. + if (!hasStaticShape(inputType)) + return false; + if (!hasStaticShape(outputType)) + return false; + + // Currently reuse requires that the shape has to be the same. + // ToFix: If the shape is not the same, memref.cast can be used. + if (getRank(inputType) != getRank(outputType)) + return false; + for (int64_t i = 0; i < getRank(inputType); i++) { + if (inputType.getShape()[i] != outputType.getShape()[i]) + return false; + } + + // ToFix: The simd padding is not checked + // We did not record whether the memref is padded or not. + // The padding added to the memref the as an attribute, or not needed. + return true; +} + +// Traverse the operands to find the candidate for buffer reuse. +// Return -1, if no candidate is found. +int whichBufferToReuse(ValueRange values, MemRefType outputType) { + for (size_t i = 0; i < values.size(); i++) { + if (isBufferReusable(values[i], outputType)) + return i; + } + return -1; +} + +// Allocate memref (as before) if no input buffer can be reused. +// Default VL=0 is used for non SIMD allocation +Value allocOrReuse(MemRefBuilder &create, Operation *op, + ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, + int64_t alignment, int64_t VL = 0); + +Value allocOrReuse(MemRefBuilder &create, Operation *op, + ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, + int64_t alignment, int64_t VL) { + + int indexToReuse = -1; + // By default, enableKrnlBufferReuse is false. Simply allocate a memref. + if (enableKrnlBufferReuse) { + // Be aware to use the op->getOperands() to check the number of uses. + // After buffer reuse, the number of uses of the transformed Value, + // generatedOperands, will increase. + indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType); + } + + if (indexToReuse != -1) { + int size = getSizeInBytes(outputMemRefType); + LLVM_DEBUG({ + llvm::dbgs() << " malloc_size " << size << "\n"; + op->dump(); + }); + return generatedOperands[indexToReuse]; + } else { + if (VL == 0) + return create.alignedAlloc(outputMemRefType, dims, alignment); + else + return create.alignedAllocWithSimdPadding( + outputMemRefType, dims, VL, alignment); + } +} + // ============================================================================= /// Emit post-processing for variadic element-wise ops. @@ -1282,9 +1358,15 @@ Value emitScalarOpFor( Value scaleFloat = scalarOperands[1]; Value zeroPointInt = scalarOperands[2]; - Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); Value xFloat = create.math.cast(elementType, XInt); - Value sub = create.math.sub(xFloat, zeroPointFloat); + + Value sub; + if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) { + Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); + sub = create.math.sub(xFloat, zeroPointFloat); + } else { + sub = xFloat; + } Value res = create.math.mul(sub, scaleFloat); return res; } @@ -1323,14 +1405,14 @@ static LogicalResult getPartiallyFlattenedSimdCode( IndexExprScope allocScope(create.vec, shapeHelper->getScope()); DimsExpr outputDims; getIndexExprList(shapeHelper->getOutputDims(), outputDims); - // Alloc memory with padding for SIMD. + // Reuse the buffer from the input, or Alloc memory with padding for SIMD. // For the moment, its ok to go here; if we truly have partial flattening of // the simd code, then we only do it with static memref size that are // multiples of VL * unrollVL, so there should be no padding anyway. This // will change if we do partial flattening with non-multiple of VL * // unrollVL. - Value alloc = create.mem.alignedAllocWithSimdPadding( - outputMemRefType, outputDims, VL, alignment); + Value alloc = allocOrReuse( + create.mem, op, operands, outputMemRefType, outputDims, alignment, VL); // Create flat inputs in the last innerDinNum dims. llvm::SmallVector flatOperands; for (Value oper : operands) { @@ -1975,8 +2057,9 @@ struct ONNXElementwiseUnaryOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { @@ -2156,8 +2239,9 @@ struct ONNXElementwiseBinaryOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { @@ -2331,8 +2415,9 @@ struct ONNXElementwiseVariadicOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { diff --git a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp index 8d325c1964..5484974624 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" @@ -29,7 +30,7 @@ void emitDynamicQuantizationLinearScalarParameters( ConversionPatternRewriter &rewriter, Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, Value input, Value qMin, Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint, - bool enableSIMD, bool enableParallel) { + bool wantZeroPoint, bool enableSIMD, bool enableParallel) { MultiDialectBuilder create(rewriter, loc); // Types @@ -62,11 +63,15 @@ void emitDynamicQuantizationLinearScalarParameters( scale = create.math.div(xDiff, boundDiff); // Compute y_zero_point. - Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); - // Saturate zero point. - Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); - // Round zero point. - zeroPoint = create.math.round(saturateZeroPoint); + if (wantZeroPoint) { + Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); + // Saturate zero point. + Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); + // Round zero point. + zeroPoint = create.math.round(saturateZeroPoint); + } else { + zeroPoint = zero; + } quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint); } @@ -122,15 +127,17 @@ struct ONNXDynamicQuantizeLinearOpLowering Value qMin = create.math.constant(elementType, 0.0); Value scale, zeroPoint, zeroPointInt; + bool wantZeroPoint = !disableQuantZeroPoint; emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt, - enableSIMD, enableParallel); + wantZeroPoint, enableSIMD, enableParallel); create.krnl.store(scale, YScale); create.krnl.store(zeroPointInt, YZeroPoint); emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/, + enableSIMD, enableParallel); rewriter.replaceOp(op, {Y, YScale, YZeroPoint}); onnxToKrnlSimdReport(op); diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp index 124b854bde..96042bd799 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp @@ -23,7 +23,8 @@ void emitQuantizationLinearScalarParameters( mlir::Operation *op, mlir::MemRefType inputType, mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims, mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale, - mlir::Value zeroPoint, bool enableSIMD, bool enableParallel); + mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel); // Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin // and qMax. @@ -32,5 +33,6 @@ void emitDynamicQuantizationLinearScalarParameters( mlir::Operation *op, mlir::MemRefType inputType, mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint, - mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel); + mlir::Value &quantizedZeroPoint, bool wantZeroPoint, bool enableSIMD, + bool enableParallel); } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 715968583d..2567c4a1f4 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" @@ -26,7 +27,8 @@ namespace onnx_mlir { void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax, - Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) { + Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel) { MultiDialectBuilder create( rewriter, loc); @@ -77,7 +79,11 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, // Round Value roundX = create.math.round(scaleX); // Adjust - Value adjustX = create.math.add(roundX, zeroPoint); + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; // Saturate Value saturateX = create.math.clip(adjustX, qMin, qMax); Value res = create.math.cast(quantizedElementType, saturateX); @@ -160,15 +166,21 @@ struct ONNXQuantizeLinearOpLowering // Load y_zero_point. Value zeroPoint; + bool hasZeroPoint = false; if (!isNoneValue(YZeroPoint)) { zeroPoint = create.krnl.load(adaptor.getYZeroPoint()); zeroPoint = create.math.cast(elementType, zeroPoint); - } else - zeroPoint = create.math.constant(elementType, 0.0); - + hasZeroPoint = true; + } + if (disableQuantZeroPoint) { + // TODO: should we expect to disable hasZeroPoint forcefully, or generate + // an error if we had a zero point? Right now, just forcefully assert we + // have no zero point, i.e. ignore one even if we had a zero point. + hasZeroPoint = false; + } emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, hasZeroPoint, enableSIMD, enableParallel); rewriter.replaceOp(op, {Y}); onnxToKrnlSimdReport(op); diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index 7b8b67bd72..4faaff6dfb 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -4,7 +4,7 @@ //===----- DialectBuilder.cpp - Helper functions for ONNX dialects -------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -164,6 +164,19 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, toTensor(bias), axisAttr, epsilon, stashTypeAttr); return layerNormOp.getY(); } +// In the case of GroupNormalization when stashType can be specified +Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, + Value bias, int64_t axis, FloatAttr epsilon, IntegerAttr stashType) const { + IntegerAttr axisAttr = getSignedInt64Attr(axis); + Value noneVal = none(); + Type noneType = noneVal.getType(); + ONNXLayerNormalizationOp layerNormOp = + createOpAndInferShapes( + /*Y type*/ toTensor(outputType), /*mean type*/ noneType, + /*std dev Type*/ noneType, toTensor(input), toTensor(scale), + toTensor(bias), axisAttr, epsilon, stashType); + return layerNormOp.getY(); +} Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale, Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale, diff --git a/src/Dialect/ONNX/DialectBuilder.hpp b/src/Dialect/ONNX/DialectBuilder.hpp index 9ff98a3755..8f6b0931e3 100644 --- a/src/Dialect/ONNX/DialectBuilder.hpp +++ b/src/Dialect/ONNX/DialectBuilder.hpp @@ -91,6 +91,10 @@ struct OnnxBuilder : DialectBuilder { mlir::Value layerNorm(mlir::Type outputType, mlir::Value input, mlir::Value scale, mlir::Value bias, int64_t axis, mlir::FloatAttr epsilon) const; + // In the case of GroupNormalization when stashType can be specified + mlir::Value layerNorm(mlir::Type outputType, mlir::Value input, + mlir::Value scale, mlir::Value bias, int64_t axis, + mlir::FloatAttr epsilon, mlir::IntegerAttr stashType) const; // ONNXQLinearMatMulOp mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a, diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 8b3f76cdba..e7193e5cc2 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -3122,6 +3122,62 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", groups `num_groups` should be divisible by the number of channels so that there are an equal number of channels per group. + The overall computation has two stages: the first stage normalizes the elements to + have zero mean and unit variance for each instance in each group, and the second + stage scales and shifts the results of the first stage. The floating-point precision + used in the first stage is determined by the `stash_type` attribute. For example, + if `stash_type` is 1, the operator casts all input variables to 32-bit float, + performs the computation, and finally casts the normalized results back to the + original type of `X`. The second stage does not depend on `stash_type`. + + When the number of groups is the same as the number of channels, this operator is + equivalent to InstanceNormalization. When there is only one group, this operator + is equivalent to LayerNormalization. + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$bias, + DefaultValuedAttr:$epsilon, + SI64Attr:$num_groups, + DefaultValuedAttr:$stash_type); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + +def ONNXGroupNormalizationV18Op:ONNX_Op<"GroupNormalizationV18", + [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GroupNormalization operation"; + let description = [{ + A GroupNormalization function. Carries out group normalization as described in + the paper https://arxiv.org/abs/1803.08494 + + This operator transforms input according to + ``` + y = scale * (x - mean) / sqrt(variance + epsilon) + bias, + ``` + where the mean and variance are computed per instance per group of channels, and + `scale` and `bias` should be specified for each group of channels. The number of + groups `num_groups` should be divisible by the number of channels so that there are + an equal number of channels per group. + When the number of groups is the same as the number of channels, this operator is equivalent to InstanceNormalization. When there is only one group, this operator is equivalent to LayerNormalization. @@ -3146,11 +3202,12 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { - onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope); + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationV18OpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); return sh; } }]; + let hasVerifier = 1; } def ONNXHammingWindowOp:ONNX_Op<"HammingWindow", diff --git a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp index 8a1bbf3aa1..091426074f 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp @@ -149,6 +149,21 @@ LogicalResult ONNXInstanceNormalizationOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GroupNormalizationV18 +//===----------------------------------------------------------------------===// +LogicalResult ONNXGroupNormalizationV18Op::verify() { + ONNXGroupNormalizationV18OpAdaptor(*this); + llvm::outs() + << "Warning: The previous understanding of Opset 18 for " + "GroupNormalization " + "is incorrect. As shown in the following issue: " + "https://github.com/onnx/onnx/issues/5466.Rather, use Opset 21 for " + "GroupNormalization instead." + << "/n"; + return success(); +} + // TODO: should there be a shape inference for this one? //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp index 3b1318c15f..8a43b3e4a1 100644 --- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp @@ -77,6 +77,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op) CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp) +CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV13Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV11Op) diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 78165e19a3..6ea3fce402 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -624,7 +624,7 @@ struct ConcatFusePattern : public OpRewritePattern { // to determine the rank of A. // // Example of onnx.Custom: -// ``` +// ``` // "onnx.Custom"(%0, %1) {alpha = 1.250000e-01 : f32, // domain_name = "com.microsoft", // function_name = "FusedMatMul", @@ -831,93 +831,162 @@ struct InstanceNormIntoLayerNormPattern }; // Transform GroupNormalization into LayerNormalization -struct GroupNormIntoLayerNormPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ONNXGroupNormalizationOp groupNormOp, - PatternRewriter &rewriter) const final { - // Match. - Value input = groupNormOp.getX(); - if (!onnx_mlir::isRankedShapedType(input.getType())) - return failure(); - - // Get info. - Value scale = groupNormOp.getScale(); - Value bias = groupNormOp.getBias(); - ShapedType inputType = mlir::cast(input.getType()); - Type elementType = inputType.getElementType(); - auto inputShapeVal = inputType.getShape(); - int64_t C = inputShapeVal[1]; - int64_t inputRank = inputType.getRank(); - int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions. - assert(inputRank > nonSpacialRank && - "expected instance norm with input ranks > 2"); - int64_t spacialRank = inputRank - nonSpacialRank; - int64_t layerNormRank = inputRank + 1; // +1 as C is split to NG and C/NG - int64_t numGroups = groupNormOp.getNumGroups(); - - // Rewrite. - onnx_mlir::MultiDialectBuilder create( - rewriter, groupNormOp.getLoc()); - int64_t axis = nonSpacialRank; - int64_t numInNorm = layerNormRank - axis; - // Unsqueeze scale/bias from [NG] to [NG x 1 x 1 x ... x 1] with numInNorm +template +constexpr bool scaleAndBiasWithNumGroupShape = + std::is_same_v; + +template +LogicalResult ONNXGroupNormalizationCommon( + OP_TYPE groupNormOp, PatternRewriter &rewriter) { + + // Match. + Value input = groupNormOp.getX(); + if (!onnx_mlir::isRankedShapedType(input.getType())) + return failure(); + + // Get info. + Value scale = groupNormOp.getScale(); + Value bias = groupNormOp.getBias(); + ShapedType inputType = mlir::cast(input.getType()); + Type elementType = inputType.getElementType(); + auto inputShapeVal = inputType.getShape(); + int64_t C = inputShapeVal[1]; + int64_t inputRank = inputType.getRank(); + int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions. + assert(inputRank > nonSpacialRank && + "expected instance norm with input ranks > 2"); + int64_t spacialRank = inputRank - nonSpacialRank; + int64_t layerNormRank = inputRank + 1; // +1 as C is split to NG and C/NG + int64_t numGroups = groupNormOp.getNumGroups(); + + // Rewrite. + onnx_mlir::MultiDialectBuilder create( + rewriter, groupNormOp.getLoc()); + int64_t axis = nonSpacialRank; + int64_t numInNorm = layerNormRank - axis; + Type biasScaleType; + Value axes; + Value newBias; + Value newScale; + + //"numgroups" and "C" should have the same dimension index + llvm::SmallVector axesList, biasScaleVal; + + if constexpr (scaleAndBiasWithNumGroupShape) { + // Opset18 Uses "numgroups" the number of groups of channels for the scale + // and bias + // Unsqueeze scale/bias from [NG] to [1 x NG x 1 x ... x 1] with numInNorm // 1s. - llvm::SmallVector axesList, biasScaleShape; - biasScaleShape.emplace_back(numGroups); + biasScaleVal.emplace_back(numGroups); for (int64_t i = 1; i <= numInNorm; ++i) { - biasScaleShape.emplace_back(1); + biasScaleVal.emplace_back(1); axesList.emplace_back(i); } - Value axes = create.onnx.constantInt64(axesList); - Type biasScaleType = RankedTensorType::get(biasScaleShape, elementType); - Value newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); - Value newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); - // Convert input from N x C x D1...Dn to N x (NG x C/NG) x D1...Dn. - // First compute the new (possibly dynamic) shape. - Type batchShapeType = RankedTensorType::get({1}, rewriter.getI64Type()); - Value NShape = create.onnx.shape( - batchShapeType, input, /*start*/ 0, /*exclusive end*/ 1); - Value NGandMin1Shape = create.onnx.constantInt64({numGroups, -1}); - Type spacialShapeType = - RankedTensorType::get({spacialRank}, rewriter.getI64Type()); - Value spacialShape = - create.onnx.shape(spacialShapeType, input, /*start*/ nonSpacialRank); - Type layerNormShapeType = - RankedTensorType::get({layerNormRank}, rewriter.getI64Type()); - Value layerNormShape = create.onnx.concat( - layerNormShapeType, {NShape, NGandMin1Shape, spacialShape}, /*axis*/ 0); - // Compute type of converted input. - llvm::SmallVector layerNormShapeVal; - layerNormShapeVal.emplace_back(inputShapeVal[0]); - layerNormShapeVal.emplace_back(numGroups); + + axes = create.onnx.constantInt64(axesList); + biasScaleType = RankedTensorType::get(biasScaleVal, elementType); + newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); + newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); + } else { + // Opset21 Uses "C" the number of channels for the scale and bias + // The equivalent of "C" when split is "NG x C/NG" + // Reshape scale/bias from [C] to [NG x C/NG x 1 x ... x 1] with numInNorm + // 1s. + biasScaleVal.emplace_back(numGroups); + // C can be a dynamic or static value, account for that here if (C != ShapedType::kDynamic) { assert(C % numGroups == 0 && "expected numGroups to divide C"); - layerNormShapeVal.emplace_back(C / numGroups); - } else - layerNormShapeVal.emplace_back(ShapedType::kDynamic); - for (int64_t i = 0; i < spacialRank; ++i) - layerNormShapeVal.emplace_back(inputShapeVal[nonSpacialRank + i]); - RankedTensorType layerNormInputType = - RankedTensorType::get(layerNormShapeVal, elementType); - Value layerNormInput = - create.onnx.reshape(layerNormInputType, input, layerNormShape); - // Create output using layer norm. - Value layerNormY = create.onnx.layerNorm(layerNormInputType, layerNormInput, - newScale, newBias, axis, groupNormOp.getEpsilonAttr()); - // Resize output to original size - Type inputShapeType = + biasScaleVal.emplace_back(C / numGroups); + } else { + biasScaleVal.emplace_back(ShapedType::kDynamic); + } + + for (int64_t i = 2; i <= numInNorm; ++i) { + biasScaleVal.emplace_back(1); + } + + // Calculate the (possible) dynamic dimensions for biasScaleShape + Value NGShape = create.onnx.constantInt64({numGroups}); + Value oneDimShape = create.onnx.constantInt64({1, 1}); + Type biasScaleShapeType = RankedTensorType::get({inputRank}, rewriter.getI64Type()); - Value inputShape = create.onnx.shape(inputShapeType, input); - Type outputType = groupNormOp.getY().getType(); - Value Y = create.onnx.reshape(outputType, layerNormY, inputShape); - // Set the type of the output to be the same as the output of the original - // operation we are trying to replace. - Y.setType(groupNormOp.getResult().getType()); - // Replace operation. - rewriter.replaceOp(groupNormOp, Y); - return success(); + Value biasScaleShape = create.onnx.concat( + biasScaleShapeType, {NGShape, NGShape, oneDimShape}, /*axis*/ 0); + + // Reshape instead of unsqueeze (use biasScaleShape) + biasScaleType = RankedTensorType::get(biasScaleVal, elementType); + newScale = create.onnx.reshape(biasScaleType, scale, biasScaleShape); + newBias = create.onnx.reshape(biasScaleType, bias, biasScaleShape); + } + + // Convert input from N x C x D1...Dn to N x (NG x C/NG) x D1...Dn. + // First compute the new (possible dynamic) shape. + Type batchShapeType = RankedTensorType::get({1}, rewriter.getI64Type()); + Value NShape = create.onnx.shape( + batchShapeType, input, /*start*/ 0, /*exclusive end*/ 1); + Value NGandMin1Shape = create.onnx.constantInt64({numGroups, -1}); + Type spacialShapeType = + RankedTensorType::get({spacialRank}, rewriter.getI64Type()); + Value spacialShape = + create.onnx.shape(spacialShapeType, input, /*start*/ nonSpacialRank); + Type layerNormShapeType = + RankedTensorType::get({layerNormRank}, rewriter.getI64Type()); + Value layerNormShape = create.onnx.concat(layerNormShapeType, + {NShape, NGandMin1Shape, spacialShape}, /*axis*/ + 0); + // Compute type of converted input. + llvm::SmallVector layerNormShapeVal; + // Create a new tensor with the following dimensions: N, NG, C/NG, D1, D2, + // Dn... + layerNormShapeVal.emplace_back(inputShapeVal[0]); // N + layerNormShapeVal.emplace_back(numGroups); // NG + if (C != ShapedType::kDynamic) { + assert(C % numGroups == 0 && "expected numGroups to divide C"); + layerNormShapeVal.emplace_back(C / numGroups); // (C/NG) + } else + layerNormShapeVal.emplace_back(ShapedType::kDynamic); + for (int64_t i = 0; i < spacialRank; ++i) + layerNormShapeVal.emplace_back(inputShapeVal[nonSpacialRank + i]); // Dn + RankedTensorType layerNormInputType = + RankedTensorType::get(layerNormShapeVal, elementType); + Value layerNormInput = + create.onnx.reshape(layerNormInputType, input, layerNormShape); + // Create output using layer norm. + Value layerNormY = create.onnx.layerNorm(layerNormInputType, layerNormInput, + newScale, newBias, axis, groupNormOp.getEpsilonAttr()); + // Resize output to original size + Type inputShapeType = + RankedTensorType::get({inputRank}, rewriter.getI64Type()); + Value inputShape = create.onnx.shape(inputShapeType, input); + Type outputType = groupNormOp.getY().getType(); + Value Y = create.onnx.reshape(outputType, layerNormY, inputShape); + // Set the type of the output to be the same as the output of the original + // operation we are trying to replace. + Y.setType(groupNormOp.getResult().getType()); + // Replace operation. + rewriter.replaceOp(groupNormOp, Y); + return success(); +} + +struct GroupNormIntoLayerNormPattern1 + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXGroupNormalizationOp groupNormOp, + PatternRewriter &rewriter) const final { + return ONNXGroupNormalizationCommon( + groupNormOp, rewriter); + } +}; + +struct GroupNormIntoLayerNormPattern2 + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXGroupNormalizationV18Op groupNormOp, + PatternRewriter &rewriter) const final { + return ONNXGroupNormalizationCommon( + groupNormOp, rewriter); } }; @@ -1003,6 +1072,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -1100,7 +1170,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul patterns.insert(context); patterns.insert(context); - patterns.insert(context); + patterns.insert(context); + patterns.insert(context); // TODO: consider whether to include SoftmaxPattern here } diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 86b9ea00d1..3a49b460d1 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -395,7 +395,7 @@ def get_test_models(): }, # ==OP== Cast # ==MIN== 6 - # ==LIM== Cast only between float and double types. Only ppc64le and MacOS platforms support float16. + # ==LIM== Cast only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. "test_cast_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -434,7 +434,7 @@ def get_test_models(): "test_cast_STRING_to_FLOAT_cpu": {}, # appears unsupported at this time # ==OP== CastLike # ==MIN== 19 - # ==LIM== CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. + # ==LIM== CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. "test_castlike_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -615,10 +615,12 @@ def get_test_models(): }, # ==OP== Constant # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. # By def, no dynamic shapes. "test_constant_cpu": {STATIC_SHAPE: {}}, # ==OP== ConstantOfShape # ==MIN== 9 + # ==LIM== Does not support int4 and uint4. # By def, no dynamic shapes. "test_constantofshape_float_ones_cpu": {STATIC_SHAPE: {}}, "test_constantofshape_int_zeros_cpu": {STATIC_SHAPE: {}}, @@ -790,7 +792,7 @@ def get_test_models(): }, # ==OP== DequantizeLinear # ==MIN== 10 - # ==LIM== Only support for per-tensor or layer dequantization. No support for per-axis dequantization. + # ==LIM== Only support for per-tensor or layer dequantization. No support for per-axis dequantization. Does not support int4 and uint4. # "test_dequantizelinear_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_dequantizelinear_cpu": { STATIC_SHAPE: {}, @@ -981,6 +983,7 @@ def get_test_models(): }, # ==OP== Flatten # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. "test_flatten_axis0_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1258,21 +1261,21 @@ def get_test_models(): }, # ==OP== GroupNormalization # ==MIN== 18 - # "test_group_normalization_epsilon_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_group_normalization_epsilon_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_group_normalization_epsilon_expanded_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, - # "test_group_normalization_example_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_group_normalization_example_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_group_normalization_example_expanded_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1358,7 +1361,7 @@ def get_test_models(): }, # ==OP== Identity # ==MIN== 16 - # ==LIM== Sequence identity not supported. + # ==LIM== Sequence identity not supported. Does not support int4 and uint4. "test_identity_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1368,7 +1371,7 @@ def get_test_models(): # "test_identity_opt_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== If # ==MIN== 16 - # ==LIM== Sequence and Optional outputs are not supported. + # ==LIM== Sequence and Optional outputs are not supported. Does not support int4 and uint4. "test_if_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1785,7 +1788,7 @@ def get_test_models(): }, # ==OP== Loop # ==MIN== 1 - # ==LIM== Input must have static shape. + # ==LIM== Input must have static shape. Does not support int4 and uint4. "test_loop11_cpu": { STATIC_SHAPE: {}, # Need to enable ConvertSeqToMemrefPass for dynamic test. @@ -2264,7 +2267,7 @@ def get_test_models(): }, # ==OP== Pad # ==MIN== 2 - # ==LIM== axes input not supported + # ==LIM== axes input not supported. Does not support int4 and uint4. "test_constant_pad_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {0: {-1}}, @@ -2317,7 +2320,7 @@ def get_test_models(): }, # ==OP== QuantizeLinear # ==MIN== 10 - # ==LIM== Do not support per-axis and i8 quantization. + # ==LIM== Does not support per-axis and i8 quantization. Does not support int4 and uint4. # "test_quantizelinear_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_quantizelinear_cpu": { STATIC_SHAPE: {}, @@ -2623,7 +2626,7 @@ def get_test_models(): }, # ==OP== Reshape # ==MIN== 5 - # ==LIM== allowzero not supported. Input `shape` must have static dimension. + # ==LIM== allowzero not supported. Input `shape` must have static dimension. Does not support int4 and uint4. "test_reshape_extended_dims_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {0: {-1}}, @@ -2802,7 +2805,7 @@ def get_test_models(): }, # ==OP== Scan # ==MIN== 8 - # ==LIM== Does not support dynamic shapes. + # ==LIM== Does not support dynamic shapes. Does not support int4 and uint4. # ==TODO== Precision issue with newer opset, maybe just unsupported. Dynamic shape? # "test_scan_sum_cpu": {STATIC_SHAPE:{}}, "test_scan9_sum_cpu": {STATIC_SHAPE: {}}, @@ -2859,7 +2862,7 @@ def get_test_models(): # "test_sequence_insert_at_back_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== Shape # ==MIN== 15 - # ==LIM== Does not support start and end attributes. + # ==LIM== Does not support start and end attributes. Does not support int4 and uint4. "test_shape_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -2915,6 +2918,7 @@ def get_test_models(): }, # ==OP== Size # ==MIN== 13 + # ==LIM== Does not support int4 and uint4. "test_size_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -3042,7 +3046,7 @@ def get_test_models(): }, # ==OP== Squeeze # ==MIN== 1 - # ==LIM== Does not support static and dynamic shape. + # ==LIM== Does not support static and dynamic shape. Does not support int4 and uint4. # ==TODO== Temporally removed due to changes in onnx 1.8.1 # "test_squeeze_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_squeeze_negative_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -3141,6 +3145,7 @@ def get_test_models(): }, # ==OP== Transpose # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. "test_transpose_default_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -3286,7 +3291,7 @@ def get_test_models(): }, # ==OP== Unsqueeze # ==MIN== 1 - # ==LIM== Does not support static and dynamic shape. + # ==LIM== Does not support static and dynamic shape. Does not support int4 and uint4. # ==TODO== Temporally removed due to changes in onnx 1.8.1 # "test_unsqueeze_axis_0_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_unsqueeze_axis_1_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir index f6b022444a..93d38fc77a 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir @@ -3,6 +3,8 @@ // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. +// ----- + func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -29,10 +31,12 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %ar // ----- + func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_dequantizelinear_ui8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> @@ -42,11 +46,11 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, % // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 // CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32 -// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 // CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_8_]], [[VAR_6_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32 // CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir new file mode 100644 index 0000000000..e456311773 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -0,0 +1,176 @@ +// RUN: onnx-mlir-opt --disable-quantization-zero-point --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// Test quantization with disabled zero point + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// ----- + + +func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { + %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dequantizelinear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_5_:%.+]] = arith.uitofp [[VAR_4_]] : i8 to f32 +// CHECK: [[VAR_6_:%.+]] = arith.mulf [[VAR_5_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: } +// CHECK: return [[RES_]] : memref<4xf32> +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor) -> (tensor, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor, tensor, tensor + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-LABEL: func.func @test_dynamic_quantize_linear +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> (memref, memref, memref) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i8 +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_3_]], [[CST_0_2_]] : memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_]]#0, [[VAR_12_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_15_]], [[RES_3_]][] : memref +// CHECK: } +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref +// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_1_]]#0, [[VAR_12_1_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK: [[VAR_15_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_15_1_]], [[RES_4_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[CST_0_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_8_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_9_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_10_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_12_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_15_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_2_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_28_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.minnumf [[VAR_29_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_31_:%.+]] = arith.fptoui [[VAR_30_]] : f32 to i8 +// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_32_]], [[VAR_reshape_14_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> + return %0 : tensor<6xui8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_2_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_4_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_5_:%.+]] = math.floor [[VAR_4_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[VAR_8_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.mulf [[VAR_5_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = math.floor [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_5_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[VAR_16_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.minnumf [[VAR_19_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.fptoui [[VAR_20_]] : f32 to i8 +// CHECK: [[VAR_22_:%.+]] = builtin.unrealized_conversion_cast [[VAR_21_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_22_]], [[RES_]]{{.}}[[VAR_2_]]{{.}} : memref<6xui8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: } +} + diff --git a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir new file mode 100644 index 0000000000..2279a7c901 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// ----- +func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32> + %1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32> + %2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32> + return %2 : tensor<1024xf32> +} +// CHECK-LABEL: func.func @test_reuse +// CHECK-NOT: memref.alloc diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 6fcdd0bbd1..f4de9145b7 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -542,11 +542,11 @@ func.func @test_constantofshape(%arg0: tensor) -> tensor<*xi32> { // ----- -func.func @test_groupnorm(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> { - %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32> +func.func @test_groupnorm_v18(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> { + %0 = "onnx.GroupNormalizationV18"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32> onnx.Return %0 : tensor<3x4x2x2xf32> // mlir2FileCheck.py -// CHECK-LABEL: func.func @test_groupnorm +// CHECK-LABEL: func.func @test_groupnorm_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x2x2xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<3xi64>) -> tensor<2x1x1x1xf32> @@ -565,10 +565,36 @@ func.func @test_groupnorm(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg } // ----- -func.func @group_norm5d(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { - %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32> +func.func @test_groupnorm_v21(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32> + onnx.Return %0 : tensor<3x4x2x2xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_groupnorm_v21 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x2x2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<4xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x2x2xf32>) -> tensor<2xi64> +// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_6_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5xi64> +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x2x2xf32>, tensor<5xi64>) -> tensor<3x2x2x2x2xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[Mean_]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x2x1x1xf32>, tensor<2x2x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none) +// CHECK: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<4xi64> +// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<3x2x2x2x2xf32>, tensor<4xi64>) -> tensor<3x4x2x2xf32> +// CHECK: onnx.Return [[VAR_12_]] : tensor<3x4x2x2xf32> +// CHECK: } +} + +// ----- + +func.func @group_norm5d_v18(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { + %0 = "onnx.GroupNormalizationV18"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32> onnx.Return %0 : tensor<3x4x6x8x16xf32> -// CHECK-LABEL: func.func @group_norm5d +// CHECK-LABEL: func.func @group_norm5d_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3, 4]> : tensor<4xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32> @@ -588,6 +614,32 @@ func.func @group_norm5d(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %ar // ----- +func.func @group_norm5d_v21(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32> + onnx.Return %0 : tensor<3x4x6x8x16xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @group_norm5d_v21 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<5xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<3xi64> +// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_6_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<3xi64>) -> tensor<6xi64> +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_]], [[Mean_]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x2x1x1x1xf32>, tensor<2x2x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none) +// CHECK: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<5xi64> +// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32> +// CHECK: onnx.Return [[VAR_12_]] : tensor<3x4x6x8x16xf32> +// CHECK: } +} + +// ----- + func.func @test_instancenorm(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<2x3x4x5x6xf32> { %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5x6xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5x6xf32> onnx.Return %0 : tensor<2x3x4x5x6xf32> diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index e9dfb24e72..b01bbf0b1f 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 8b72899c6c..348ffbf7d8 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -161,7 +161,7 @@ "Greater": [13], "GreaterOrEqual": [16], "GridSample": [16], - "GroupNormalization": [18], + "GroupNormalization": [21, 18], "HammingWindow": [17], "HannWindow": [17], "HardSigmoid": [6], @@ -396,6 +396,7 @@ "Gelu", "Greater", "GreaterOrEqual", + "GroupNormalizationV18", "Hardmax", "If", "IsInf",