Skip to content

LLVM update 43d71ba #3086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
85adea4
update float types, tosa, other misc changes
brnorris03 Feb 21, 2025
575843d
fix buildOnnxToTosaPaddingConstOp
brnorris03 Feb 22, 2025
4b8bc70
fix lit tests (wip)
brnorris03 Feb 22, 2025
7015679
updte doc
brnorris03 Feb 22, 2025
65ff3c0
use stablehlo tagged version
brnorris03 Feb 22, 2025
685c3f5
fixed more lit tests
brnorris03 Feb 23, 2025
fa9186e
fix .clang-format
brnorris03 Feb 23, 2025
09ca5c1
fix lit (wip)
brnorris03 Feb 23, 2025
8fc25e5
revert .clang-format change
brnorris03 Feb 23, 2025
96b0354
fix lit tests
brnorris03 Feb 24, 2025
8eec54d
fix formatting
brnorris03 Feb 24, 2025
5c2da76
lit tests pass (except jni -- not tested)
brnorris03 Feb 24, 2025
9a472ac
manually fix formatting; can't get clang-format to do it on any of my…
brnorris03 Feb 24, 2025
4c02657
revert lit test changes unrelated to update
brnorris03 Feb 24, 2025
010f3ce
Merge branch 'main' into llvm-update-0e779ad
AlexandreEichenberger Mar 3, 2025
0eb511e
update llvm and stablhlo shas, misc minor updates
brnorris03 Feb 25, 2025
c0dc186
remove non-existent passes
brnorris03 Feb 25, 2025
ff7c7a1
lit updates (wip)
brnorris03 Mar 3, 2025
31da901
Bump Upsample to Opset 10 and change the opset versioning to allow to…
jorickert Mar 5, 2025
292271a
Improve scripts (#3089)
AlexandreEichenberger Mar 7, 2025
25c0be7
Bump various ops to opset 21, adding int4/uint4 and 8 bit float suppo…
jorickert Mar 11, 2025
4bb8153
Added minimal support to do some timing of OM Runtime functionality (…
AlexandreEichenberger Mar 19, 2025
c79bc0a
adding __errno_location call for mvs (#3099)
christopherlmunoz Mar 21, 2025
b393925
Rewriting pattern to remove WhereOp and EqualOp. (#3094)
imaihal Mar 21, 2025
e053067
Enable NNPA saturation by default and change the option to --nnpa-dis…
tungld Mar 26, 2025
5e40f19
removing weak attribute of errorno (#3103)
christopherlmunoz Mar 26, 2025
4cee701
Fix the custom build link for docs/Docker.md (#3104)
qjivy Mar 27, 2025
4df054f
Python driver for torch model (#3093)
chentong319 Mar 27, 2025
83d68d6
implement (#3108)
chentong319 Apr 2, 2025
2be25d2
Followups for torch model driver (#3106)
chentong319 Apr 3, 2025
9843a33
Fix an error in ZHighConstantPropagation for QuantizedStick (#3112)
tungld Apr 8, 2025
8353536
Add z17 for -march (#3113)
chentong319 Apr 10, 2025
7c29b12
Decompose Hardswish into simpler ONNX ops (#3107)
kumarappan-cmyk Apr 14, 2025
1279b35
Reorder relu to maxpool optimization pass in ONNX dialect (#3109)
Arkar-Hema Apr 14, 2025
531f682
Move onnx.Constant before the root op when fusing onnx ops (#3119)
tungld Apr 15, 2025
4879853
Support QLinearMatMul on CPU (#3117)
tungld Apr 16, 2025
51a4609
Update black-format-check.yml (#3118)
andife Apr 16, 2025
adc83ed
Merge nested concat Ops optimization pass in ONNX dialect (#3111)
Arkar-Hema Apr 16, 2025
ddf39de
Enhance shape inference for ONNX Reshape (#3122)
tungld Apr 18, 2025
1d1fa8e
update zdnn1.1.2 (#3130)
Sunny-Anand Apr 18, 2025
7a2d25e
Updating supported ops on NNPA md for z17. (#3120)
christopherlmunoz Apr 18, 2025
1c2a9ba
fix CVE-2025-32434 (#3135)
Sunny-Anand Apr 21, 2025
069d129
Fuse consecutive clips pattern (#3132)
kumarappan-cmyk Apr 22, 2025
bd0597d
Replace deprecated applyPatternsAndFoldGreedily with applyPatternsGre…
jorickert Apr 23, 2025
b2bba6c
Fix clang-format
jorickert Apr 23, 2025
864e6e5
Replace bufferization::createOwnershipBasedBufferDeallocationPass wit…
jorickert Apr 23, 2025
bc6ef60
Update onnx-to-tosa reshape lit test
jorickert Apr 23, 2025
d8999ee
Move gemm_to_fc tests to gemm_to_matmul
jorickert Apr 24, 2025
2acd812
Change tosaBuilder::mul function signature to make clear that the shi…
jorickert Apr 28, 2025
1357f0b
Disable buffer_loop_hoisting test as it gets completly optimized away
jorickert Apr 30, 2025
20a1cf7
Guard against dynamic dim in result
jorickert Apr 30, 2025
5bc4937
Use resize operaton input and output type to calculate the border, in…
jorickert Apr 30, 2025
3b21c0a
Guard against linear interpolation of integer types
jorickert Apr 30, 2025
319933a
Add test for disallowed onnx.Resize on its with linear interpolation …
jorickert May 5, 2025
a1e584d
Add 'Pure' annotation to some krnl ops and recreate documentation
jorickert May 6, 2025
3fbc8f8
Build stablehlo with static libs
jorickert May 7, 2025
5627762
Disable memref.prefetch since it does not work with the new bufferiza…
tungld May 9, 2025
51897ad
Conv add const where the constant is a scalar (#3145)
AlexandreEichenberger Apr 30, 2025
55251a0
added support for Celu op (#3139)
logeshwaranmcw Apr 30, 2025
7a5c8e0
Fix some warnings related to stickification for NNPA (#3147)
tungld May 1, 2025
35f73c0
Removing duplicate file (#3146)
christopherlmunoz May 1, 2025
a0ff012
migrated instance/group normalization from decompose to canonicalize …
AlexandreEichenberger May 2, 2025
6faf742
Fusion of Matmul add covering the stacked/unstacked/bcast1/bcast23 pa…
AlexandreEichenberger May 2, 2025
7a621c3
Support --march=native (#3134)
chentong319 May 6, 2025
0c5b71b
fix another error on s390x
tungld May 9, 2025
65164ba
undo changes made by mistake
tungld May 9, 2025
e404842
lower Ub to LLVM since vector.shape_cast is lowered to UB
tungld May 9, 2025
7ee8a9c
Merge branch 'main' into llvm-update-0e779ad
jorickert May 14, 2025
5cb6029
Merge branch 'main' into llvm-update-0e779ad
AlexandreEichenberger May 16, 2025
b89116b
Merge branch 'main' into llvm-update-0e779ad
jorickert May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
auto int16Ty = IntegerType::get(context, 16);
auto int32Ty = IntegerType::get(context, 32);
auto int64Ty = IntegerType::get(context, 64);
auto float32Ty = FloatType::getF32(context);
auto float32Ty = Float32Type::get(context);

// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
Expand Down Expand Up @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
Type llvmI64Ty = IntegerType::get(context, 64);
Type llvmI1Ty = IntegerType::get(context, 1);
Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmF32Ty = FloatType::getF32(context);
Type llvmF32Ty = Float32Type::get(context);
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
Expand Down Expand Up @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
create.llvm.store(recScale, recScalePtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, recScalePtr);
}

Expand All @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
create.llvm.store(offset, offsetPtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, offsetPtr);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
// or
// (memref<3x4x5xf64>, index, f64, f64, f64)
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
Type llvmOptionsTy = FloatType::getF32(context);
Type llvmOptionsTy = Float32Type::get(context);
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
if (inType.isF64()) {
llvmOptionsTy = FloatType::getF64(context);
llvmOptionsTy = Float64Type::get(context);
llvmOutputTy = getPointerType(context, llvmOptionsTy);
}
Type llvmI64Ty = IntegerType::get(context, 64);
Expand Down
13 changes: 6 additions & 7 deletions src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
Type outType = op->getResultTypes().front();
Type llvmInType, llvmOutType;
if (inType.isF16())
llvmInType = FloatType::getF16(context);
llvmInType = Float16Type::get(context);
else if (inType.isF32())
llvmInType = FloatType::getF32(context);
llvmInType = Float32Type::get(context);
else if (inType.isF64())
llvmInType = FloatType::getF64(context);
llvmInType = Float64Type::get(context);
else if (inType.isBF16())
llvmInType = FloatType::getBF16(context);
llvmInType = Float64Type::get(context);
if (outType.isInteger(1))
llvmOutType = IntegerType::get(context, 1);
else if (outType.isF32())
llvmOutType = FloatType::getF32(context);
llvmOutType = Float32Type::get(context);
else if (outType.isF64())
llvmOutType = FloatType::getF64(context);
llvmOutType = Float64Type::get(context);

// Insert and/or get reference to elementary math function declaration.
assert(
Expand Down Expand Up @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
return SymbolRefAttr::get(context, mathFuncName);

// Create function declaration.
// auto llvmF32Ty = FloatType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef<Type>({llvmInType}));

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

// Get memRefDescriptor, the new memref descriptor.
MemRefDescriptor memRefDescriptor =
MemRefDescriptor::undef(rewriter, loc, targetStructType);
MemRefDescriptor::poison(rewriter, loc, targetStructType);
auto targetElementPtrType = memRefDescriptor.getElementPtrType();

// Set the new memref to the same buffer as the source memref.
Expand All @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(targetType, strides, offset)))
if (failed(targetType.getStridesAndOffset(strides, offset)))
return failure();

// Unhandled dynamic offset.
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
SmallVector<int64_t, 4> strides;
int64_t alignmentOffset; // not used, just to make the function call
// completed.
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
if (memRefType.getStridesAndOffset(strides, alignmentOffset)
.failed())
llvm_unreachable("Failed to get strides");
Value stringMemRef =
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Math/LRN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
float alphaLit = adaptor.getAlpha().convertToFloat();
float betaLit = adaptor.getBeta().convertToFloat();
int sizeLit = adaptor.getSize();
auto f32Type = FloatType::getF32(rewriter.getContext());
auto f32Type = Float32Type::get(rewriter.getContext());
Value biasValue = create.math.constant(f32Type, biasLit);
Value alphaDivSizeValue =
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));
Expand Down
24 changes: 16 additions & 8 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {

Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef<int64_t> size,
llvm::ArrayRef<int64_t> start) {
DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
auto startVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
auto sizeVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
Value newSliceInput =
tosa::CreateOpAndInfer<mlir::tosa::SliceOp>(rewriter(), loc(),
RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
mlir::cast<ShapedType>(inputConst.getType()).getElementType()),
inputConst, startAttr, sizeAttr);
inputConst, startVal, sizeVal);
return newSliceInput;
}

Expand All @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef<int64_t> shape) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(shape.size(), ShapedType::kDynamic),
valueType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(
rewriter(), loc(), newValueType, value, shapeAttr);
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter(), loc(),
newValueType, value,
mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Expand All @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());

auto int8Type = rewriter().getI8Type();
auto shiftValue = TosaBuilder::createConst(
ArrayRef<int8_t>{static_cast<int8_t>(shift)}, {1}, int8Type);
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
}

Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
Expand Down Expand Up @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(Value &lhs, Value &rhs);
// Return null if none is found.
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
auto definingOp = value.getDefiningOp();
// If we have a cast between index/integer, skip it, i.e. get the defining op
// that is the input to the cast.
// If we have a cast between index/integer, skip it, i.e. get the defining
// op that is the input to the cast.
if (auto castOp = dyn_cast_or_null<arith::IndexCastOp>(definingOp)) {
Value input = castOp.getIn();
definingOp = input.getDefiningOp();
Expand Down
20 changes: 15 additions & 5 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
auto inputElementType =
llvm::cast<TensorType>(op.getType()).getElementType();
if (llvm::isa<IntegerType>(inputElementType)) {
auto minClamp = rewriter.getI64IntegerAttr(0);
auto maxClamp =
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
} else {
auto minClamp = rewriter.getF32FloatAttr(0.0f);
auto maxClamp =
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
}
return success();
}
};
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
Expand Down Expand Up @@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {

llvm::SmallVector<int64_t> dynamicTensorShape = {
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};

A = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
rewriter.getDenseI64ArrayAttr(newShapeA))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
.getResult();
B = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
rewriter.getDenseI64ArrayAttr(newShapeB))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
.getResult();

// If transA or transB are present, create Transpose operators.
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -60,8 +61,8 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
}
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());
TosaBuilder tosaBuilder(rewriter, loc);
return tosaBuilder.getConst(
tosaPads, {static_cast<int64_t>(tosaPads.size())});

return mlir::tosa::getTosaConstShape(rewriter, loc, tosaPads);
}

} // namespace tosa
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ T getValueFromTosaConst(mlir::Value &val) {
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Type result_ty, Args &&... args) {

auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

mlir::InferShapedTypeOpInterface shapeInterface =
Expand All @@ -64,6 +65,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
assert(returnedShapes.size() >= 1 && "Expected at least one returned shape");
auto predictedShape = returnedShapes[0];
if (predictedShape.hasRank())
updateType(nullptr, op, predictedShape.getDims(),
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/BType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) {
case BType::FLOAT : return b.getF32Type();
case BType::FLOAT16 : return b.getF16Type();
case BType::BFLOAT16 : return b.getBF16Type();
case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType();
case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType();
case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type();
case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType();
case BType::FLOAT8E4M3FN : return b.getType<Float8E4M3FNType>();
case BType::FLOAT8E4M3FNUZ : return b.getType<Float8E4M3FNUZType>();
case BType::FLOAT8E5M2 : return b.getType<Float8E5M2Type>();
case BType::FLOAT8E5M2FNUZ : return b.getType<Float8E5M2FNUZType>();
default: llvm_unreachable("unsupported data type");
}
// clang-format on
Expand Down Expand Up @@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) {
[](auto btype) { return toBType<typename BTypeTrait<btype>::widetype>; });
}

} // namespace onnx_mlir
} // namespace onnx_mlir
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes(
return success();

ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext()));
return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext()));
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes(
ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {});
RankedTensorType xType = mlir::dyn_cast<RankedTensorType>(getX().getType());
return shapeHelper.computeShapeAndUpdateType(
FloatType::getF32(getContext()), xType.getEncoding());
Float32Type::get(getContext()), xType.getEncoding());
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 7 additions & 7 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
Type elementType;
if (auto attr = getDtypeAttr()) {
if (getDtype() == 0) {
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
} else if (getDtype() == 1) {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
} else if (getDtype() == 2) {
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
} else {
llvm_unreachable("dtype not supported for RandomNormal");
}
} else {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
}
return {UnrankedTensorType::get(elementType)};
}
Expand All @@ -68,11 +68,11 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
LogicalResult ONNXRandomNormalOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
auto elementTypeID = getDtype();
Type elementType = FloatType::getF32(getContext());
Type elementType = Float32Type::get(getContext());
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());

ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(elementType);
Expand Down
14 changes: 6 additions & 8 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ LogicalResult ONNXRandomNormalLikeOp::verify() {
if (elementTypeID < 0 || elementTypeID > 2) {
return emitOpError("dtype not 0, 1 or 2.");
}
if (elementTypeID == 0 && outputType != FloatType::getF16(getContext()))
if (elementTypeID == 0 && outputType != Float16Type::get(getContext()))
return emitOpError("output tensor does match 0 dtype.");
else if (elementTypeID == 1 &&
outputType != FloatType::getF32(getContext()))
else if (elementTypeID == 1 && outputType != Float32Type::get(getContext()))
return emitOpError("output tensor does match 1 dtype.");
else if (elementTypeID == 2 &&
outputType != FloatType::getF64(getContext()))
else if (elementTypeID == 2 && outputType != Float64Type::get(getContext()))
return emitOpError("output tensor does match 2 dtype.");
} else if (inputType != outputType) {
return emitOpError("output and input element types do not match.");
Expand All @@ -75,11 +73,11 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes(
} else {
int64_t elementTypeID = elementTypeIDDType.value();
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 1)
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
else
return emitError("dtype attribute is invalid (use: 0, 1 or 2)");
}
Expand Down
Loading