Skip to content

Commit 178ec84

Browse files
authored
Clean up krnl.dim, krnl.shape, and krnl.getref (onnx#2733)
* Remove krnl.dim and krnl.shape Signed-off-by: Tung D. Le <[email protected]> * Clean up krnl.getref Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent 0507600 commit 178ec84

24 files changed

+53
-893
lines changed

docs/Dialects/krnl.md

-83
Original file line numberDiff line numberDiff line change
@@ -313,37 +313,6 @@ intend to optimize.
313313
| :----: | ----------- |
314314
&laquo;unnamed&raquo; | variadic of any type
315315

316-
### `krnl.dim` (KrnlDimOp)
317-
318-
_Krnl dimensions operation._
319-
320-
Emits the dimension of a MemRef independent of the MemRef alloc:
321-
322-
```
323-
"krnl.dim"(%memref, %index)
324-
```
325-
326-
The index identifies the dimension within the shape which is going to be emitted.
327-
Initially the krnl.dim operation depends on the alloc of the MemRef.
328-
Unlike the std.dim operation which maintains a dependency on the alloc of the MemRef, the dimension emitted by krnl.dim will not depend on the alloc operation of the MemRef once the krnl.dim operation is lowered.
329-
330-
Any changes to the original MemRef size after the krnl.dim has been lowered will not be picked up by the emitted dimension. This allows the original MemRef to be safely modified via code transformations or affine map normalization without the risk of changing the value already emitted via krnl.dim.
331-
332-
Traits: MemRefsNormalizable
333-
334-
#### Operands:
335-
336-
| Operand | Description |
337-
| :-----: | ----------- |
338-
| `alloc` | memref of any type values
339-
| `index` | index
340-
341-
#### Results:
342-
343-
| Result | Description |
344-
| :----: | ----------- |
345-
| `dimension` | index
346-
347316
### `krnl.entry_point` (KrnlEntryPointOp)
348317

349318
_Indicate ONNX entry point_
@@ -429,34 +398,6 @@ current tile being iterated over.
429398
| :----: | ----------- |
430399
| `ind_var_vals` | variadic of any type
431400

432-
### `krnl.getref` (KrnlGetRefOp)
433-
434-
_Krnl a MemRef from within another MemRef starting at a specific offset._
435-
436-
Retrieves a MemRef from within another MemRef:
437-
438-
```
439-
"krnl.getref"(%memref, %offset)
440-
```
441-
The offset is an integer which is used as an index into the input MemRef. It works
442-
just like an array index.
443-
444-
Traits: MemRefsNormalizable
445-
446-
#### Operands:
447-
448-
| Operand | Description |
449-
| :-----: | ----------- |
450-
| `mempool` | memref of any type values
451-
| `offset` | integer
452-
| `value` | variadic of index
453-
454-
#### Results:
455-
456-
| Result | Description |
457-
| :----: | ----------- |
458-
| `output` | memref of any type values
459-
460401
### `krnl.global` (KrnlGlobalOp)
461402

462403
_Krnl global operation_
@@ -1234,30 +1175,6 @@ Traits: MemRefsNormalizable
12341175
| `seq` | memref of any type values
12351176
| `index` | index
12361177

1237-
### `krnl.shape` (KrnlShapeOp)
1238-
1239-
_Krnl operation to retrieve the shape of a MemRef._
1240-
1241-
Extracts the shape of a MemRef:
1242-
```
1243-
"krnl.shape"(%memref)
1244-
```
1245-
The return result is of `shape.type`.
1246-
1247-
Traits: MemRefsNormalizable
1248-
1249-
#### Operands:
1250-
1251-
| Operand | Description |
1252-
| :-----: | ----------- |
1253-
| `alloc` | memref of any type values
1254-
1255-
#### Results:
1256-
1257-
| Result | Description |
1258-
| :----: | ----------- |
1259-
| `shape` | memref of any type values
1260-
12611178
### `krnl.specialized_kernel` (KrnlSpecializedKernel)
12621179

12631180
_Krnl specialized kernel op_

src/Compiler/CompilerPasses.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
193193
// from ONNX dialect to Standard dialect exposes additional canonicalization
194194
// opportunities.
195195
pm.addPass(mlir::createCanonicalizerPass());
196-
pm.addNestedPass<func::FuncOp>(
197-
onnx_mlir::createDisconnectKrnlDimFromAllocPass());
198-
pm.addPass(mlir::createCanonicalizerPass());
199196
}
200197

201198
void addKrnlToAffinePasses(mlir::PassManager &pm) {
@@ -315,4 +312,4 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
315312
addKrnlToLLVMPasses(pm, outputNameNoExt, /*enableCSE=*/true);
316313
}
317314

318-
} // namespace onnx_mlir
315+
} // namespace onnx_mlir

src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,6 @@ void ConvertKrnlToAffinePass::runOnOperation() {
773773
ConversionTarget target(*ctx);
774774
// Legal/illegal ops.
775775
target.addIllegalOp<KrnlTerminatorOp>();
776-
// krnl.dim operations must be lowered prior to this pass.
777-
target.addIllegalOp<KrnlDimOp>();
778776
target.addIllegalOp<KrnlMatMulOp>();
779777
target.addIllegalOp<KrnlCopyToBufferOp>();
780778
target.addIllegalOp<KrnlCopyFromBufferOp>();

src/Conversion/KrnlToLLVM/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ add_onnx_mlir_library(OMKrnlToLLVM
55
KrnlFindIndex.cpp
66
KrnlCall.cpp
77
KrnlEntryPoint.cpp
8-
KrnlGetRef.cpp
98
KrnlGlobal.cpp
109
KrnlInstrument.cpp
1110
KrnlMemcpy.cpp

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,6 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
956956
krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx);
957957
krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx);
958958
krnl::populateLoweringKrnlGlobalOpPattern(typeConverter, patterns, ctx);
959-
krnl::populateLoweringKrnlGetRefOpPattern(typeConverter, patterns, ctx);
960959
krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx);
961960
krnl::populateLoweringKrnlMemcpyOpPattern(typeConverter, patterns, ctx);
962961
krnl::populateLoweringKrnlPrintOpPattern(typeConverter, patterns, ctx);

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ void populateLoweringKrnlFindIndexOpPattern(
6767
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
6868
mlir::MLIRContext *ctx);
6969

70-
void populateLoweringKrnlGetRefOpPattern(mlir::LLVMTypeConverter &typeConverter,
71-
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
72-
7370
void populateLoweringKrnlGlobalOpPattern(mlir::LLVMTypeConverter &typeConverter,
7471
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
7572

src/Conversion/KrnlToLLVM/KrnlGetRef.cpp

-172
This file was deleted.

src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -257,20 +257,37 @@ struct ONNXCategoryMapperOpLowering
257257
.Case<IntegerType>(
258258
[&](IntegerType) { inputElem = createKrnl.load(memref, loopInd); })
259259
.Case<krnl::StringType>([&](krnl::StringType stringType) {
260-
MathBuilder createMath(createKrnl);
261-
Value zero = createMath.constant(
262-
createMath.getBuilder().getIntegerType(64), 0);
263260
ArrayRef<int64_t> shape =
264261
memref.getType().cast<ShapedType>().getShape();
265262
SmallVector<int64_t, 4> newShape;
266-
for (uint64_t i = 0; i < shape.size(); i++)
267-
newShape.emplace_back(
268-
(shape[i] == ShapedType::kDynamic) ? 1 : shape[i]);
269-
auto memRefType = MemRefType::get(
270-
newShape, krnl::StringType::get(elementType.getContext()));
271-
// Sole use of krnl.getRef.
272-
Value stringMemRef = createKrnl.getRef(memRefType, memref, zero);
273-
inputElem = createKrnl.load(stringMemRef, loopInd);
263+
bool hasDynamicDim = false;
264+
for (uint64_t i = 0; i < shape.size(); i++) {
265+
if (shape[i] == ShapedType::kDynamic) {
266+
newShape.emplace_back(1);
267+
hasDynamicDim = true;
268+
} else {
269+
newShape.emplace_back(shape[i]);
270+
}
271+
}
272+
if (!hasDynamicDim) {
273+
inputElem = createKrnl.load(memref, loopInd);
274+
} else {
275+
MemRefBuilder createMemRef(createKrnl);
276+
MemRefType memRefType = MemRefType::get(
277+
newShape, krnl::StringType::get(elementType.getContext()));
278+
SmallVector<int64_t, 4> offsets(shape.size(), 0);
279+
SmallVector<int64_t, 4> strides;
280+
int64_t alignmentOffset; // not used, just to make the function call
281+
// completed.
282+
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
283+
.failed())
284+
llvm_unreachable("Failed to get strides");
285+
Value stringMemRef =
286+
createMemRef
287+
.subView(memRefType, memref, offsets, newShape, strides)
288+
.getResult();
289+
inputElem = createKrnl.load(stringMemRef, loopInd);
290+
}
274291
})
275292
.Default([&](Type type) {
276293
llvm::errs() << "type: " << type << "\n";

src/Dialect/Krnl/DialectBuilder.cpp

-9
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,10 @@ void KrnlBuilder::matmul(Value A, ValueRange aStart, Value B, ValueRange bStart,
211211
globalUBs[1], globalUBs[2], simdize, unroll, overCompute);
212212
}
213213

214-
Value KrnlBuilder::dim(Type type, Value alloc, Value index) const {
215-
return b().create<KrnlDimOp>(loc(), type, alloc, index);
216-
}
217-
218214
KrnlMovableOp KrnlBuilder::movable() const {
219215
return b().create<KrnlMovableOp>(loc());
220216
}
221217

222-
KrnlGetRefOp KrnlBuilder::getRef(
223-
Type type, Value memref, Value offset, ValueRange indices) const {
224-
return b().create<KrnlGetRefOp>(loc(), type, memref, offset, indices);
225-
}
226-
227218
Value KrnlBuilder::constant(MemRefType type, StringRef name,
228219
std::optional<Attribute> value, std::optional<IntegerAttr> offset,
229220
std::optional<IntegerAttr> alignment) const {

0 commit comments

Comments
 (0)