diff --git a/CMakeLists.txt b/CMakeLists.txt
index bec9c49027..b9ef78d799 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -173,7 +173,11 @@ else()
add_subdirectory(third_party/rapidcheck)
if (ONNX_MLIR_ENABLE_STABLEHLO)
+ # Workaround for stablehlo failing to link libStablehloOptimizationPasses.so if built with shared libs
+ set(ONNX_MLIR_BUILD_SHARED_LIBS_STORE ${BUILD_SHARED_LIBS})
+ set(BUILD_SHARED_LIBS OFF)
add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL)
+ set(BUILD_SHARED_LIBS ${ONNX_MLIR_BUILD_SHARED_LIBS_STORE})
endif()
if (NOT TARGET benchmark)
diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md
index 21fbc37e0c..3cdfdc81d2 100644
--- a/docs/BuildOnLinuxOSX.md
+++ b/docs/BuildOnLinuxOSX.md
@@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
-cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
+cd llvm-project && git checkout 43d71baae36c8d8b5a9995aa35efebe09cc9c2d6 && cd ..
```
[same-as-file]: <> (utils/build-mlir.sh)
diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md
index 13e2a002ec..339b7df662 100644
--- a/docs/BuildOnWindows.md
+++ b/docs/BuildOnWindows.md
@@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
-cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
+cd llvm-project && git checkout 43d71baae36c8d8b5a9995aa35efebe09cc9c2d6 && cd ..
```
[same-as-file]: <> (utils/build-mlir.cmd)
diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md
index 797aef9ae6..ed608672ab 100644
--- a/docs/Dialects/krnl.md
+++ b/docs/Dialects/krnl.md
@@ -1,21 +1,29 @@
+
### `krnl.acos` (KrnlAcosOp)
_Krnl acos scalar operation_
Krnl acos scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.acosh` (KrnlAcoshOp)
@@ -23,17 +31,24 @@ _Krnl acosh scalar operation_
Krnl acosh scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.asin` (KrnlAsinOp)
@@ -41,17 +56,24 @@ _Krnl asin scalar operation_
Krnl asin scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.asinh` (KrnlAsinhOp)
@@ -59,17 +81,24 @@ _Krnl asinh scalar operation_
Krnl asinh scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.atan` (KrnlAtanOp)
@@ -77,17 +106,24 @@ _Krnl atan scalar operation_
Krnl atan scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.atanh` (KrnlAtanhOp)
@@ -95,23 +131,29 @@ _Krnl atanh scalar operation_
Krnl atanh scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.block` (KrnlBlockOp)
_Krnl block operation_
-
Syntax:
```
@@ -135,14 +177,15 @@ means to block the for loop referred to by %i using a tile size of 4.
| Operand | Description |
| :-----: | ----------- |
-| `loop` | any type
+| `loop` | any type |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `loop_block` | any type
-| `loop_local` | any type
+| `loop_block` | any type |
+| `loop_local` | any type |
+
### `krnl.call` (KrnlCallOp)
@@ -189,19 +232,19 @@ Interfaces: `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `parameters` | variadic of any type
+| `parameters` | variadic of any type |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `returnValue` | variadic of floating-point or integer
+| `returnValue` | variadic of floating-point or integer |
+
### `krnl.copy_from_tile_buffer` (KrnlCopyFromBufferOp)
_Copy from buffer._
-
Syntax:
```
@@ -231,15 +274,15 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `buffer` | memref of any type values
-| `dest` | memref of any type values
-| `starts` | variadic of index
+| `buffer` | memref of any type values |
+| `dest` | memref of any type values |
+| `starts` | variadic of index |
+
### `krnl.copy_to_tile_buffer` (KrnlCopyToBufferOp)
_Copy to buffer._
-
Syntax:
```
@@ -300,10 +343,11 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `buffer` | memref of any type values
-| `source` | memref of any type values
-| `starts` | variadic of index
-| `padValue` | any type
+| `buffer` | memref of any type values |
+| `source` | memref of any type values |
+| `starts` | variadic of index |
+| `padValue` | any type |
+
### `krnl.define_loops` (KrnlDefineLoopsOp)
@@ -321,7 +365,8 @@ Effects: `MemoryEffects::Effect{}`
| Result | Description |
| :----: | ----------- |
-«unnamed» | variadic of any type
+| «unnamed» | variadic of any type |
+
### `krnl.entry_point` (KrnlEntryPointOp)
@@ -330,23 +375,31 @@ _Indicate ONNX entry point_
The "krnl.entry_point" function indicates the main entry
point of ONNX model.
+
### `krnl.erf` (KrnlErfOp)
_Krnl erf scalar operation_
Krnl erf scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.find_index` (KrnlFindIndexOp)
@@ -367,22 +420,22 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `input` | string type or 64-bit signless integer
-| `G` | memref of 32-bit signless integer values
-| `V` | memref of 32-bit signless integer values
-| `len` | 32-bit signless integer
+| `input` | string type or 64-bit signless integer |
+| `G` | memref of 32-bit signless integer values |
+| `V` | memref of 32-bit signless integer values |
+| `len` | 32-bit signless integer |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `index` | index
+| `index` | index |
+
### `krnl.get_induction_var_value` (KrnlGetInductionVariableValueOp)
_Krnl_
-
Syntax:
```
@@ -405,13 +458,14 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `loops` | variadic of any type
+| `loops` | variadic of any type |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `ind_var_vals` | variadic of any type
+| `ind_var_vals` | variadic of any type |
+
### `krnl.get_linear_offset_index` (KrnlGetLinearOffsetIndexOp)
@@ -435,14 +489,15 @@ Interfaces: `AffineMapAccessInterface`, `AffineReadOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `memref` | memref of any type values
-| `indices` | variadic of index
+| `memref` | memref of any type values |
+| `indices` | variadic of index |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `result` | index
+| `result` | index |
+
### `krnl.global` (KrnlGlobalOp)
@@ -473,7 +528,8 @@ Effects: `MemoryEffects::Effect{}`
| Result | Description |
| :----: | ----------- |
-| `output` | memref of any type values
+| `output` | memref of any type values |
+
### `krnl.runtime_instrument` (KrnlInstrumentOp)
@@ -491,23 +547,31 @@ May be used for gdb.
nodeName | ::mlir::StringAttr | string attribute |
+
### `krnl.isinf` (KrnlIsInfOp)
_Krnl isinf scalar operation_
Krnl isinf scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | 1-bit signless integer
+| `out` | 1-bit signless integer |
+
### `krnl.isnan` (KrnlIsNaNOp)
@@ -515,17 +579,24 @@ _Krnl isnan scalar operation_
Krnl isnan scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | 1-bit signless integer
+| `out` | 1-bit signless integer |
+
### `krnl.iterate` (KrnlIterateOp)
@@ -564,19 +635,19 @@ Interfaces: `LoopLikeOpInterface`
| Operand | Description |
| :-----: | ----------- |
-«unnamed» | variadic of any type
+| «unnamed» | variadic of any type |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `results` | variadic of any type
+| `results` | variadic of any type |
+
### `krnl.load` (KrnlLoadOp)
_A Krnl operation to load data from the memref._
-
Syntax:
```
@@ -595,20 +666,20 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `memref` | memref of any type values
-| `indices` | variadic of index
+| `memref` | memref of any type values |
+| `indices` | variadic of index |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `result` | any type
+| `result` | any type |
+
### `krnl.matmul` (KrnlMatMulOp)
_Matmul operation for a single pannel._
-
Syntax:
```
@@ -803,19 +874,20 @@ Interfaces: `SpecializedKernelOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `A` | memref of any type values
-| `aGlobalIndexMemStart` | variadic of index
-| `B` | memref of any type values
-| `bGlobalIndexMemStart` | variadic of index
-| `C` | memref of any type values
-| `cGlobalIndexMemStart` | variadic of index
-| `loops` | variadic of any type
-| `iGlobalIndexComputeStart` | index
-| `jGlobalIndexComputeStart` | index
-| `kGlobalIndexComputeStart` | index
-| `iGlobalUB` | index
-| `jGlobalUB` | index
-| `kGlobalUB` | index
+| `A` | memref of any type values |
+| `aGlobalIndexMemStart` | variadic of index |
+| `B` | memref of any type values |
+| `bGlobalIndexMemStart` | variadic of index |
+| `C` | memref of any type values |
+| `cGlobalIndexMemStart` | variadic of index |
+| `loops` | variadic of any type |
+| `iGlobalIndexComputeStart` | index |
+| `jGlobalIndexComputeStart` | index |
+| `kGlobalIndexComputeStart` | index |
+| `iGlobalUB` | index |
+| `jGlobalUB` | index |
+| `kGlobalUB` | index |
+
### `krnl.memcpy` (KrnlMemcpyOp)
@@ -836,17 +908,17 @@ Interfaces: `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `dest` | memref of any type values
-| `src` | memref of any type values
-| `num_elems` | 64-bit signless integer
-| `dest_offset` | index
-| `src_offset` | index
+| `dest` | memref of any type values |
+| `src` | memref of any type values |
+| `num_elems` | 64-bit signless integer |
+| `dest_offset` | index |
+| `src_offset` | index |
+
### `krnl.memset` (KrnlMemsetOp)
_Set buffer to a given value._
-
Syntax:
```
@@ -889,14 +961,14 @@ Interfaces: `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `dest` | memref of any type values
-| `value` | any type
+| `dest` | memref of any type values |
+| `value` | any type |
+
### `krnl.movable` (KrnlMovableOp)
_Krnl movable operation_
-
Syntax:
```
@@ -914,6 +986,7 @@ are nested imperfectly between an "eager" and a "lazy" loop.
Traits: `SingleBlockImplicitTerminator`, `SingleBlock`
+
### `krnl.noValue` (KrnlNoneOp)
_An operation representing the absence of a value._
@@ -934,13 +1007,13 @@ Typically it is used for optional arguments used in KrnlCallop.
| Result | Description |
| :----: | ----------- |
-| `none_val` | none type
+| `none_val` | none type |
+
### `krnl.parallel_clause` (KrnlParallelClauseOp)
_Attach OpenMP clauses to an index varialbe_
-
Syntax:
```
@@ -962,14 +1035,14 @@ is used to uniquely associate a parallel loop with its clauses.
| Operand | Description |
| :-----: | ----------- |
-| `parallel_loop_index` | index
-| `num_threads` | 32-bit signless integer
+| `parallel_loop_index` | index |
+| `num_threads` | 32-bit signless integer |
+
### `krnl.parallel` (KrnlParallelOp)
_Mark Krnl loops as parallel loops_
-
Syntax:
```
@@ -1003,14 +1076,14 @@ Traits: `AttrSizedOperandSegments`
| Operand | Description |
| :-----: | ----------- |
-| `loops` | variadic of any type
-| `num_threads` | 32-bit signless integer
+| `loops` | variadic of any type |
+| `num_threads` | 32-bit signless integer |
+
### `krnl.permute` (KrnlPermuteOp)
_Krnl permute operation_
-
Syntax:
```
@@ -1083,7 +1156,8 @@ affine.for %arg0 = 0 to 1024 step 4 {
| Operand | Description |
| :-----: | ----------- |
-| `loops` | variadic of any type
+| `loops` | variadic of any type |
+
### `krnl.prefetch` (KrnlPrefetchOp)
@@ -1110,8 +1184,9 @@ Interfaces: `AffineMapAccessInterface`
| Operand | Description |
| :-----: | ----------- |
-| `memref` | memref of any type values
-| `indices` | variadic of index
+| `memref` | memref of any type values |
+| `indices` | variadic of index |
+
### `krnl.print` (KrnlPrintOp)
@@ -1134,7 +1209,8 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `input` | any type
+| `input` | any type |
+
### `krnl.print_tensor` (KrnlPrintTensorOp)
@@ -1163,7 +1239,8 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `input` | memref of any type values
+| `input` | memref of any type values |
+
### `krnl.random_normal` (KrnlRandomNormalOp)
@@ -1177,11 +1254,12 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `output` | memref of any type values
-| `numberOfValues` | index
-| `mean` | floating-point
-| `scale` | floating-point
-| `seed` | floating-point
+| `output` | memref of any type values |
+| `numberOfValues` | index |
+| `mean` | floating-point |
+| `scale` | floating-point |
+| `seed` | floating-point |
+
### `krnl.region` (KrnlRegionOp)
@@ -1200,6 +1278,7 @@ create a new memref inside the region and use it outside of the region.
Traits: `AffineScope`, `NoTerminator`, `SingleBlock`
+
### `krnl.round_even` (KrnlRoundEvenOp)
_Krnl round to nearest even operation_
@@ -1207,17 +1286,24 @@ _Krnl round to nearest even operation_
Krnl round to nearest even operation. Accept scalar or vector float values.
Vector must be 1D of a size that is a multiple of the hardware vector size.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point-like
+| `in` | floating-point-like |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point-like
+| `out` | floating-point-like |
+
### `krnl.seqalloc` (KrnlSeqAllocOp)
@@ -1236,13 +1322,14 @@ Interfaces: `AllocationOpInterface`, `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `length` | variadic of index
+| `length` | variadic of index |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `output` | memref of any type values
+| `output` | memref of any type values |
+
### `krnl.seqdealloc` (KrnlSeqDeallocOp)
@@ -1257,7 +1344,8 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `input_sequence` | memref of any type values
+| `input_sequence` | memref of any type values |
+
### `krnl.seqextract` (KrnlSeqExtractOp)
@@ -1292,14 +1380,15 @@ Interfaces: `AllocationOpInterface`, `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `seq` | memref of any type values
-| `index` | index
+| `seq` | memref of any type values |
+| `index` | index |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `output` | any type
+| `output` | any type |
+
### `krnl.seqstore` (KrnlSeqStoreOp)
@@ -1320,15 +1409,15 @@ Interfaces: `MemoryEffectOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `input` | any type
-| `seq` | memref of any type values
-| `index` | index
+| `input` | any type |
+| `seq` | memref of any type values |
+| `index` | index |
+
### `krnl.specialized_kernel` (KrnlSpecializedKernel)
_Krnl specialized kernel op_
-
Syntax:
```
@@ -1343,13 +1432,13 @@ Interfaces: `SpecializedKernelOpInterface`
| Operand | Description |
| :-----: | ----------- |
-| `loops` | variadic of any type
+| `loops` | variadic of any type |
+
### `krnl.store` (KrnlStoreOp)
_A Krnl operation to store data to the memref._
-
Syntax:
```
@@ -1367,9 +1456,10 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
-| `value` | any type
-| `memref` | memref of any type values
-| `indices` | variadic of index
+| `value` | any type |
+| `memref` | memref of any type values |
+| `indices` | variadic of index |
+
### `krnl.strlen` (KrnlStrlenOp)
@@ -1387,13 +1477,14 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `str` | string type
+| `str` | string type |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `res` | 64-bit signless integer
+| `res` | 64-bit signless integer |
+
### `krnl.strncmp` (KrnlStrncmpOp)
@@ -1411,15 +1502,16 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `str1` | string type
-| `str2` | string type
-| `len` | 64-bit signless integer
+| `str1` | string type |
+| `str2` | string type |
+| `len` | 64-bit signless integer |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `res` | 32-bit signless integer
+| `res` | 32-bit signless integer |
+
### `krnl.tan` (KrnlTanOp)
@@ -1427,17 +1519,24 @@ _Krnl tan scalar operation_
Krnl tan scalar operation.
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+
+Effects: `MemoryEffects::Effect{}`
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `in` | floating-point
+| `in` | floating-point |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `out` | floating-point
+| `out` | floating-point |
+
### `krnl.terminate` (KrnlTerminatorOp)
@@ -1456,11 +1555,11 @@ Interfaces: `NoMemoryEffect (MemoryEffectOpInterface)`, `RegionBranchTerminatorO
Effects: `MemoryEffects::Effect{}`
+
### `krnl.unroll` (KrnlUnrollOp)
_Krnl unroll operation_
-
Syntax:
```
@@ -1477,13 +1576,13 @@ unrolls the loop referred to by %i fully.
| Operand | Description |
| :-----: | ----------- |
-| `loop` | any type
+| `loop` | any type |
+
### `krnl.vector_type_cast` (KrnlVectorTypeCastOp)
_Vector type cast operation_
-
Syntax:
```
@@ -1511,19 +1610,19 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `source` | memref of any type values
+| `source` | memref of any type values |
#### Results:
| Result | Description |
| :----: | ----------- |
-| `result` | memref of any type values
+| `result` | memref of any type values |
+
### `krnl.yield` (KrnlYieldOp)
_Yield values to parent operation_
-
Syntax:
```
@@ -1550,5 +1649,5 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `operands` | variadic of any type
+| `operands` | variadic of any type |
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
index c98938481d..a23bbc7e1f 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
@@ -1680,7 +1680,7 @@ void ONNXToZHighLoweringPass::runOnOperation() {
onnx_mlir::getONNXToZHighMultipleOpPatterns(combinedPatterns);
// It's ok to fail.
- (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));
+ (void)applyPatternsGreedily(module, std::move(combinedPatterns));
// Run the unknown dimension analysis to help check equality of unknown
// dimensions at compile time.
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp
index 60e11ca41d..61dc1b4651 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp
@@ -62,7 +62,7 @@ void ZHighToONNXLoweringPass::runOnOperation() {
zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());
- (void)applyPatternsAndFoldGreedily(function, std::move(patterns));
+ (void)applyPatternsGreedily(function, std::move(patterns));
}
std::unique_ptr createZHighToONNXPass() {
diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
index 2dfd5e4ac3..ed533d8682 100644
--- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
+++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
@@ -688,7 +688,6 @@ struct ZHighToZLowQuantizedStickOpLowering : public ConversionPattern {
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());
- Type si64Ty = rewriter.getIntegerType(64, true);
Type i8Ty = rewriter.getIntegerType(8);
Type f32Ty = rewriter.getF32Type();
MemRefType scalarF32MemRefTy = MemRefType::get({}, f32Ty);
diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp
index 114c19d618..6853f9d070 100644
--- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp
+++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp
@@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
auto int16Ty = IntegerType::get(context, 16);
auto int32Ty = IntegerType::get(context, 32);
auto int64Ty = IntegerType::get(context, 64);
- auto float32Ty = FloatType::getF32(context);
+ auto float32Ty = Float32Type::get(context);
// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
@@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
Type llvmI64Ty = IntegerType::get(context, 64);
Type llvmI1Ty = IntegerType::get(context, 1);
Type llvmI8Ty = IntegerType::get(context, 8);
- Type llvmF32Ty = FloatType::getF32(context);
+ Type llvmF32Ty = Float32Type::get(context);
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
@@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
create.llvm.store(recScale, recScalePtr);
} else {
- Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
+ Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, recScalePtr);
}
@@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
create.llvm.store(offset, offsetPtr);
} else {
- Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
+ Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, offsetPtr);
}
diff --git a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp
index 3c63406d4b..dc97665232 100644
--- a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp
+++ b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp
@@ -211,8 +211,7 @@ class FoldStdAllocPass
RewritePatternSet patterns(&getContext());
patterns.insert(&getContext());
- static_cast(
- applyPatternsAndFoldGreedily(function, std::move(patterns)));
+ static_cast(applyPatternsGreedily(function, std::move(patterns)));
}
};
diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp
index 0e8067789c..723439b0df 100644
--- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp
+++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp
@@ -503,7 +503,7 @@ struct ZHighConstPropagationPass
patterns.insert(patterns.getContext());
patterns.insert(patterns.getContext());
patterns.insert(patterns.getContext());
- (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+ (void)applyPatternsGreedily(moduleOp, std::move(patterns));
}
};
} // anonymous namespace
diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighDecomposeStickUnstick.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighDecomposeStickUnstick.cpp
index 1dfe212e65..df555ac9c3 100644
--- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighDecomposeStickUnstick.cpp
+++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighDecomposeStickUnstick.cpp
@@ -71,7 +71,7 @@ struct ZHighDecomposeStickUnstickPass
ZHighDLF16ToF32Op::getCanonicalizationPatterns(patterns, &getContext());
ZHighF32ToDLF16Op::getCanonicalizationPatterns(patterns, &getContext());
ONNXLayoutTransformOp::getCanonicalizationPatterns(patterns, &getContext());
- (void)applyPatternsAndFoldGreedily(function, std::move(patterns));
+ (void)applyPatternsGreedily(function, std::move(patterns));
}
};
diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp
index ceb4d6459a..23a3d40ead 100644
--- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp
+++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp
@@ -367,7 +367,7 @@ struct ZHighLayoutPropagationPass
// rules in this pass.
ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());
- (void)applyPatternsAndFoldGreedily(function, std::move(patterns));
+ (void)applyPatternsGreedily(function, std::move(patterns));
}
};
} // anonymous namespace
diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighRecomposeToStickUnstick.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighRecomposeToStickUnstick.cpp
index 1cbed79fe0..6d4bf93522 100644
--- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighRecomposeToStickUnstick.cpp
+++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighRecomposeToStickUnstick.cpp
@@ -71,7 +71,7 @@ struct ZHighRecomposeToStickUnstickPass
ZHighDLF16ToF32Op::getCanonicalizationPatterns(patterns, &getContext());
ZHighF32ToDLF16Op::getCanonicalizationPatterns(patterns, &getContext());
ONNXLayoutTransformOp::getCanonicalizationPatterns(patterns, &getContext());
- (void)applyPatternsAndFoldGreedily(function, std::move(patterns));
+ (void)applyPatternsGreedily(function, std::move(patterns));
}
};
diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
index 2328f70265..1e6eb92347 100644
--- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
+++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
@@ -696,7 +696,7 @@ class ZLowRewritePass
patterns.insert(
&getContext(), removableStickOps);
- if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
+ if (failed(applyPatternsGreedily(function, std::move(patterns))))
return signalPassFailure();
// Remove ZLowStickOp that were marked "removable".
diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
index ff8dca97ab..77edb6b0c8 100644
--- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
+++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
@@ -355,7 +355,7 @@ class ZLowStickExpansionPass
patterns.insert(&getContext(), enableParallel);
// patterns.insert(&getContext());
- if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
+ if (failed(applyPatternsGreedily(function, std::move(patterns))))
return signalPassFailure();
}
};
diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp
index 4e04c93b3f..907fce1ac9 100644
--- a/src/Compiler/CompilerOptions.cpp
+++ b/src/Compiler/CompilerOptions.cpp
@@ -94,7 +94,6 @@ std::vector extraLibPaths; // onnx-mlir only
std::vector extraLibs; // onnx-mlir only
ProfileIRs profileIR; // onnx-mlir only
OptReport optReport; // onnx-mlir only
-bool useOldBufferization; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
bool split_input_file; // onnx-mlir-opt only
@@ -260,11 +259,12 @@ static llvm::cl::opt enableSafeCodeGenOpt("enable-safe-code-gen",
llvm::cl::location(enableSafeCodeGen), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));
+// TODO(alexe) re-enable prefetch.
static llvm::cl::opt disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
"Set to 'true' if you want to disable prefetch."),
- llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
+ llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));
static llvm::cl::list>
@@ -781,15 +781,6 @@ static llvm::cl::opt allowUnregisteredDialectsOpt(
llvm::cl::location(allowUnregisteredDialects), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptOptions));
-// Removed once the new LLVM bufferization works without performance regression.
-static llvm::cl::opt useOldBufferizationOpt("use-old-bufferization",
- llvm::cl::desc(
- "Enable the old LLVM bufferization mechanism (default=true).\n"
- "This option should be removed once the new LLVM bufferization works "
- "well in onnx-mlir."),
- llvm::cl::location(useOldBufferization), llvm::cl::init(true),
- llvm::cl::cat(OnnxMlirOptions));
-
// Configuration states associated with certain options.
// For example, when maccel is specified, NNPA can register
// dependent libdnn.
diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp
index d8aa17dc5a..d8f8f084ae 100644
--- a/src/Compiler/CompilerOptions.hpp
+++ b/src/Compiler/CompilerOptions.hpp
@@ -139,7 +139,6 @@ extern std::vector extraLibPaths; // onnx-mlir only
extern std::vector extraLibs; // onnx-mlir only
extern ProfileIRs profileIR; // onnx-mlir only
extern OptReport optReport; // onnx-mlir only
-extern bool useOldBufferization; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp
index 02ecde0241..2c0999ee02 100644
--- a/src/Compiler/CompilerPasses.cpp
+++ b/src/Compiler/CompilerPasses.cpp
@@ -251,15 +251,10 @@ void addKrnlToLLVMPasses(
// Currently this has to be done *after* lowering the affine dialect because
// operations in that dialect do not conform to the requirements explained
// in https://mlir.llvm.org/docs/BufferDeallocationInternals.
- if (useOldBufferization) {
- pm.addNestedPass(
- mlir::bufferization::createBufferDeallocationPass());
- } else {
- bufferization::BufferDeallocationPipelineOptions bufferDeallocOptions;
- mlir::bufferization::buildBufferDeallocationPipeline(
- pm, bufferDeallocOptions);
- pm.addPass(mlir::createBufferizationToMemRefPass());
- }
+ bufferization::BufferDeallocationPipelineOptions bufferDeallocOptions;
+ mlir::bufferization::buildBufferDeallocationPipeline(
+ pm, bufferDeallocOptions);
+ pm.addPass(mlir::createConvertBufferizationToMemRefPass());
// Late introduction of OpenMP, after bufferization.
if (enableParallel) {
diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt
index 92948137be..d81dd48bfd 100644
--- a/src/Conversion/KrnlToLLVM/CMakeLists.txt
+++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt
@@ -37,5 +37,6 @@ add_onnx_mlir_library(OMKrnlToLLVM
MLIRSCFToControlFlow
MLIRShapeToStandard
MLIRVectorToLLVMPass
+ MLIRUBToLLVM
onnx
)
diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
index 950eb65236..2da391cd30 100644
--- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
+++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
@@ -28,6 +28,7 @@
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -210,6 +211,8 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
vector::populateVectorInsertExtractStridedSliceTransforms(patterns);
vector::populateVectorStepLoweringPatterns(patterns);
vector::populateVectorRankReducingFMAPattern(patterns);
+ // Some vector ops are lower to UB. Hence, lower UB to LLVM.
+ ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
populateAffineToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
index e976b42b7f..5a4c494f14 100644
--- a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
+++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
@@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
// or
// (memref<3x4x5xf64>, index, f64, f64, f64)
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
- Type llvmOptionsTy = FloatType::getF32(context);
+ Type llvmOptionsTy = Float32Type::get(context);
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
if (inType.isF64()) {
- llvmOptionsTy = FloatType::getF64(context);
+ llvmOptionsTy = Float64Type::get(context);
llvmOutputTy = getPointerType(context, llvmOptionsTy);
}
Type llvmI64Ty = IntegerType::get(context, 64);
diff --git a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
index 2a0ee747c7..a50acf402f 100644
--- a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
+++ b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
@@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
Type outType = op->getResultTypes().front();
Type llvmInType, llvmOutType;
if (inType.isF16())
- llvmInType = FloatType::getF16(context);
+ llvmInType = Float16Type::get(context);
else if (inType.isF32())
- llvmInType = FloatType::getF32(context);
+ llvmInType = Float32Type::get(context);
else if (inType.isF64())
- llvmInType = FloatType::getF64(context);
+ llvmInType = Float64Type::get(context);
else if (inType.isBF16())
- llvmInType = FloatType::getBF16(context);
+ llvmInType = Float64Type::get(context);
if (outType.isInteger(1))
llvmOutType = IntegerType::get(context, 1);
else if (outType.isF32())
- llvmOutType = FloatType::getF32(context);
+ llvmOutType = Float32Type::get(context);
else if (outType.isF64())
- llvmOutType = FloatType::getF64(context);
+ llvmOutType = Float64Type::get(context);
// Insert and/or get reference to elementary math function declaration.
assert(
@@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
return SymbolRefAttr::get(context, mathFuncName);
// Create function declaration.
- // auto llvmF32Ty = FloatType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef({llvmInType}));
diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
index 62d7c25de3..a52e57afe7 100644
--- a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
+++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
@@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {
// Get memRefDescriptor, the new memref descriptor.
MemRefDescriptor memRefDescriptor =
- MemRefDescriptor::undef(rewriter, loc, targetStructType);
+ MemRefDescriptor::poison(rewriter, loc, targetStructType);
auto targetElementPtrType = memRefDescriptor.getElementPtrType();
// Set the new memref to the same buffer as the source memref.
@@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {
int64_t offset;
SmallVector strides;
- if (failed(getStridesAndOffset(targetType, strides, offset)))
+ if (failed(targetType.getStridesAndOffset(strides, offset)))
return failure();
// Unhandled dynamic offset.
diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
index 565e63a7d7..00e252fdb6 100644
--- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
+++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
@@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
SmallVector strides;
int64_t alignmentOffset; // not used, just to make the function call
// completed.
- if (getStridesAndOffset(memRefType, strides, alignmentOffset)
+ if (memRefType.getStridesAndOffset(strides, alignmentOffset)
.failed())
llvm_unreachable("Failed to get strides");
Value stringMemRef =
diff --git a/src/Conversion/ONNXToKrnl/Math/LRN.cpp b/src/Conversion/ONNXToKrnl/Math/LRN.cpp
index 1b08661a2d..12a596d08c 100644
--- a/src/Conversion/ONNXToKrnl/Math/LRN.cpp
+++ b/src/Conversion/ONNXToKrnl/Math/LRN.cpp
@@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern {
float alphaLit = adaptor.getAlpha().convertToFloat();
float betaLit = adaptor.getBeta().convertToFloat();
int sizeLit = adaptor.getSize();
- auto f32Type = FloatType::getF32(rewriter.getContext());
+ auto f32Type = Float32Type::get(rewriter.getContext());
Value biasValue = create.math.constant(f32Type, biasLit);
Value alphaDivSizeValue =
create.math.constant(f32Type, alphaLit / static_cast(sizeLit));
diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp
index adf494c88e..86c861e115 100644
--- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp
+++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) {
Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef size,
llvm::ArrayRef start) {
- DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
- DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
+ auto startVal =
+ mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
+ auto sizeVal =
+ mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
Value newSliceInput =
tosa::CreateOpAndInfer(rewriter(), loc(),
RankedTensorType::get(
llvm::SmallVector(size.size(), ShapedType::kDynamic),
mlir::cast(inputConst.getType()).getElementType()),
- inputConst, startAttr, sizeAttr);
+ inputConst, startVal, sizeVal);
return newSliceInput;
}
@@ -164,11 +167,12 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef shape) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector(shape.size(), ShapedType::kDynamic),
valueType.getElementType());
- return tosa::CreateOpAndInfer(
- rewriter(), loc(), newValueType, value, shapeAttr);
+ return tosa::CreateOpAndInfer(rewriter(), loc(),
+ newValueType, value,
+ mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
}
-Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
+Value TosaBuilder::mul(Value &lhs, Value &rhs, int8_t shift) {
if (needsRankBroadcast({lhs, rhs})) {
llvm::SmallVector valueVec = equalizeRanks({lhs, rhs});
lhs = valueVec[0];
@@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
+
+ auto int8Type = rewriter().getI8Type();
+ auto shiftValue =
+ TosaBuilder::createConst(ArrayRef{shift}, {1}, int8Type);
return tosa::CreateOpAndInfer(
- rewriter(), loc(), newValueType, lhs, rhs, shift);
+ rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
}
Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
@@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp(Value &lhs, Value &rhs);
// Return null if none is found.
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
auto definingOp = value.getDefiningOp();
- // If we have a cast between index/integer, skip it, i.e. get the defining op
- // that is the input to the cast.
+ // If we have a cast between index/integer, skip it, i.e. get the defining
+ // op that is the input to the cast.
if (auto castOp = dyn_cast_or_null(definingOp)) {
Value input = castOp.getIn();
definingOp = input.getDefiningOp();
diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp
index 1050d97053..46fdeaa93a 100644
--- a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp
+++ b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp
@@ -40,7 +40,7 @@ struct TosaBuilder : DialectBuilder {
template
mlir::Value binaryOp(mlir::Value &lhs, mlir::Value &rhs);
- mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift = 0);
+ mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int8_t shift = 0);
mlir::Value intdiv(mlir::Value &lhs, mlir::Value &rhs);
mlir::Value transpose(mlir::Value &value, llvm::ArrayRef perm);
diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
index 2e105d2dc5..ab8b9a43a0 100644
--- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
+++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
@@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern {
// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
- rewriter.replaceOpWithNewOp(op, op.getType(), input,
- rewriter.getI64IntegerAttr(0),
- rewriter.getI64IntegerAttr(std::numeric_limits::max()),
- rewriter.getF32FloatAttr(0.0f),
- rewriter.getF32FloatAttr(std::numeric_limits::max()));
+ auto inputElementType =
+ llvm::cast(op.getType()).getElementType();
+ if (llvm::isa(inputElementType)) {
+ auto minClamp = rewriter.getI64IntegerAttr(0);
+ auto maxClamp =
+ rewriter.getI64IntegerAttr(std::numeric_limits::max());
+ rewriter.replaceOpWithNewOp(
+ op, op.getType(), input, minClamp, maxClamp);
+ } else {
+ auto minClamp = rewriter.getF32FloatAttr(0.0f);
+ auto maxClamp =
+ rewriter.getF32FloatAttr(std::numeric_limits::max());
+ rewriter.replaceOpWithNewOp(
+ op, op.getType(), input, minClamp, maxClamp);
+ }
return success();
}
};
diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp
index 4f1028002c..599a67bbff 100644
--- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp
+++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
@@ -31,9 +32,6 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern {
LogicalResult matchAndRewrite(ONNXGemmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TosaBuilder tosaBuilder(rewriter, op->getLoc());
- // If legal, create a FullyConnected operator instead
- if (rewriteToTosaFC(op, adaptor, rewriter, tosaBuilder))
- return success();
return rewriteToTosaMatMul(op, adaptor, rewriter, tosaBuilder);
}
@@ -67,13 +65,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern {
llvm::SmallVector dynamicTensorShape = {
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};
+
A = tosa::CreateOpAndInfer(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
- rewriter.getDenseI64ArrayAttr(newShapeA))
+ mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
.getResult();
B = tosa::CreateOpAndInfer(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
- rewriter.getDenseI64ArrayAttr(newShapeB))
+ mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
.getResult();
// If transA or transB are present, create Transpose operators.
@@ -149,73 +148,6 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern {
// only need to check C[0].
return CShape[0] == AShape[0] || CShape[0] == BShape[0];
}
-
- /// The GEMM can be described as a FullyConnected operator.
- /// Y = AB^T + C if we perform a transpose on B only with.
- /// alpha and beta factors set to 1.
- /// Input A must be of rank 2 (input).
- /// Input B must be of rank 2 (weights).
- /// Input C must be of rank 1 (bias).
- bool rewriteToTosaFC(ONNXGemmOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter, TosaBuilder &tosaBuilder) const {
- Value A = op.getA();
- Value B = op.getB();
- Value C = op.getC();
-
- auto AType = mlir::cast(A.getType());
- auto BType = mlir::cast(B.getType());
-
- bool isCPresent = !mlir::isa(C.getType());
- // If C is present, it can only be of rank 1, if the rank is not 1, return
- // false.
- if (mlir::isa(C.getType()) &&
- mlir::cast(C.getType()).getRank() != 1)
- return false;
-
- // Input tensor must be of rank 2.
- // Weights must also be of rank 2.
- if (AType.getRank() != 2 || BType.getRank() != 2)
- return false;
-
- // Both alpha and beta must be 1.
- if ((adaptor.getAlpha().convertToFloat() != 1.0F) ||
- (adaptor.getBeta().convertToFloat() != 1.0F))
- return false;
-
- // Only Transpose B must be enabled.
- if (adaptor.getTransA() != 0 || adaptor.getTransB() != 1)
- return false;
-
- // If all check passed, we replace the GEMM by a FC operator
- Type resultType = getTypeConverter()->convertType(op.getResult().getType());
-
- // Because the bias is not broadcastable for TOSA while it is for ONNX,
- // we create an empty bias and use an add (broadcastable for tosa)
- // afterwards.
- // Base dummy C shape on B[0] shape.
- bool needsBroadcasting = !hasCCorrectShape(AType, BType, C);
- Value dummyC = C;
- if (!isCPresent || needsBroadcasting) {
- ArrayRef cformat(
- mlir::cast(resultType).getShape()[1]);
- std::vector elements = {};
- for (int i = 0; i < cformat[0]; ++i)
- elements.push_back(0.0F);
- dummyC = tosaBuilder.getConst(elements, cformat);
- }
-
- Value fcRes = tosa::CreateOpAndInfer(
- rewriter, op->getLoc(), resultType, A, B, dummyC)
- .getResult();
- // If C was present in the original GEMM, we create an add to take the bias
- // into account.
- if (isCPresent && needsBroadcasting)
- fcRes = tosaBuilder.binaryOp(fcRes, C);
-
- rewriter.replaceOp(op, fcRes);
-
- return true;
- }
};
} // namespace
diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
index 321a2b35e2..1ec1dec493 100644
--- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
+++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
@@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
@@ -60,8 +61,8 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
}
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());
TosaBuilder tosaBuilder(rewriter, loc);
- return tosaBuilder.getConst(
- tosaPads, {static_cast(tosaPads.size())});
+
+ return mlir::tosa::getTosaConstShape(rewriter, loc, tosaPads);
}
} // namespace tosa
diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
index bcd5c7c128..6b00198e17 100644
--- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
+++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
@@ -45,6 +45,7 @@ T getValueFromTosaConst(mlir::Value &val) {
template
TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Type result_ty, Args &&... args) {
+
auto op = rewriter.create(loc, result_ty, args...);
mlir::InferShapedTypeOpInterface shapeInterface =
@@ -64,6 +65,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
+ assert(returnedShapes.size() >= 1 && "Expected at least one returned shape");
auto predictedShape = returnedShapes[0];
if (predictedShape.hasRank())
updateType(nullptr, op, predictedShape.getDims(),
diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp
index 3f3269a029..abc5afd443 100644
--- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp
+++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
@@ -35,23 +36,24 @@ struct ScaleHelper {
};
// Adapted from TFL to TOSA.
-ScaleHelper normalize(int64_t output, int64_t input, bool pytorchHalfPixel,
- bool alignCorners, bool halfPixel, bool isNearest,
- bool isNearestModeFloor) {
- int64_t numerator, denominator, offset, border;
+ScaleHelper normalize(int64_t numerator, int64_t denominator, int64_t inputSize,
+ int64_t outputSize, bool pytorchHalfPixel, bool alignCorners,
+ bool halfPixel, bool isNearest, bool isNearestModeFloor) {
+ int64_t offset, border;
// Test if pytorch_half_pixel needs special handling
- if (pytorchHalfPixel && output == 1) {
+ if (pytorchHalfPixel && outputSize == 1) {
numerator = 1;
denominator = 1;
offset = -1;
- border = denominator * (output - 1) - numerator * (input - 1) + offset;
+ border =
+ denominator * (outputSize - 1) - numerator * (inputSize - 1) + offset;
return ScaleHelper(numerator, denominator, offset, border);
}
// Apply if aligned and capable to be aligned.
- bool applyAligned = alignCorners && (output > 1);
- numerator = applyAligned ? (output - 1) : output;
- denominator = applyAligned ? (input - 1) : input;
+ bool applyAligned = alignCorners && (numerator > 1);
+ numerator = applyAligned ? (numerator - 1) : numerator;
+ denominator = applyAligned ? (denominator - 1) : denominator;
// Simplify the scalers, make sure they are even values.
int gcd = std::gcd(numerator, denominator);
@@ -67,7 +69,8 @@ ScaleHelper normalize(int64_t output, int64_t input, bool pytorchHalfPixel,
}
// We can compute this directly based on previous values.
- border = denominator * (output - 1) - numerator * (input - 1) + offset;
+ border =
+ denominator * (outputSize - 1) - numerator * (inputSize - 1) + offset;
return ScaleHelper(numerator, denominator, offset, border);
};
@@ -197,7 +200,8 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
// With only static dimensions, scales and sizes as inputs are not relevant
// anymore.
- if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
+ if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3) ||
+ resultType.isDynamicDim(2) || resultType.isDynamicDim(3)) {
return rewriter.notifyMatchFailure(
resizeOp, "Only static sized tensors are supported.");
}
@@ -219,6 +223,11 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
"TOSA does not support ceil and round_prefer_floor as nearestMode.");
}
+ if (mode == "linear" && isa(elementType)) {
+ return rewriter.notifyMatchFailure(resizeOp,
+ "linear interpolation for integer types is not implemented");
+ }
+
// This also makes roi as an input irrelevant.
if (coordinateTransformationMode == "tf_crop_and_resize") {
return rewriter.notifyMatchFailure(
@@ -245,10 +254,15 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
}
// Set these explicitly just out of convenience.
- int64_t inputHeight = inputShape[2];
- int64_t inputWidth = inputShape[3];
- int64_t outputHeight = outputShape[2];
- int64_t outputWidth = outputShape[3];
+ const int64_t inputHeight = inputShape[2];
+ const int64_t inputWidth = inputShape[3];
+ const int64_t outputHeight = outputShape[2];
+ const int64_t outputWidth = outputShape[3];
+
+ int64_t denominatorHeight = inputHeight;
+ int64_t numeratorHeight = outputHeight;
+ int64_t denominatorWidth = inputWidth;
+ int64_t numeratorWidth = outputWidth;
// Check if scales are set. We need to get those float values, because they
// make a difference in linear interpolation.
@@ -262,10 +276,10 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
// In TOSA the scale is a fraction of two integer numbers.
FractionNumber height(scales[2]);
FractionNumber width(scales[3]);
- outputHeight = height.numerator;
- inputHeight = height.denominator;
- outputWidth = width.numerator;
- inputWidth = width.denominator;
+ numeratorHeight = height.numerator;
+ denominatorHeight = height.denominator;
+ numeratorWidth = width.numerator;
+ denominatorWidth = width.denominator;
}
bool alignCorners = coordinateTransformationMode == "align_corners";
@@ -284,23 +298,25 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
"TOSA does not support float offsets which are required "
"for symmetric mode.");
- ScaleHelper yDimension =
- normalize(outputHeight, inputHeight, pytorchHalfPixel, alignCorners,
- halfPixel, isNearest, isNearestModeFloor);
- ScaleHelper xDimension =
- normalize(outputWidth, inputWidth, pytorchHalfPixel, alignCorners,
- halfPixel, isNearest, isNearestModeFloor);
+ ScaleHelper yDimension = normalize(numeratorHeight, denominatorHeight,
+ inputHeight, outputHeight, pytorchHalfPixel, alignCorners, halfPixel,
+ isNearest, isNearestModeFloor);
+ ScaleHelper xDimension = normalize(numeratorWidth, denominatorWidth,
+ inputWidth, outputWidth, pytorchHalfPixel, alignCorners, halfPixel,
+ isNearest, isNearestModeFloor);
// Convert input [N,IC,IH,IW] -> [N,IH,IW,IC]
Value newInput = tosaBuilder.transpose(input, {0, 2, 3, 1});
// Create resizeOp
- auto scale = rewriter.getDenseI64ArrayAttr({yDimension.numerator,
- yDimension.denominator, xDimension.numerator, xDimension.denominator});
- auto offset =
- rewriter.getDenseI64ArrayAttr({yDimension.offset, xDimension.offset});
- auto border =
- rewriter.getDenseI64ArrayAttr({yDimension.border, xDimension.border});
+ Value scale = mlir::tosa::getTosaConstShape(rewriter, loc,
+ {yDimension.numerator, yDimension.denominator, xDimension.numerator,
+ xDimension.denominator});
+ Value offset = mlir::tosa::getTosaConstShape(
+ rewriter, loc, {yDimension.offset, xDimension.offset});
+ Value border = mlir::tosa::getTosaConstShape(
+ rewriter, loc, {yDimension.border, xDimension.border});
+
auto resizeModeAttr = rewriter.getStringAttr(resizeMode);
Type newOutputType =
RankedTensorType::get(llvm::SmallVector(
@@ -327,4 +343,4 @@ void populateLoweringONNXResizeOpToTOSAPattern(ConversionTarget &target,
patterns.insert(ctx);
}
-} // namespace onnx_mlir
\ No newline at end of file
+} // namespace onnx_mlir
diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td
index c8220dfc53..92aaf330e4 100644
--- a/src/Dialect/Krnl/Krnl.td
+++ b/src/Dialect/Krnl/Krnl.td
@@ -567,7 +567,7 @@ def KrnlParallelClauseOp : Op {
}];
}
-def KrnlRoundEvenOp : Op {
+def KrnlRoundEvenOp : Op {
let summary = "Krnl round to nearest even operation";
let description = [{
Krnl round to nearest even operation. Accept scalar or vector float values.
@@ -578,7 +578,7 @@ def KrnlRoundEvenOp : Op {
let results = (outs FloatLike:$out);
}
-def KrnlErfOp : Op {
+def KrnlErfOp : Op {
let summary = "Krnl erf scalar operation";
let description = [{
Krnl erf scalar operation.
@@ -588,7 +588,7 @@ def KrnlErfOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlIsInfOp : Op {
+def KrnlIsInfOp : Op {
let summary = "Krnl isinf scalar operation";
let description = [{
Krnl isinf scalar operation.
@@ -598,7 +598,7 @@ def KrnlIsInfOp : Op {
let results = (outs I1:$out);
}
-def KrnlIsNaNOp : Op {
+def KrnlIsNaNOp : Op {
let summary = "Krnl isnan scalar operation";
let description = [{
Krnl isnan scalar operation.
@@ -608,7 +608,7 @@ def KrnlIsNaNOp : Op {
let results = (outs I1:$out);
}
-def KrnlAcosOp : Op {
+def KrnlAcosOp : Op {
let summary = "Krnl acos scalar operation";
let description = [{
Krnl acos scalar operation.
@@ -618,7 +618,7 @@ def KrnlAcosOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlAcoshOp : Op {
+def KrnlAcoshOp : Op {
let summary = "Krnl acosh scalar operation";
let description = [{
Krnl acosh scalar operation.
@@ -628,7 +628,7 @@ def KrnlAcoshOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlAsinOp : Op {
+def KrnlAsinOp : Op {
let summary = "Krnl asin scalar operation";
let description = [{
Krnl asin scalar operation.
@@ -638,7 +638,7 @@ def KrnlAsinOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlAsinhOp : Op {
+def KrnlAsinhOp : Op {
let summary = "Krnl asinh scalar operation";
let description = [{
Krnl asinh scalar operation.
@@ -648,7 +648,7 @@ def KrnlAsinhOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlAtanOp : Op {
+def KrnlAtanOp : Op {
let summary = "Krnl atan scalar operation";
let description = [{
Krnl atan scalar operation.
@@ -658,7 +658,7 @@ def KrnlAtanOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlAtanhOp : Op {
+def KrnlAtanhOp : Op {
let summary = "Krnl atanh scalar operation";
let description = [{
Krnl atanh scalar operation.
@@ -668,7 +668,7 @@ def KrnlAtanhOp : Op {
let results = (outs AnyFloat:$out);
}
-def KrnlTanOp : Op {
+def KrnlTanOp : Op {
let summary = "Krnl tan scalar operation";
let description = [{
Krnl tan scalar operation.
diff --git a/src/Dialect/ONNX/ElementsAttr/BType.cpp b/src/Dialect/ONNX/ElementsAttr/BType.cpp
index 8073d2a4e2..a6aa4b17f5 100644
--- a/src/Dialect/ONNX/ElementsAttr/BType.cpp
+++ b/src/Dialect/ONNX/ElementsAttr/BType.cpp
@@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) {
case BType::FLOAT : return b.getF32Type();
case BType::FLOAT16 : return b.getF16Type();
case BType::BFLOAT16 : return b.getBF16Type();
- case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType();
- case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType();
- case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type();
- case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType();
+ case BType::FLOAT8E4M3FN : return b.getType();
+ case BType::FLOAT8E4M3FNUZ : return b.getType();
+ case BType::FLOAT8E5M2 : return b.getType();
+ case BType::FLOAT8E5M2FNUZ : return b.getType();
default: llvm_unreachable("unsupported data type");
}
// clang-format on
@@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) {
[](auto btype) { return toBType::widetype>; });
}
-} // namespace onnx_mlir
\ No newline at end of file
+} // namespace onnx_mlir
diff --git a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
index 47a74a0093..56fd3c5ca8 100644
--- a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
+++ b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
@@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes(
return success();
ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {});
- return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext()));
+ return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext()));
return success();
}
diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
index a38ddfcb11..13308602cd 100644
--- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
@@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes(
ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {});
RankedTensorType xType = mlir::dyn_cast(getX().getType());
return shapeHelper.computeShapeAndUpdateType(
- FloatType::getF32(getContext()), xType.getEncoding());
+ Float32Type::get(getContext()), xType.getEncoding());
}
//===----------------------------------------------------------------------===//
diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
index 72399e09f0..53c7f50587 100644
--- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
@@ -50,21 +50,21 @@ Type getRandomNormalElementType(ONNXRandomNormalOp op) {
static_cast(op.getDtype());
if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
- return FloatType::getF16(op.getContext());
+ return Float16Type::get(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT) {
- return FloatType::getF32(op.getContext());
+ return Float32Type::get(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE) {
- return FloatType::getF64(op.getContext());
+ return Float64Type::get(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
- return FloatType::getBF16(op.getContext());
+ return BFloat16Type::get(op.getContext());
} else {
llvm_unreachable("dtype not supported for RandomNormal");
}
}
- return FloatType::getF32(op.getContext());
+ return Float32Type::get(op.getContext());
}
} // namespace
diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
index b02cc1f5ac..9f7942f17b 100644
--- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
@@ -52,19 +52,19 @@ LogicalResult ONNXRandomNormalLikeOp::verify() {
}
if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16 &&
- outputType != FloatType::getF16(getContext()))
+ outputType != Float16Type::get(getContext()))
return emitOpError("output tensor does not match float16 dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT &&
- outputType != FloatType::getF32(getContext()))
+ outputType != Float32Type::get(getContext()))
return emitOpError("output tensor does not match float dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE &&
- outputType != FloatType::getF64(getContext()))
+ outputType != Float64Type::get(getContext()))
return emitOpError("output tensor does not match double dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16 &&
- outputType != FloatType::getBF16(getContext()))
+ outputType != BFloat16Type::get(getContext()))
return emitOpError("output tensor does not match bfloat16 dtype.");
} else if (inputType != outputType) {
return emitOpError("output and input element types do not match.");
diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp
index 2260d778d3..389a89903b 100644
--- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp
+++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp
@@ -608,13 +608,13 @@ Type convertONNXTypeToMLIRType(
Builder &builder, onnx::TensorProto_DataType onnxType) {
switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
- return builder.getFloat8E4M3FNType();
+ return builder.getType();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
- return builder.getFloat8E4M3FNUZType();
+ return builder.getType();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
- return builder.getFloat8E5M2Type();
+ return builder.getType();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
- return builder.getFloat8E5M2FNUZType();
+ return builder.getType();
case onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return builder.getBF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp
index 7f27d19ebb..ae1ea165fd 100644
--- a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp
@@ -61,7 +61,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes(
IntegerType ui8Type =
IntegerType::get(getContext(), 8, IntegerType::Unsigned);
- FloatType f32Type = FloatType::getF32(getContext());
+ FloatType f32Type = Float32Type::get(getContext());
ONNXDynamicQuantizeLinearOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateTypes(
diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp
index 70ee132682..bfa487d74a 100644
--- a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp
@@ -54,10 +54,10 @@ std::vector ONNXConstantOp::resultTypeInference() {
} else if (auto attr = getSparseValueAttr()) {
type = mlir::cast(attr).getShapedType();
} else if (auto attr = getValueFloatAttr()) {
- type = RankedTensorType::get({}, FloatType::getF32(getContext()));
+ type = RankedTensorType::get({}, Float32Type::get(getContext()));
} else if (auto attr = getValueFloatsAttr()) {
int64_t size = attr.size();
- type = RankedTensorType::get({size}, FloatType::getF32(getContext()));
+ type = RankedTensorType::get({size}, Float32Type::get(getContext()));
} else if (auto attr = getValueIntAttr()) {
type = RankedTensorType::get({}, IntegerType::get(getContext(), 64));
} else if (auto attr = getValueIntsAttr()) {
diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp
index 6058adfcdb..773152fc52 100644
--- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp
@@ -99,7 +99,7 @@ std::vector ONNXConstantOfShapeOp::resultTypeInference() {
if (auto attr = getValueAttr()) {
elementType = mlir::cast(attr).getElementType();
} else {
- elementType = FloatType::getF32(getContext());
+ elementType = Float32Type::get(getContext());
}
return {UnrankedTensorType::get(elementType)};
}
@@ -125,7 +125,7 @@ LogicalResult ONNXConstantOfShapeOp::inferShapes(
} else {
// If 'value' attribute is not specified, it defaults to a tensor of
// value 0 and datatype float32.
- elementType = FloatType::getF32(getContext());
+ elementType = Float32Type::get(getContext());
llvm::SmallVector dims(1, 1);
auto tensorType = RankedTensorType::get(dims, elementType);
diff --git a/src/Dialect/ONNX/Transforms/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp
index d3c2698273..2d9f60a3c1 100644
--- a/src/Dialect/ONNX/Transforms/ConstProp.cpp
+++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp
@@ -1227,7 +1227,7 @@ void ConstPropONNXToONNXPass::runOnOperation() {
RewritePatternSet patterns(context);
getConstPropONNXToONNXPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
+ if (failed(applyPatternsGreedily(function, std::move(patterns))))
signalPassFailure();
}
diff --git a/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp
index 0e58963512..c802048103 100644
--- a/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp
+++ b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp
@@ -145,7 +145,7 @@ struct ONNXHybridTransformPass
config.maxNumRewrites =
maxNumRewritesOffset + maxNumRewritesMultiplier * numOps;
}
- if (failed(applyPatternsAndFoldGreedily(body, patterns, config))) {
+ if (failed(applyPatternsGreedily(body, patterns, config))) {
llvm::errs() << "\nWarning: onnx-hybrid-transform didn't converge with "
<< "max-num-rewrites-offset="
<< maxNumRewritesOffset.getValue() << ", "
diff --git a/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp b/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp
index bc43bfbb2b..64963115ca 100644
--- a/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp
+++ b/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp
@@ -48,7 +48,7 @@ class ShapeInferencePass
func::FuncOp f = getOperation();
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
- (void)applyPatternsAndFoldGreedily(f.getBody(), patterns, config);
+ (void)applyPatternsGreedily(f.getBody(), patterns, config);
inferFunctionReturnShapes(f);
}
diff --git a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp
index 397fe18cd5..0baec219a8 100644
--- a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp
+++ b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp
@@ -508,7 +508,7 @@ void SimplifyShapeRelatedOpsPass::topDownShapeSimplification(
config.useTopDownTraversal = true;
// Simplify shape-related ops.
- if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns))))
+ if (failed(applyPatternsGreedily(moduleOp, std::move(patterns))))
signalPassFailure();
}
diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
index 37d6d8b095..1af70b866e 100644
--- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
+++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
@@ -157,7 +157,7 @@ void registerMLIRPasses() {
return mlir::createLowerAffinePass();
});
mlir::registerPass([]() -> std::unique_ptr {
- return mlir::createConvertSCFToCFPass();
+ return mlir::createSCFToControlFlowPass();
});
mlir::registerPass([]() -> std::unique_ptr {
return mlir::createConvertVectorToLLVMPass();
diff --git a/src/Transform/LowerKrnlRegion.cpp b/src/Transform/LowerKrnlRegion.cpp
index d2d2733c71..19f1fcbc6e 100644
--- a/src/Transform/LowerKrnlRegion.cpp
+++ b/src/Transform/LowerKrnlRegion.cpp
@@ -70,7 +70,7 @@ class LowerKrnlRegionPass
RewritePatternSet patterns(&getContext());
patterns.insert(&getContext());
- if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
+ if (failed(applyPatternsGreedily(function, std::move(patterns))))
signalPassFailure();
}
};
diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir
index 2b56c8db2b..1db01a656a 100644
--- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir
+++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir
@@ -1,4 +1,4 @@
-// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s
+// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm -cse %s -split-input-file | FileCheck %s
// -----
@@ -15,11 +15,31 @@ func.func @test_zlow_softmax_constant_shape() -> () {
%work_area = memref.alloc() {alignment = 4096 : i64} : memref<8192xi8>
"zlow.softmax"(%input, %work_area, %shape, %res) {act_func = "ACT_NONE"} : (memref<1x1x1x1x32x64xf16>, memref<8192xi8>, memref<3xi64>, memref<1x1x1x1x32x64xf16>) -> ()
return
+}
+// CHECK: llvm.mlir.global internal constant @[[SHAPE_CONST_GLOBAL:.*]](dense<[1, 5, 10]> : tensor<3xi64>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i64>
+// CHECK-LABEL: llvm.func @test_zlow_softmax_constant_shape
+// CHECK-DAG: [[SHAPE_MEMREF_0:%.+]] = llvm.mlir.addressof @[[SHAPE_CONST_GLOBAL]] : !llvm.ptr
+// CHECK-DAG: [[SHAPE_MEMREF_1:%.+]] = llvm.bitcast [[SHAPE_MEMREF_0]] : !llvm.ptr to !llvm.ptr
+// CHECK-DAG: [[SHAPE_MEMREF_2:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: [[SHAPE_MEMREF_3:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_1]], [[SHAPE_MEMREF_2]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: [[SHAPE_MEMREF_4:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_1]], [[SHAPE_MEMREF_3]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: [[SHAPE_MEMREF_5:%.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK-NEXT: [[SHAPE_MEMREF_6:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_5]], [[SHAPE_MEMREF_4]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: [[SHAPE_MEMREF_7:%.+]] = llvm.mlir.constant(3 : index) : i64
+// CHECK-NEXT: [[SHAPE_MEMREF_8:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_7]], [[SHAPE_MEMREF_6]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: [[SHAPE_MEMREF_9:%.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK-NEXT: [[SHAPE_MEMREF_10:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_9]], [[SHAPE_MEMREF_8]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-LABEL: llvm.func @test_zlow_softmax_constant_shape() {{.*}} {
- // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(5 : i64) : i64
- // CHECK: %[[DIM2:.*]] = llvm.mlir.constant(10 : i64) : i64
- // CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0]], %[[DIM1]], %[[DIM2]]) vararg(!llvm.func) : (i64, i64, !llvm.ptr, i64, i64, i64) -> ()
+// ...
-}
+// CHECK: %[[SHAPE:.*]] = llvm.extractvalue [[SHAPE_MEMREF_10]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: %[[DIM0_0:.*]] = llvm.getelementptr %[[SHAPE]][0] : (!llvm.ptr) -> !llvm.ptr, i64
+// CHECK-NEXT: %[[DIM0_1:.*]] = llvm.load %[[DIM0_0]] : !llvm.ptr -> i64
+// CHECK-NEXT: %[[DIM1_0:.*]] = llvm.getelementptr %[[SHAPE]][1] : (!llvm.ptr) -> !llvm.ptr, i64
+// CHECK-NEXT: %[[DIM1_1:.*]] = llvm.load %[[DIM1_0]] : !llvm.ptr -> i64
+// CHECK-NEXT: %[[DIM2_0:.*]] = llvm.getelementptr %[[SHAPE]][2] : (!llvm.ptr) -> !llvm.ptr, i64
+// CHECK-NEXT: %[[DIM2_1:.*]] = llvm.load %[[DIM2_0]] : !llvm.ptr -> i64
+
+// ...
+
+// CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0_1]], %[[DIM1_1]], %[[DIM2_1]]) vararg(!llvm.func) : (i64, i64, !llvm.ptr, i64, i64, i64) -> ()
diff --git a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir
index f403151578..bb88e862a2 100644
--- a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir
+++ b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir
@@ -1,4 +1,4 @@
-// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zlow-stick-expansion %s -split-input-file | FileCheck %s
+// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zlow-stick-expansion --disable-memref-prefetch=false %s -split-input-file | FileCheck %s
// -----
diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir
index 9711b01c79..4d871b5260 100644
--- a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir
+++ b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir
@@ -179,7 +179,7 @@ func.func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>)
// CHECK-LABEL: @test_category_mapper_int64_to_string(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: [[LEN:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[MALLOC:%.+]] = llvm.call @malloc({{.*}}) : (i64) -> !llvm.ptr
- // CHECK: [[UNDEF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)>
+ // CHECK: [[UNDEF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
// CHECK: [[EV_1:%.+]] = llvm.insertvalue {{.*}}, [[UNDEF]][0]
// CHECK: [[EV_2:%.+]] = llvm.insertvalue {{.*}}, [[EV_1]][1]
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64
@@ -222,7 +222,7 @@ func.func private @test_krnl_global_with_129_elements() -> memref<129x!krnl.stri
// CHECK: llvm.func @test_krnl_global_with_129_elements() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface, sym_visibility = "private"} {
// CHECK: [[VAR_0_1_:%.+]] = llvm.mlir.addressof @cats_strings : !llvm.ptr
// CHECK-DAG: [[VAR_1_1_:%.+]] = llvm.bitcast [[VAR_0_1_]] : !llvm.ptr to !llvm.ptr
- // CHECK-DAG: [[VAR_2_1_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: [[VAR_2_1_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[VAR_3_1_:%.+]] = llvm.insertvalue [[VAR_1_1_]], [[VAR_2_1_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_4_1_:%.+]] = llvm.insertvalue [[VAR_1_1_]], [[VAR_3_1_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_5_1_:%.+]] = llvm.mlir.constant(0 : index) : i64
diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir
index c736b12cf4..65a9bf0f73 100644
--- a/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir
+++ b/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir
@@ -11,7 +11,7 @@ func.func @test_krnl_global_constant_alignment() -> memref<3xf32> {
// CHECK-LABEL: llvm.func @test_krnl_global_constant_alignment() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface} {
// CHECK: [[VAR_0_:%.+]] = llvm.mlir.addressof @constant : !llvm.ptr
// CHECK-DAG: [[VAR_1_:%.+]] = llvm.bitcast [[VAR_0_]] : !llvm.ptr to !llvm.ptr
-// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[VAR_3_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_2_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_4_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_3_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_5_:%.+]] = llvm.mlir.constant(0 : index) : i64
@@ -37,7 +37,7 @@ func.func @test_krnl_global_constant_no_alignment() -> memref<2xi64> {
// CHECK-LABEL: llvm.func @test_krnl_global_constant_no_alignment() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface} {
// CHECK: [[VAR_0_:%.+]] = llvm.mlir.addressof @constant : !llvm.ptr
// CHECK-DAG: [[VAR_1_:%.+]] = llvm.bitcast [[VAR_0_]] : !llvm.ptr to !llvm.ptr
-// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[VAR_3_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_2_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_4_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_3_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: [[VAR_5_:%.+]] = llvm.mlir.constant(0 : index) : i64
diff --git a/test/mlir/conversion/krnl_to_llvm/reshape.mlir b/test/mlir/conversion/krnl_to_llvm/reshape.mlir
index 97d5374ec5..80edf5c5a6 100644
--- a/test/mlir/conversion/krnl_to_llvm/reshape.mlir
+++ b/test/mlir/conversion/krnl_to_llvm/reshape.mlir
@@ -7,7 +7,7 @@ func.func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tens
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: llvm.func @test_reshape
-// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_1_:%.+]] = llvm.insertvalue {{.*}}, [[OLD_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_2_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_1_]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_3_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_2_]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -17,7 +17,7 @@ func.func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tens
// CHECK-DAG:[[INSERT_7_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_6_]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// COM: Check that there is no copy but only a new MemRef with a new view, i.e. new sizes and strides.
-// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
+// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: [[INSERT_8_:%.+]] = llvm.insertvalue {{.*}}, [[NEW_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK-DAG: [[INSERT_9_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_8_]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64
diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir
index 6b2e395407..0baf28f0ab 100644
--- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir
+++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir
@@ -290,14 +290,11 @@ func.func private @test_reducesum1(%arg0: tensor<3x2x2xf32>, %arg1: tensor [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_2_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_4_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[VAR_4_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index
+// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index
// CHECK-DAG: [[VAR_6_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi eq, [[VAR_6_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index
+// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_6_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index
+// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<3x2x2xf32>
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]], [[VAR_8_]], [[VAR_11_]]{{.}} : memref<3x1x2xf32>
// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
@@ -348,14 +345,11 @@ func.func @test_reducesum2(%arg0: tensor<3x2x2xf32>, %arg1: tensor) -> te
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_3_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_5_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[VAR_5_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index
+// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index
// CHECK-DAG: [[VAR_7_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi eq, [[VAR_7_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index
+// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index
+// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_1_]]#0, [[VAR_3_1_]]#1, [[VAR_3_1_]]#2] : memref<3x2x2xf32>
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]], [[VAR_9_]], [[VAR_12_]]{{.}} : memref<3x1x2xf32>
// CHECK: [[VAR_15_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir
index f35603fc9e..59c67b8377 100644
--- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir
+++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir
@@ -245,14 +245,11 @@ func.func private @test_reducesum1(%arg0: tensor<3x2x2xf32>, %arg1: tensor [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_2_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_4_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[VAR_4_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index
+// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index
// CHECK-DAG: [[VAR_6_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi eq, [[VAR_6_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index
-// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index
+// CHECK: [[VAR_8_:%.+]] = arith.select [[VAR_6_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index
+// CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1>
+// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<3x2x2xf32>
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]], [[VAR_8_]], [[VAR_11_]]{{.}} : memref<3x1x2xf32>
// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
@@ -303,14 +300,11 @@ func.func @test_reducesum2(%arg0: tensor<3x2x2xf32>, %arg1: tensor) -> te
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_3_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_5_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[VAR_5_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index
+// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index
// CHECK-DAG: [[VAR_7_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi eq, [[VAR_7_1_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index
+// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1
-// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index
+// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_1_]]#0, [[VAR_3_1_]]#1, [[VAR_3_1_]]#2] : memref<3x2x2xf32>
// CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]], [[VAR_9_]], [[VAR_12_]]{{.}} : memref<3x1x2xf32>
// CHECK: [[VAR_15_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir
index 336fa49a2a..c309eacffe 100644
--- a/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir
+++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir
@@ -14,15 +14,13 @@ func.func @compress_axis0(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor
// CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<3xi1>) -> memref {
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref
// CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){
// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1
-// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index
+// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index
// CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref
@@ -34,8 +32,7 @@ func.func @compress_axis0(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor
// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 3){
// CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_1_]], [[VAR_false_]] : i1
-// CHECK: scf.if [[VAR_7_1_]] {
+// CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] {
// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){
@@ -64,15 +61,13 @@ func.func @compress_axis0_not_enough(%arg0: tensor<3x2xf32>, %arg1: tensor<2xi1>
// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref
// CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){
// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<2xi1>
-// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1
-// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index
+// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index
// CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref
@@ -86,8 +81,7 @@ func.func @compress_axis0_not_enough(%arg0: tensor<3x2xf32>, %arg1: tensor<2xi1>
// CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = arith.cmpi slt, [[VAR_5_1_]], [[VAR_c2_]] : index
// CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] {
// CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<2xi1>
-// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1
-// CHECK: scf.if [[VAR_8_1_]] {
+// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] {
// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){
@@ -116,15 +110,13 @@ func.func @compress_axis1(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor
// CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<3xi1>) -> memref<3x?xf32> {
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref
// CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){
// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1
-// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index
+// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index
// CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref
@@ -136,8 +128,7 @@ func.func @compress_axis1(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor
// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 2){
// CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_1_]], [[VAR_false_]] : i1
-// CHECK: scf.if [[VAR_7_1_]] {
+// CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] {
// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 3){
@@ -166,15 +157,13 @@ func.func @compress_no_axis_not_elided(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi
// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref
// CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){
// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_6_]]{{.}} : memref<3xi1>
-// CHECK: [[VAR_8_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1
-// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_c1_]], [[VAR_c0_]] : index
+// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: [[VAR_11_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_9_]] : index
// CHECK: krnl.store [[VAR_11_]], [[RES_]][] : memref
@@ -191,8 +180,7 @@ func.func @compress_no_axis_not_elided(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi
// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi slt, [[LOAD_CONDITION_MEM_1_]], [[VAR_c3_]] : index
// CHECK: scf.if [[VAR_8_1_]] {
// CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[LOAD_CONDITION_MEM_1_]]{{.}} : memref<3xi1>
-// CHECK: [[LOAD_RES_MEM_2_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1
-// CHECK: scf.if [[LOAD_RES_MEM_2_]] {
+// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] {
// CHECK-DAG: [[LOAD_INPUT_MEM_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_6_1_]]#0, [[VAR_6_1_]]#1] : memref<3x2xf32>
// CHECK-DAG: [[LOAD_RES_MEM_3_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: krnl.store [[LOAD_INPUT_MEM_]], [[RES_1_]]{{.}}[[LOAD_RES_MEM_3_]]{{.}} : memref
@@ -218,15 +206,13 @@ func.func @compress_no_axis_enough_cond(%arg0: tensor<3x2xf32>, %arg1: tensor<6x
// CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<6xi1>) -> memref {
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref
// CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){
// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_6_]]{{.}} : memref<6xi1>
-// CHECK: [[VAR_8_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1
-// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_c1_]], [[VAR_c0_]] : index
+// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: [[VAR_11_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_9_]] : index
// CHECK: krnl.store [[VAR_11_]], [[RES_]][] : memref
@@ -241,8 +227,7 @@ func.func @compress_no_axis_enough_cond(%arg0: tensor<3x2xf32>, %arg1: tensor<6x
// CHECK-DAG: [[VAR_6_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[RES_2_]][] : memref
// CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[LOAD_CONDITION_MEM_1_]]{{.}} : memref<6xi1>
-// CHECK: [[VAR_9_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1
-// CHECK: scf.if [[VAR_9_1_]] {
+// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] {
// CHECK-DAG: [[VAR_11_1_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_6_1_]]#0, [[VAR_6_1_]]#1] : memref<3x2xf32>
// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref
// CHECK: krnl.store [[VAR_11_1_]], [[RES_1_]]{{.}}[[LOAD_RES_MEM_2_]]{{.}} : memref
diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir
index 0437de4625..d754e00f0c 100644
--- a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir
+++ b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir
@@ -11,7 +11,8 @@ func.func @test_onnx_conv2d_stride_13(%arg0: tensor<5x3x256x256xf32>, %arg1 : te
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32>
// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x15x15x2xf32>
+// CHECK: %[[VAL_6_0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]], %[[VAL_6_0]], %[[VAL_6_0]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x15x15x2xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x2xf32>, tensor<4xi32>) -> tensor<5x2x15x15xf32>
// CHECK: return %[[VAL_8]] : tensor<5x2x15x15xf32>
@@ -29,7 +30,8 @@ func.func @test_onnx_conv2d_novalue(%arg0: tensor<5x3x256x256xf32>, %arg1 : tens
// CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x197x199x2xf32>
+// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x197x199x2xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x197x199x2xf32>, tensor<4xi32>) -> tensor<5x2x197x199xf32>
// CHECK: return %[[VAL_8]] : tensor<5x2x197x199xf32>
@@ -47,7 +49,8 @@ func.func @test_onnx_conv2d_no_dilation_pad(%arg0: tensor<5x3x256x256xf32>, %arg
// CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<7x3x64x64xf32>, tensor<4xi32>) -> tensor<7x64x64x3xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<7xf32>}> : () -> tensor<7xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>) -> tensor<5x15x15x7xf32>
+// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x15x15x7xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x7xf32>, tensor<4xi32>) -> tensor<5x7x15x15xf32>
// CHECK: return %[[VAL_8]] : tensor<5x7x15x15xf32>
@@ -65,7 +68,8 @@ func.func @test_onnx_conv2d_no_dilation_pad_stride(%arg0: tensor<5x3x256x260xf32
// CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x260xf32>, tensor<4xi32>) -> tensor<5x256x260x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<2x3x60x64xf32>, tensor<4xi32>) -> tensor<2x60x64x3xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x260x3xf32>, tensor<2x60x64x3xf32>, tensor<2xf32>) -> tensor<5x197x197x2xf32>
+// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x260x3xf32>, tensor<2x60x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x197x197x2xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x197x197x2xf32>, tensor<4xi32>) -> tensor<5x2x197x197xf32>
// CHECK: return %[[VAL_8]] : tensor<5x2x197x197xf32>
@@ -82,22 +86,36 @@ func.func @test_onnx_conv2d_group(%arg0: tensor<5x64x256x256xf32>, %arg1 : tenso
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x64x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x64xf32>
// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<12x16x45x45xf32>, tensor<4xi32>) -> tensor<12x45x45x16xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32>
-// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32>
-// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32>
-// CHECK: %[[VAL_9:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32>
-// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32>
-// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32>
-// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32>
-// CHECK: %[[VAL_13:.*]] = tosa.conv2d %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32>
-// CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32>
-// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32>
-// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32>
-// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32>
-// CHECK: %[[VAL_18:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32>
-// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32>
-// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32>
-// CHECK: %[[VAL_21:.*]] = tosa.conv2d %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32>
+// CHECK: %[[STARTS_0:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[SIZES_0:.*]] = tosa.const_shape {value = dense<[5, 256, 256, 16]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_0]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32>
+// CHECK: %[[SIZES_1:.*]] = tosa.const_shape {value = dense<[3, 45, 45, 16]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_0]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32>
+// CHECK: %[[STARTS_2:.*]] = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK: %[[THREE:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_2]], %[[STARTS_2]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32>
+// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_9:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32>
+// CHECK: %[[STARTS_3:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 16]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_3]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32>
+// CHECK: %[[STARTS_4:.*]] = tosa.const_shape {value = dense<[3, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_4]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32>
+// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_2]], %[[THREE]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32>
+// CHECK: %[[VAL_13:.*]] = tosa.conv2d %[[VAL_10]], %[[VAL_11]], %[[VAL_12]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32>
+// CHECK: %[[STARTS_5:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_5]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32>
+// CHECK: %[[STARTS_6:.*]] = tosa.const_shape {value = dense<[6, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_6]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32>
+// CHECK: %[[SIX:.*]] = tosa.const_shape {value = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_2]], %[[SIX]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32>
+// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32>
+// CHECK: %[[STARTS_7:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 48]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_18:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_7]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32>
+// CHECK: %[[STARTS_8:.*]] = tosa.const_shape {value = dense<[9, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_8]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32>
+// CHECK: %[[NINE:.*]] = tosa.const_shape {value = dense<9> : tensor<1xindex>} : () -> !tosa.shape<1>
+// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_2]], %[[NINE]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32>
+// CHECK: %[[VAL_21:.*]] = tosa.conv2d %[[VAL_18]], %[[VAL_19]], %[[VAL_20]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32>
// CHECK: %[[VAL_22:.*]] = tosa.concat %[[VAL_9]], %[[VAL_13]], %[[VAL_17]], %[[VAL_21]] {axis = 3 : i32} : (tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>) -> tensor<5x17x17x12xf32>
// CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<5x17x17x12xf32>, tensor<4xi32>) -> tensor<5x12x17x17xf32>
@@ -115,8 +133,9 @@ func.func @test_onnx_conv2d_autopad(%arg0: tensor<5x3x125x256xf32>, %arg1 : tens
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x125x256xf32>, tensor<4xi32>) -> tensor<5x125x256x3xf32>
// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32>
-// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x125x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x125x256x2xf32>
+// CHECK-DAG: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x125x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x125x256x2xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x125x256x2xf32>, tensor<4xi32>) -> tensor<5x2x125x256xf32>
// CHECK: return %[[VAL_8]] : tensor<5x2x125x256xf32>
-}
\ No newline at end of file
+}
diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
index 623ef3fe5f..50957fd565 100644
--- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
+++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
@@ -5,7 +5,7 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_relu
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
-// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32>
+// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32>
// CHECK-NEXT: }
}
@@ -17,7 +17,7 @@ func.func @test_relu_dynamic(%arg0 : tensor) -> tensor<*xf32> {
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: func @test_relu_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor {
-// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor
+// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor) -> tensor
// CHECK-NEXT: return [[VAR_0_]] : tensor
// CHECK-NEXT: }
}
@@ -60,7 +60,8 @@ func.func @test_add_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func.func @test_add_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
-// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32>
+// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: [[VAR_1_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
}
@@ -83,7 +84,8 @@ func.func @test_sub_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func.func @test_sub_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
-// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32>
+// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: [[VAR_1_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
}
@@ -106,7 +108,8 @@ func.func @test_div_broadcast(%arg0: tensor<13x21x1xi32>, %arg1: tensor<1xi32>)
"func.return"(%0) : (tensor<13x21x1xi32>) -> ()
// CHECK-LABEL: func @test_div_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<13x21x1xi32> {
-// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xi32>, !tosa.shape<3>) -> tensor<1x1x1xi32>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.int_div [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi32>, tensor<1x1x1xi32>) -> tensor<13x21x1xi32>
}
@@ -118,7 +121,8 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1
// CHECK-LABEL: func @test_div_decomposed
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
-// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
+// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
}
// -----
@@ -129,6 +133,8 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
// CHECK-LABEL: func @test_div_decomposed_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32>
-// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32>
-// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
+// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
+// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
}
diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir
deleted file mode 100644
index 5ccbd32a28..0000000000
--- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir
+++ /dev/null
@@ -1,44 +0,0 @@
-// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s
-
-func.func @gemm_to_fc(%arg0: tensor<1x5xf32>, %arg1: tensor<4x5xf32>, %arg2: tensor<4xf32>) -> tensor<1x4xf32> {
- %0 = "onnx.Gemm"(%arg0, %arg1, %arg2) {transB = 1 : si64} : (tensor<1x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<1x4xf32>
- return %0 : tensor<1x4xf32>
-// CHECK-LABEL: func.func @gemm_to_fc(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x5xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<4xf32>) -> tensor<1x4xf32> {
-// CHECK: %[[VAL_3:.*]] = tosa.fully_connected %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<1x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<1x4xf32>
-// CHECK: return %[[VAL_3]] : tensor<1x4xf32>
-// CHECK: }
-}
-
-// -----
-
-func.func @gemm_to_fc_broadcast(%arg0: tensor<2x5xf32>, %arg1: tensor<4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4xf32> {
- %0 = "onnx.Gemm"(%arg0, %arg1, %arg2) {transB = 1 : si64} : (tensor<2x5xf32>, tensor<4x5xf32>, tensor<1xf32>) -> tensor<2x4xf32>
- return %0 : tensor<2x4xf32>
-// CHECK-LABEL: func.func @gemm_to_fc_broadcast
-// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5xf32>, [[PARAM_1_:%.+]]: tensor<4x5xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>) -> tensor<2x4xf32> {
-// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
-// CHECK-DAG: [[VAR_1_:%.+]] = tosa.fully_connected [[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<2x4xf32>
-// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32>
-// CHECK: [[VAR_3_:%.+]] = tosa.add [[VAR_1_]], [[VAR_2_]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32>
-// CHECK: return [[VAR_3_]] : tensor<2x4xf32>
-// CHECK: }
-}
-
-// -----
-
-func.func @gemm_to_fc_opt(%arg0: tensor<1x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<1x4xf32> {
- %none = "onnx.NoValue"() {value} : () -> none
- %0 = "onnx.Gemm"(%arg0, %arg1, %none) {transB = 1 : si64} : (tensor<1x5xf32>, tensor<4x5xf32>, none) -> tensor<1x4xf32>
- return %0 : tensor<1x4xf32>
-// CHECK-LABEL: func.func @gemm_to_fc_opt(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x5xf32>) -> tensor<1x4xf32> {
-// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
-// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
-// CHECK: %[[VAL_4:.*]] = tosa.fully_connected %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : (tensor<1x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<1x4xf32>
-// CHECK: return %[[VAL_4]] : tensor<1x4xf32>
-// CHECK: }
-}
\ No newline at end of file
diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir
index 3654d493ea..10081ba077 100644
--- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir
+++ b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir
@@ -5,13 +5,17 @@ func.func @test_gemm_to_matmul(%arg0: tensor<3x5xf32>, %arg1: tensor<5x4xf32>, %
return %0 : tensor<3x4xf32>
// CHECK-LABEL: func.func @test_gemm_to_matmul
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x5xf32>, [[PARAM_1_:%.+]]: tensor<5x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> {
-// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x5xf32>) -> tensor<1x3x5xf32>
-// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<5x4xf32>) -> tensor<1x5x4xf32>
+// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x5xf32>, !tosa.shape<3>) -> tensor<1x3x5xf32>
+// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<5x4xf32>, !tosa.shape<3>) -> tensor<1x5x4xf32>
// CHECK-NOT: separator of consecutive DAGs
-// CHECK-DAG: [[VAR_2_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x5xf32>, tensor<1x5x4xf32>) -> tensor<1x3x4xf32>
-// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
+// CHECK: [[VAR_2_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x5xf32>, tensor<1x5x4xf32>) -> tensor<1x3x4xf32>
+// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_4_:%.+]] = tosa.add [[VAR_2_]], [[VAR_3_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
-// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[VAR_4_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
+// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[VAR_4_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32>
// CHECK: return [[VAR_5_]] : tensor<3x4xf32>
// CHECK: }
}
@@ -23,14 +27,19 @@ func.func @test_alpha(%arg0: tensor<3x6xf32>, %arg1: tensor<6x4xf32>, %arg2: ten
return %0 : tensor<3x4xf32>
// CHECK-LABEL: func.func @test_alpha
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x6xf32>, [[PARAM_1_:%.+]]: tensor<6x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> {
-// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
-// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x6x4xf32>
-// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
-// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
-// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
-// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
+// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
+// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<1x6x4xf32>
+// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
+// CHECK: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
+// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_4_]], [[VAR_5_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
-// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
+// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32>
// CHECK: return [[VAR_7_]] : tensor<3x4xf32>
// CHECK: }
}
@@ -42,15 +51,20 @@ func.func @test_beta(%arg0: tensor<3x6xf32>, %arg1: tensor<6x6xf32>, %arg2: tens
return %0 : tensor<3x6xf32>
// CHECK-LABEL: func.func @test_beta
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x6xf32>, [[PARAM_1_:%.+]]: tensor<6x6xf32>, [[PARAM_2_:%.+]]: tensor<3x6xf32>) -> tensor<3x6xf32> {
-// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
-// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x6xf32>) -> tensor<1x6x6xf32>
-// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
-// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
+// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
+// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x6xf32>, !tosa.shape<3>) -> tensor<1x6x6xf32>
+// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
// CHECK-NOT: separator of consecutive DAGs
-// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
-// CHECK-DAG: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32>
+// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
+// CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_4_]] : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
-// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array} : (tensor<1x3x6xf32>) -> tensor<3x6xf32>
+// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 6]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]], [[SHAPE_3]] : (tensor<1x3x6xf32>, !tosa.shape<2>) -> tensor<3x6xf32>
// CHECK: return [[VAR_7_]] : tensor<3x6xf32>
// CHECK: }
}
@@ -62,14 +76,18 @@ func.func @test_transa(%arg0: tensor<6x3xf32>, %arg1: tensor<6x4xf32>, %arg2: te
return %0 : tensor<3x4xf32>
// CHECK-LABEL: func.func @test_transa
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<6x3xf32>, [[PARAM_1_:%.+]]: tensor<6x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> {
-// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<6x3xf32>) -> tensor<1x6x3xf32>
-// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x6x4xf32>
-// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 6, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<6x3xf32>, !tosa.shape<3>) -> tensor<1x6x3xf32>
+// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<1x6x4xf32>
+// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: [[VAR_3_:%.+]] = tosa.transpose [[VAR_0_]], [[VAR_2_]] : (tensor<1x6x3xf32>, tensor<3xi32>) -> tensor<1x3x6xf32>
-// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
-// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array