Skip to content

Commit 25d1bf8

Browse files
committed
revert krnlops.
Signed-off-by: Haruki Imai <[email protected]>
1 parent ff9dc6e commit 25d1bf8

File tree

6 files changed

+390
-21
lines changed

6 files changed

+390
-21
lines changed

src/Conversion/KrnlToLLVM/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ add_onnx_mlir_library(OMKrnlToLLVM
66
KrnlFindIndex.cpp
77
KrnlCall.cpp
88
KrnlEntryPoint.cpp
9-
# KrnlGlobal.cpp
9+
KrnlGlobal.cpp
1010
KrnlInstrument.cpp
1111
KrnlMemcpy.cpp
1212
KrnlNone.cpp

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,7 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
962962
verifyInputTensors);
963963
krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx);
964964
krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx);
965+
krnl::populateLoweringKrnlGlobalOpPattern(typeConverter, patterns, ctx);
965966
krnl::populateLoweringConstantOpInterfacePattern(
966967
typeConverter, patterns, ctx);
967968
krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx);

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ void populateLoweringKrnlFindIndexOpPattern(
6868
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
6969
mlir::MLIRContext *ctx);
7070

71+
void populateLoweringKrnlGlobalOpPattern(mlir::LLVMTypeConverter &typeConverter,
72+
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
73+
7174
void populateLoweringConstantOpInterfacePattern(
7275
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
7376
mlir::MLIRContext *ctx);
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------ KrnlGlobal.cpp - Lower KrnlGlobalOp ---------------------------===//
6+
//
7+
// Copyright 2019-2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file lowers the KrnlGlobalOp operator.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include <fstream>
16+
17+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/IR/DialectResourceBlobManager.h"
20+
#include "llvm/ADT/TypeSwitch.h"
21+
#include "llvm/Support/Debug.h"
22+
#include "llvm/Support/FileSystem.h"
23+
24+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
25+
#include "src/Dialect/Mlir/DialectBuilder.hpp"
26+
#include "src/Support/KrnlSupport.hpp"
27+
28+
#define DEBUG_TYPE "krnl_to_llvm"
29+
30+
using namespace mlir;
31+
32+
namespace onnx_mlir {
33+
namespace krnl {
34+
35+
/// This variable is initizalied inside ConvertKrnlToLLVMPass.
36+
extern std::string EXTERNAL_CONSTANT_PREFIX;
37+
38+
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
39+
public:
40+
explicit KrnlGlobalOpLowering(
41+
LLVMTypeConverter &typeConverter, MLIRContext *context)
42+
: ConvertToLLVMPattern(
43+
KrnlGlobalOp::getOperationName(), context, typeConverter) {}
44+
45+
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
46+
ConversionPatternRewriter &rewriter) const override {
47+
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
48+
Location loc = krnlGlobalOp.getLoc();
49+
MLIRContext *context = krnlGlobalOp.getContext();
50+
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
51+
52+
// Basic type.
53+
Type llvmI8Ty = IntegerType::get(context, 8);
54+
Type llvmI8PtrTy = getPointerType(context, llvmI8Ty);
55+
56+
// The element type of the array.
57+
const Type type = op->getResult(0).getType();
58+
const MemRefType memRefTy = mlir::cast<mlir::MemRefType>(type);
59+
const Type constantElementType =
60+
typeConverter->convertType(memRefTy.getElementType());
61+
Type globalType = constantElementType;
62+
63+
// The llvm type of the global (example: [2 x [8 x float]]).
64+
const auto shape = mlir::dyn_cast<ArrayAttr>(krnlGlobalOp.getShape());
65+
if (shape.empty())
66+
globalType = LLVM::LLVMArrayType::get(mlir::cast<Type>(globalType), 1);
67+
else {
68+
for (int i = shape.size() - 1; i >= 0; i--)
69+
globalType = LLVM::LLVMArrayType::get(
70+
mlir::cast<Type>(globalType), ArrayAttrIntVal(shape, i));
71+
}
72+
73+
// Create the global at the entry of the module.
74+
LLVM::GlobalOp global;
75+
// Pointer to the raw data of the global.
76+
Value dataPtr;
77+
78+
if (krnlGlobalOp.getValue().has_value()) {
79+
auto value = krnlGlobalOp.getValue().value();
80+
TypeSwitch<Attribute>(value)
81+
.Case<DenseResourceElementsAttr>([&](DenseResourceElementsAttr attr) {
82+
global =
83+
lowerDenseResourceConstant(krnlGlobalOp, globalType, rewriter);
84+
})
85+
.Case<DenseElementsAttr>([&](DenseElementsAttr attr) {
86+
global = lowerDenseConstant(krnlGlobalOp, globalType, rewriter);
87+
})
88+
.Default([&](Attribute attr) {
89+
llvm_unreachable("Unsupported attribute type");
90+
});
91+
dataPtr = create.llvm.addressOf(global);
92+
} else {
93+
// Data are stored on files.
94+
global = lowerGlobalOpWithExternalFiles(krnlGlobalOp, rewriter);
95+
dataPtr = create.llvm.load(llvmI8PtrTy, create.llvm.addressOf(global));
96+
}
97+
98+
// Set the global alignment based on the alignment attribute if it exists,
99+
// otherwise use the module datalayout info.
100+
krnl::setAlignment(global, krnlGlobalOp.getAlignmentAttr(),
101+
krnlGlobalOp->getParentOfType<ModuleOp>(), rewriter,
102+
*getTypeConverter());
103+
104+
// Prepare data to be inserted into a MemRefDescriptor (a struct).
105+
MemRefDescriptor memRefDescr =
106+
createMemRefDescriptor(dataPtr, memRefTy, loc, rewriter);
107+
108+
rewriter.replaceOp(op, {memRefDescr});
109+
110+
return success();
111+
}
112+
113+
private:
114+
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
115+
return mlir::cast<IntegerAttr>(a.getValue()[i]).getInt();
116+
}
117+
118+
LLVM::GlobalOp lowerDenseResourceConstant(KrnlGlobalOp &krnlGlobalOp,
119+
Type globalType, ConversionPatternRewriter &rewriter) const {
120+
assert(krnlGlobalOp.getValue().has_value() &&
121+
"Expecting KrnlGlobalOp with a valid value");
122+
assert(
123+
mlir::isa<DenseResourceElementsAttr>(krnlGlobalOp.getValue().value()) &&
124+
"Expecting a global with an dense resource elements attribute");
125+
126+
MLIRContext *context = krnlGlobalOp.getContext();
127+
Location loc = krnlGlobalOp.getLoc();
128+
ModuleOp module = krnlGlobalOp->getParentOfType<ModuleOp>();
129+
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
130+
131+
OpBuilder::InsertionGuard insertGuard(rewriter);
132+
rewriter.setInsertionPointToStart(module.getBody());
133+
134+
auto blob =
135+
mlir::cast<DenseResourceElementsAttr>(krnlGlobalOp.getValue().value())
136+
.getRawHandle()
137+
.getBlob();
138+
assert(blob && "Expecting dense resource with a valid blob");
139+
ArrayRef<char> rawData = blob->getData();
140+
141+
// Check data size.
142+
uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp);
143+
assert(((uint64_t)rawData.size() == sizeInBytes) && "Data size mismatch.");
144+
145+
StringRef data(rawData.data(), rawData.size());
146+
StringAttr llvmStringAttr = StringAttr::get(context, data);
147+
auto llvmArrayI8Ty =
148+
LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes);
149+
LLVM::GlobalOp global = create.llvm.globalOp(llvmArrayI8Ty,
150+
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
151+
llvmStringAttr);
152+
153+
LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
154+
return global;
155+
}
156+
157+
LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType,
158+
ConversionPatternRewriter &rewriter) const {
159+
assert(krnlGlobalOp.getValue().has_value() &&
160+
"Expecting KrnlGlobalOp with a valid value");
161+
assert(mlir::isa<DenseElementsAttr>(krnlGlobalOp.getValue().value()) &&
162+
"Expecting a global with an dense elements attribute");
163+
164+
Location loc = krnlGlobalOp.getLoc();
165+
ModuleOp module = krnlGlobalOp->getParentOfType<ModuleOp>();
166+
MLIRContext *context = krnlGlobalOp.getContext();
167+
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
168+
169+
Type llvmI8Ty = IntegerType::get(context, 8);
170+
171+
OpBuilder::InsertionGuard insertGuard(rewriter);
172+
rewriter.setInsertionPointToStart(module.getBody());
173+
174+
DenseElementsAttr denseAttr =
175+
mlir::cast<DenseElementsAttr>(krnlGlobalOp.getValue().value());
176+
177+
uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp);
178+
LLVM::GlobalOp global;
179+
if (!(mlir::isa<StringType>(denseAttr.getElementType())) &&
180+
!(denseAttr.getElementType().isInteger(1)) && (!denseAttr.isSplat()) &&
181+
(sizeInBytes > 1024)) {
182+
183+
ArrayRef<char> rawData = denseAttr.getRawData();
184+
assert(
185+
((uint64_t)rawData.size() == sizeInBytes) && "Data size mismatch.");
186+
187+
auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, sizeInBytes);
188+
StringRef data(rawData.data(), rawData.size());
189+
StringAttr llvmStringAttr = StringAttr::get(context, data);
190+
global = create.llvm.globalOp(llvmArrayI8Ty,
191+
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
192+
llvmStringAttr);
193+
} else {
194+
if (mlir::isa<StringType>(denseAttr.getElementType()))
195+
global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter);
196+
else
197+
global = create.llvm.globalOp(globalType,
198+
/*isConstant=*/true, LLVM::Linkage::Internal,
199+
krnlGlobalOp.getName(), krnlGlobalOp.getValue().value());
200+
}
201+
202+
LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
203+
return global;
204+
}
205+
206+
LLVM::GlobalOp lowerGlobalOpWithExternalFiles(
207+
KrnlGlobalOp &krnlGlobalOp, ConversionPatternRewriter &rewriter) const {
208+
Location loc = krnlGlobalOp.getLoc();
209+
MLIRContext *context = krnlGlobalOp.getContext();
210+
ModuleOp module = krnlGlobalOp.getOperation()->getParentOfType<ModuleOp>();
211+
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
212+
213+
Type llvmI8Ty = IntegerType::get(context, 8);
214+
Type llvmI8PtrTy = getPointerType(context, llvmI8Ty);
215+
Type llvmI64Ty = IntegerType::get(context, 64);
216+
217+
auto offset = krnlGlobalOp.getOffset();
218+
assert(offset.has_value() && "Missing offset value in KrnlGlobalOp");
219+
220+
// Data is store in `constants.bin` at offset.
221+
std::string constantName = krnlGlobalOp.getName().str();
222+
223+
// Emit globals at the begining of the module.
224+
OpBuilder::InsertionGuard insertGuard(rewriter);
225+
rewriter.setInsertionPointToStart(module.getBody());
226+
227+
// Create an uninitialized global. Data will be loaded at runtime.
228+
LLVM::GlobalOp global = create.llvm.globalOp(llvmI8PtrTy,
229+
/*isConstant=*/false, LLVM::Linkage::Internal,
230+
EXTERNAL_CONSTANT_PREFIX + "data_" + constantName, nullptr);
231+
{
232+
OpBuilder::InsertionGuard insertGuard(rewriter);
233+
Region &region = global.getInitializerRegion();
234+
Block *block = rewriter.createBlock(&region);
235+
// Initialize an array with the addresses of the global op.
236+
rewriter.setInsertionPoint(block, block->begin());
237+
create.llvm._return(create.llvm.null(llvmI8PtrTy));
238+
}
239+
240+
// Create a global to store offset.
241+
create.llvm.globalOp(llvmI64Ty,
242+
/*isConstant=*/true, LLVM::Linkage::Internal,
243+
EXTERNAL_CONSTANT_PREFIX + "offset_" + constantName,
244+
rewriter.getI64IntegerAttr(offset.value()));
245+
246+
return global;
247+
}
248+
249+
uint64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const {
250+
// Compute total number of elements.
251+
const auto shape = mlir::dyn_cast<ArrayAttr>(krnlGlobalOp.getShape());
252+
uint64_t numElements = 1;
253+
for (unsigned int i = 0; i < shape.size(); ++i)
254+
numElements *= ArrayAttrIntVal(shape, i);
255+
256+
const auto type = krnlGlobalOp.getResult().getType();
257+
const auto memRefTy = mlir::cast<mlir::MemRefType>(type);
258+
259+
// Special handling for bool.
260+
if (memRefTy.getElementType().isInteger(1))
261+
return llvm::divideCeil(numElements, 8);
262+
263+
return numElements * getMemRefEltSizeInBytes(memRefTy);
264+
}
265+
266+
// Store the given address into a MemRefDescriptor (a struct).
267+
MemRefDescriptor createMemRefDescriptor(Value address, MemRefType memRefType,
268+
Location loc, OpBuilder &builder) const {
269+
Type elementType = memRefType.getElementType();
270+
const LLVMTypeConverter &typeConverter = *getTypeConverter();
271+
Type llvmElemType = typeConverter.convertType(elementType);
272+
MLIRContext *context = builder.getContext();
273+
MultiDialectBuilder<LLVMBuilder> create(builder, loc);
274+
275+
// Prepare data to be inserted into a MemRefDescriptor (a struct).
276+
auto ptrType = getPointerType(context, llvmElemType);
277+
// Bitcast the address to the MemRefType's element type.
278+
Value bitCastOp = create.llvm.bitcast(ptrType, address);
279+
// Create llvm MemRef from original MemRef and fill the data pointers.
280+
return MemRefDescriptor::fromStaticShape(
281+
builder, loc, typeConverter, memRefType, bitCastOp);
282+
}
283+
284+
// Generate a global string for each krnlGlobalOp string value, and store
285+
// the address of the global strings into an array. Return the array address.
286+
LLVM::GlobalOp lowerStringLiteral(
287+
KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const {
288+
assert(mlir::isa<DenseElementsAttr>(krnlGlobalOp.getValue().value()) &&
289+
"Expecting a dense value");
290+
291+
Location loc = krnlGlobalOp.getLoc();
292+
MultiDialectBuilder<LLVMBuilder> create(builder, loc);
293+
294+
DenseElementsAttr denseAttr =
295+
mlir::cast<DenseElementsAttr>(krnlGlobalOp.getValue().value());
296+
297+
Type i8PtrType = getI8PointerType(builder.getContext());
298+
299+
auto strs = denseAttr.getValues<StringRef>();
300+
// Collect total size of the strs.
301+
size_t totalSize = 0;
302+
for (StringRef str : strs) {
303+
// Add 1 for the null terminator.
304+
totalSize += str.size() + 1;
305+
}
306+
307+
// Concatenate all strings into one.
308+
std::vector<char> concatStr(totalSize);
309+
size_t offset = 0;
310+
std::vector<size_t> offsets;
311+
for (StringRef str : strs) {
312+
offsets.emplace_back(offset);
313+
std::copy(str.begin(), str.end(), concatStr.begin() + offset);
314+
concatStr[offset + str.size()] = '\0';
315+
offset += str.size() + 1;
316+
}
317+
318+
// Create a global for the concatenated string.
319+
StringRef data(concatStr.data(), concatStr.size());
320+
StringAttr llvmStringAttr = StringAttr::get(builder.getContext(), data);
321+
auto i8Type = IntegerType::get(builder.getContext(), 8);
322+
auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(i8Type, totalSize);
323+
LLVM::GlobalOp globalStr = create.llvm.globalOp(llvmArrayI8Ty,
324+
/*isConstant=*/true, LLVM::Linkage::Internal,
325+
"om.strArray." + krnlGlobalOp.getName().str(), llvmStringAttr);
326+
327+
// Generate an LLVM GlobalOps with an initializer region containing one
328+
// block.
329+
auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, offsets.size());
330+
auto global = create.llvm.globalOp(arrayType,
331+
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
332+
Attribute());
333+
Region &region = global.getInitializerRegion();
334+
Block *block = builder.createBlock(&region);
335+
336+
// Initialize an array with the addresses of the global strings.
337+
builder.setInsertionPoint(block, block->begin());
338+
Value array = builder.create<LLVM::UndefOp>(loc, arrayType);
339+
340+
int32_t index = 0;
341+
Value lastValue = array;
342+
Value baseAddr = create.llvm.addressOf(globalStr);
343+
// Cast globalStr to i8Ptr.
344+
baseAddr = create.llvm.bitcast(i8PtrType, baseAddr);
345+
for (size_t offset : offsets) {
346+
// Get each str with gep base, offset.
347+
Value gepOp = create.llvm.getElemPtr(
348+
i8PtrType, i8Type, baseAddr, {(int32_t)offset});
349+
lastValue =
350+
create.llvm.insertValue(arrayType, lastValue, gepOp, {index++});
351+
}
352+
353+
create.llvm._return(lastValue);
354+
return global;
355+
}
356+
};
357+
358+
void populateLoweringKrnlGlobalOpPattern(LLVMTypeConverter &typeConverter,
359+
RewritePatternSet &patterns, MLIRContext *ctx) {
360+
patterns.insert<KrnlGlobalOpLowering>(typeConverter, ctx);
361+
}
362+
363+
} // namespace krnl
364+
} // namespace onnx_mlir

0 commit comments

Comments
 (0)