Skip to content

Commit 599ab54

Browse files
committed
update float types, tosa, other misc changes
Signed-off-by: Boyana Norris <[email protected]>
1 parent be4a2b8 commit 599ab54

File tree

19 files changed

+83
-66
lines changed

19 files changed

+83
-66
lines changed

src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
3535
auto int16Ty = IntegerType::get(context, 16);
3636
auto int32Ty = IntegerType::get(context, 32);
3737
auto int64Ty = IntegerType::get(context, 64);
38-
auto float32Ty = FloatType::getF32(context);
38+
auto float32Ty = Float32Type::get(context);
3939

4040
// Declare API type as an enum value, its string name and an LLVM Type
4141
// specifying its signature.
@@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
570570
Type llvmI64Ty = IntegerType::get(context, 64);
571571
Type llvmI1Ty = IntegerType::get(context, 1);
572572
Type llvmI8Ty = IntegerType::get(context, 8);
573-
Type llvmF32Ty = FloatType::getF32(context);
573+
Type llvmF32Ty = Float32Type::get(context);
574574
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
575575
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
576576
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
@@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
662662
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
663663
create.llvm.store(recScale, recScalePtr);
664664
} else {
665-
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
665+
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
666666
create.llvm.store(zero, recScalePtr);
667667
}
668668

@@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
675675
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
676676
create.llvm.store(offset, offsetPtr);
677677
} else {
678-
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
678+
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
679679
create.llvm.store(zero, offsetPtr);
680680
}
681681

src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
8080
// or
8181
// (memref<3x4x5xf64>, index, f64, f64, f64)
8282
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
83-
Type llvmOptionsTy = FloatType::getF32(context);
83+
Type llvmOptionsTy = Float32Type::get(context);
8484
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
8585
if (inType.isF64()) {
86-
llvmOptionsTy = FloatType::getF64(context);
86+
llvmOptionsTy = Float64Type::get(context);
8787
llvmOutputTy = getPointerType(context, llvmOptionsTy);
8888
}
8989
Type llvmI64Ty = IntegerType::get(context, 64);

src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
172172
Type outType = op->getResultTypes().front();
173173
Type llvmInType, llvmOutType;
174174
if (inType.isF16())
175-
llvmInType = FloatType::getF16(context);
175+
llvmInType = Float16Type::get(context);
176176
else if (inType.isF32())
177-
llvmInType = FloatType::getF32(context);
177+
llvmInType = Float32Type::get(context);
178178
else if (inType.isF64())
179-
llvmInType = FloatType::getF64(context);
179+
llvmInType = Float64Type::get(context);
180180
else if (inType.isBF16())
181-
llvmInType = FloatType::getBF16(context);
181+
llvmInType = Float64Type::get(context);
182182
if (outType.isInteger(1))
183183
llvmOutType = IntegerType::get(context, 1);
184184
else if (outType.isF32())
185-
llvmOutType = FloatType::getF32(context);
185+
llvmOutType = Float32Type::get(context);
186186
else if (outType.isF64())
187-
llvmOutType = FloatType::getF64(context);
187+
llvmOutType = Float64Type::get(context);
188188

189189
// Insert and/or get reference to elementary math function declaration.
190190
assert(
@@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
214214
return SymbolRefAttr::get(context, mathFuncName);
215215

216216
// Create function declaration.
217-
// auto llvmF32Ty = FloatType::get(context);
218217
auto llvmFnType =
219218
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef<Type>({llvmInType}));
220219

src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {
6262

6363
// Get memRefDescriptor, the new memref descriptor.
6464
MemRefDescriptor memRefDescriptor =
65-
MemRefDescriptor::undef(rewriter, loc, targetStructType);
65+
MemRefDescriptor::poison(rewriter, loc, targetStructType);
6666
auto targetElementPtrType = memRefDescriptor.getElementPtrType();
6767

6868
// Set the new memref to the same buffer as the source memref.
@@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {
7878

7979
int64_t offset;
8080
SmallVector<int64_t, 4> strides;
81-
if (failed(getStridesAndOffset(targetType, strides, offset)))
81+
if (failed(targetType.getStridesAndOffset(strides, offset)))
8282
return failure();
8383

8484
// Unhandled dynamic offset.

src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
281281
SmallVector<int64_t, 4> strides;
282282
int64_t alignmentOffset; // not used, just to make the function call
283283
// completed.
284-
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
284+
if (memRefType.getStridesAndOffset(strides, alignmentOffset)
285285
.failed())
286286
llvm_unreachable("Failed to get strides");
287287
Value stringMemRef =

src/Conversion/ONNXToKrnl/Math/LRN.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
5252
float alphaLit = adaptor.getAlpha().convertToFloat();
5353
float betaLit = adaptor.getBeta().convertToFloat();
5454
int sizeLit = adaptor.getSize();
55-
auto f32Type = FloatType::getF32(rewriter.getContext());
55+
auto f32Type = Float32Type::get(rewriter.getContext());
5656
Value biasValue = create.math.constant(f32Type, biasLit);
5757
Value alphaDivSizeValue =
5858
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));

src/Conversion/ONNXToTOSA/DialectBuilder.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717

18+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1819
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
1920
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
2021
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {
147148

148149
Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef<int64_t> size,
149150
llvm::ArrayRef<int64_t> start) {
150-
DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
151-
DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
151+
auto startVal =
152+
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
153+
auto sizeVal =
154+
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
152155
Value newSliceInput =
153156
tosa::CreateOpAndInfer<mlir::tosa::SliceOp>(rewriter(), loc(),
154157
RankedTensorType::get(
155158
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
156159
mlir::cast<ShapedType>(inputConst.getType()).getElementType()),
157-
inputConst, startAttr, sizeAttr);
160+
inputConst, startVal, sizeVal);
158161
return newSliceInput;
159162
}
160163

@@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef<int64_t> shape) {
164167
Type newValueType = RankedTensorType::get(
165168
llvm::SmallVector<int64_t, 4>(shape.size(), ShapedType::kDynamic),
166169
valueType.getElementType());
167-
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(
168-
rewriter(), loc(), newValueType, value, shapeAttr);
170+
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter(), loc(),
171+
newValueType, value,
172+
mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
169173
}
170174

171175
Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
@@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
178182
Type newValueType = RankedTensorType::get(
179183
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
180184
lhsType.getElementType());
185+
186+
auto int8Type = rewriter().getI8Type();
187+
auto shiftValue =
188+
TosaBuilder::createConst(ArrayRef<int32_t>{shift}, {1}, int8Type);
181189
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
182-
rewriter(), loc(), newValueType, lhs, rhs, shift);
190+
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
183191
}
184192

185193
Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
@@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(Value &lhs, Value &rhs);
236244
// Return null if none is found.
237245
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
238246
auto definingOp = value.getDefiningOp();
239-
// If we have a cast between index/integer, skip it, i.e. get the defining op
240-
// that is the input to the cast.
247+
// If we have a cast between index/integer, skip it, i.e. get the defining
248+
// op that is the input to the cast.
241249
if (auto castOp = dyn_cast_or_null<arith::IndexCastOp>(definingOp)) {
242250
Value input = castOp.getIn();
243251
definingOp = input.getDefiningOp();

src/Conversion/ONNXToTOSA/Math/Elementwise.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
121121
// Quantized types are not supported right now (in type conversion).
122122
// Once they are, the input should be rescaled for quantized types. (TBD)
123123
// Maps to `tosa.clamp` which has both int and fp limits.
124-
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(op, op.getType(), input,
125-
rewriter.getI64IntegerAttr(0),
126-
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
127-
rewriter.getF32FloatAttr(0.0f),
128-
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
124+
auto inputElementType =
125+
llvm::cast<TensorType>(op.getType()).getElementType();
126+
if (llvm::isa<IntegerType>(inputElementType)) {
127+
auto minClamp = rewriter.getI64IntegerAttr(0);
128+
auto maxClamp =
129+
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max());
130+
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
131+
op, op.getType(), input, minClamp, maxClamp);
132+
} else {
133+
auto minClamp = rewriter.getF32FloatAttr(0.0f);
134+
auto maxClamp =
135+
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
136+
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
137+
op, op.getType(), input, minClamp, maxClamp);
138+
}
129139
return success();
130140
}
131141
};

src/Conversion/ONNXToTOSA/Math/Gemm.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1617
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
1718
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
1819
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
@@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
6768

6869
llvm::SmallVector<int64_t> dynamicTensorShape = {
6970
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};
70-
A = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
71+
72+
tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
7173
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
72-
rewriter.getDenseI64ArrayAttr(newShapeA))
73-
.getResult();
74+
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
75+
.getResult();
7476
B = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
7577
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
76-
rewriter.getDenseI64ArrayAttr(newShapeB))
78+
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
7779
.getResult();
7880

7981
// If transA or transB are present, create Transpose operators.

src/Dialect/ONNX/ElementsAttr/BType.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) {
5555
case BType::FLOAT : return b.getF32Type();
5656
case BType::FLOAT16 : return b.getF16Type();
5757
case BType::BFLOAT16 : return b.getBF16Type();
58-
case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType();
59-
case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType();
60-
case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type();
61-
case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType();
58+
case BType::FLOAT8E4M3FN : return b.getType<Float8E4M3FNType>();
59+
case BType::FLOAT8E4M3FNUZ : return b.getType<Float8E4M3FNUZType>();
60+
case BType::FLOAT8E5M2 : return b.getType<Float8E5M2Type>();
61+
case BType::FLOAT8E5M2FNUZ : return b.getType<Float8E5M2FNUZType>();
6262
default: llvm_unreachable("unsupported data type");
6363
}
6464
// clang-format on
@@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) {
104104
[](auto btype) { return toBType<typename BTypeTrait<btype>::widetype>; });
105105
}
106106

107-
} // namespace onnx_mlir
107+
} // namespace onnx_mlir

0 commit comments

Comments
 (0)