-
Notifications
You must be signed in to change notification settings - Fork 359
Implicit broadcast of scalar values to vector values in the MathBuilder constructor #2900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implicit broadcast of scalar values to vector values in the MathBuilder constructor #2900
Conversation
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
return mlir::isa<FloatType>(elementType); | ||
} | ||
|
||
bool MathBuilder::splatToMatch(Value &first, Value &second) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Method performing the splat so that both inputs are either scalar or vectors
static bool isIntegerWithVector(mlir::Type elementOrVectorType); | ||
static bool isUnsignedIntegerWithVector(mlir::Type elementOrVectorType); | ||
static bool isFloatWithVector(mlir::Type elementOrVectorType); | ||
static bool isScalarOrVectorInteger(mlir::Type elementOrVectorType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed for the meaning of the operation to be less ambiguous.
@tungld implemented the scheme which we agreed on, thanks for the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
VectorType firstVectorType = | ||
VectorType::get(secondVectorType.getShape(), first.getType()); | ||
first = b().create<vector::SplatOp>(loc(), firstVectorType, first); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current code is good as it is. FYI, I just tried to rewrite this to make it more natural to read, e.g.: if first is not a vector, splat it to a vector, if second is not a vector, splat it to a vector
, instead of if first is a vector, splat second to a vector, ...
.
VectorType firstVectorType = mlir::dyn_cast<VectorType>(first.getType());
VectorType secondVectorType = mlir::dyn_cast<VectorType>(second.getType());
if (!firstVectorType && secondVectorType) {
firstVectorType = VectorType::get(secondVectorType.getShape(), first.getType());
first = b().create<vector::SplatOp>(loc(), firstVectorType, first);
}
if (firstVectorType && !secondVectorType) {
secondVectorType = VectorType::get(firstVectorType.getShape(), second.getType());
second = b().create<vector::SplatOp>(loc(), secondVectorType, second);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it, thanks for rewriting my code! Better than any chatbots could ever be :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) but I guess there was an error in my code, I should not update firstVectorType
but use a local variable for it.
VectorType firstVectorType = mlir::dyn_cast<VectorType>(first.getType());
VectorType secondVectorType = mlir::dyn_cast<VectorType>(second.getType());
if (!firstVectorType && secondVectorType) {
VectorType ty = VectorType::get(secondVectorType.getShape(), first.getType());
first = b().create<vector::SplatOp>(loc(), ty, first);
}
if (firstVectorType && !secondVectorType) {
VectorType ty = VectorType::get(firstVectorType.getShape(), second.getType());
second = b().create<vector::SplatOp>(loc(), ty, second);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not notice it, but in fact I changed the code a bit so that it actually does not matter. Here is the latest code
bool MathBuilder::splatToMatch(Value &first, Value &second) const {
Type firstType = first.getType();
Type secondType = second.getType();
VectorType firstVectorType = mlir::dyn_cast<VectorType>(firstType);
VectorType secondVectorType = mlir::dyn_cast<VectorType>(secondType);
VectorBuilder createVec(*this);
// Splat first if needed.
if (!firstVectorType && secondVectorType) {
firstVectorType = VectorType::get(secondVectorType.getShape(), firstType);
first = createVec.splat(firstVectorType, first);
return true;
}
// Splat second if needed.
if (firstVectorType && !secondVectorType) {
secondVectorType = VectorType::get(firstVectorType.getShape(), secondType);
second = createVec.splat(secondVectorType, second);
return true;
}
// Otherwise check compatibility.
assert(createVec.compatibleTypes(firstType, secondType) &&
"expected compatible types");
return false;
}
I like the early return as a style to avoid if-then-else
that I find harder to read.
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Jenkins Linux ppc64le Build #14314 [push] Implicit broadcast of sc... started at 15:28 |
Jenkins Linux s390x Build #15289 [push] Implicit broadcast of sc... started at 15:16 |
Jenkins Linux amd64 Build #15284 [push] Implicit broadcast of sc... started at 14:16 |
Jenkins Linux amd64 Build #15284 [push] Implicit broadcast of sc... passed after 1 hr 7 min |
Jenkins Linux s390x Build #15289 [push] Implicit broadcast of sc... passed after 1 hr 42 min |
Jenkins Linux ppc64le Build #14314 [push] Implicit broadcast of sc... passed after 2 hr 15 min |
The math dialect in MLIR supports operations on scalar or vector inputs, but not a mixture of the two.
When generating SIMD code, we would like to generate a single source of code that can then be used to generate both SIMD code for the innermost SIMD loop, as well as scalar code for the remaining few iterations, if any.
The code to generate SIMD will load
memrefs
either as scalar or vectors; but often there are operations that needs constants, such as "compare to zero", "max of zero"... the scalar are needed then as a scalar in scalar code, but as a splatted constant when generating SIMD code.The most robust method to provide this functionality is to provide implicit "splat" for each of the MathBuilder operations.
This PR does this.
This code will later be put to good use in a subsequent PR.
As an illustration of how this is used, here is a preview of where this new capability in this PR will be useful.
In the above code,
scale
,zeroPoint
,qMin
,qMax
are scalar that are needed in scalar or simd mode. We don't want explicit splats as we don't know in which mode we run, so we want them to be implicitly added when theinpputVals[*]
are actually simd vectors.