Skip to content

Commit 79ed896

Browse files
Implicit broadcast of scalar values to vector values in the MathBuilder constructor (#2900)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent f11a21c commit 79ed896

File tree

4 files changed

+261
-132
lines changed

4 files changed

+261
-132
lines changed

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ Value emitPostProcessingFor(ConversionPatternRewriter &rewriter, Location loc,
4242

4343
template <typename Op>
4444
static void CheckIfCustomScalarOpIsSupported(Type elementType) {
45-
Type actualElementType = MathBuilder::elementTypeWithVector(elementType);
45+
Type actualElementType =
46+
MathBuilder::elementTypeOfScalarOrVector(elementType);
4647
if (mlir::isa<mlir::IntegerType>(actualElementType)) {
4748
if constexpr (std::is_same<ScalarIOp<Op>, CustomScalarOp>::value)
4849
return;
@@ -914,7 +915,7 @@ Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter,
914915
// ConstantOp 0,
915916
// %Y)
916917
Value plusSelect;
917-
if (create.math.isUnsignedIntegerWithVector(elementType)) {
918+
if (create.math.isScalarOrVectorUnsignedInteger(elementType)) {
918919
// Unsigned integers are by definition positive.
919920
plusSelect = one;
920921
} else {
@@ -1188,7 +1189,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
11881189
MultiDialectBuilder<MathBuilder, KrnlBuilder> create(rewriter, loc);
11891190

11901191
// TODO: here we assume fmod=1, what should if that is not the case?
1191-
if (create.math.isFloatWithVector(elementType)) {
1192+
if (create.math.isScalarOrVectorFloat(elementType)) {
11921193
// fmod is always 1. Behavior is like numpy.fmod.
11931194
// The sign of the remainder is the same as the dividend.
11941195
Value rem = create.math.rem(dividend, divisor);
@@ -1201,7 +1202,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
12011202
return create.math.copySign(rem, dividend);
12021203
#endif
12031204
}
1204-
if (create.math.isIntegerWithVector(elementType)) {
1205+
if (create.math.isScalarOrVectorInteger(elementType)) {
12051206
// "math.rem" returns "minus" for minus dividend and "plus or zero" for plus
12061207
// dividend. We call the math.rem's return value "mathRemainder". However
12071208
// onnx.ModOp should return "minus" for minus divisor and "plus or zero" for

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ mlir::Value emitScalarOpFor(mlir::ConversionPatternRewriter &rewriter,
228228
// int. Thus we look at the type the first input argument, and not the output
229229
// elementType.
230230
mlir::Type actualElementType =
231-
MathBuilder::elementTypeWithVector(scalarOperands[0].getType());
231+
MathBuilder::elementTypeOfScalarOrVector(scalarOperands[0]);
232232
// Perform int or float operation depending on the actual elementary type.
233233
if (mlir::isa<mlir::IntegerType>(actualElementType)) {
234234
// Generate the integer code only if the scalar integer op is non-void

0 commit comments

Comments
 (0)