Skip to content

[Backport to 18] Handle OpVectorShuffle with differing vector sizes #2409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ using namespace llvm;

namespace llvm {
class IntrinsicInst;
class IRBuilderBase;
}

namespace SPIRV {
Expand Down Expand Up @@ -552,6 +553,10 @@ std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed,
Type *PointerElementType = nullptr);
SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target);

/// Return vector V extended with poison elements to match the number of
/// components of NewType.
Value *extendVector(Value *V, FixedVectorType *NewType, IRBuilderBase &Builder);

/// Add decorations to a SPIR-V entry.
/// \param Decs Each string is a postfix without _ at the beginning.
SPIRVValue *addDecorations(SPIRVValue *Target,
Expand Down
35 changes: 31 additions & 4 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2375,10 +2375,37 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
if (BB) {
Builder.SetInsertPoint(BB);
}
return mapValue(BV, Builder.CreateShuffleVector(
transValue(VS->getVector1(), F, BB),
transValue(VS->getVector2(), F, BB),
ConstantVector::get(Components), BV->getName()));
Value *Vec1 = transValue(VS->getVector1(), F, BB);
Value *Vec2 = transValue(VS->getVector2(), F, BB);
auto *Vec1Ty = cast<FixedVectorType>(Vec1->getType());
auto *Vec2Ty = cast<FixedVectorType>(Vec2->getType());
if (Vec1Ty->getNumElements() != Vec2Ty->getNumElements()) {
// LLVM's shufflevector requires that the two vector operands have the
// same type; SPIR-V's OpVectorShuffle allows the vector operands to
// differ in the number of components. Adjust for that by extending
// the smaller vector.
if (Vec1Ty->getNumElements() < Vec2Ty->getNumElements()) {
Vec1 = extendVector(Vec1, Vec2Ty, Builder);
// Extending Vec1 requires offsetting any Vec2 indices in Components by
// the number of new elements.
unsigned Offset = Vec2Ty->getNumElements() - Vec1Ty->getNumElements();
unsigned Vec2Start = Vec1Ty->getNumElements();
for (auto &C : Components) {
if (auto *CI = dyn_cast<ConstantInt>(C)) {
uint64_t V = CI->getZExtValue();
if (V >= Vec2Start) {
// This is a Vec2 index; add the offset to it.
C = ConstantInt::get(Int32Ty, V + Offset);
}
}
}
} else {
Vec2 = extendVector(Vec2, Vec1Ty, Builder);
}
}
return mapValue(
BV, Builder.CreateShuffleVector(
Vec1, Vec2, ConstantVector::get(Components), BV->getName()));
}

case OpBitReverse: {
Expand Down
18 changes: 18 additions & 0 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
Call->removeFnAttr(Attr);
}

Value *extendVector(Value *V, FixedVectorType *NewType,
IRBuilderBase &Builder) {
unsigned OldSize = cast<FixedVectorType>(V->getType())->getNumElements();
unsigned NewSize = NewType->getNumElements();
assert(OldSize < NewSize);
std::vector<Constant *> Components;
IntegerType *Int32Ty = Builder.getInt32Ty();
for (unsigned I = 0; I < NewSize; I++) {
if (I < OldSize)
Components.push_back(ConstantInt::get(Int32Ty, I));
else
Components.push_back(PoisonValue::get(Int32Ty));
}

return Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()),
ConstantVector::get(Components), "vecext");
}

void saveLLVMModule(Module *M, const std::string &OutputFile) {
std::error_code EC;
ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None);
Expand Down
6 changes: 1 addition & 5 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2210,15 +2210,11 @@ class SPIRVVectorShuffleBase : public SPIRVInstTemplateBase {
protected:
void validate() const override {
SPIRVInstruction::validate();
SPIRVId Vector1 = Ops[0];
SPIRVId Vector2 = Ops[1];
[[maybe_unused]] SPIRVId Vector1 = Ops[0];
assert(OpCode == OpVectorShuffle);
assert(Type->isTypeVector());
assert(Type->getVectorComponentType() ==
getValueType(Vector1)->getVectorComponentType());
if (getValue(Vector1)->isForward() || getValue(Vector2)->isForward())
return;
assert(getValueType(Vector1) == getValueType(Vector2));
assert(Ops.size() - 2 == Type->getVectorComponentCount());
}
};
Expand Down
36 changes: 36 additions & 0 deletions test/OpVectorShuffle.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; REQUIRES: spirv-as
; RUN: spirv-as --target-env spv1.0 -o %t.spv %s
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s
OpCapability Addresses
OpCapability Kernel
OpMemoryModel Physical32 OpenCL
OpEntryPoint Kernel %1 "testVecShuffle"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uintv2 = OpTypeVector %uint 2
%uintv3 = OpTypeVector %uint 3
%uintv4 = OpTypeVector %uint 4
%func = OpTypeFunction %void %uintv2 %uintv3

%1 = OpFunction %void None %func
%pv2 = OpFunctionParameter %uintv2
%pv3 = OpFunctionParameter %uintv3
%entry = OpLabel

; Same vector lengths
%vs1 = OpVectorShuffle %uintv4 %pv3 %pv3 0 1 3 5
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 3, i32 5>

; vec1 smaller than vec2
%vs2 = OpVectorShuffle %uintv4 %pv2 %pv3 0 1 3 4
; CHECK: %[[VS2EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; CHECK: shufflevector <3 x i32> %[[VS2EXT]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>

; vec1 larger than vec2
%vs3 = OpVectorShuffle %uintv4 %pv3 %pv2 0 1 3 4
; CHECK: %[[VS3EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[VS3EXT]], <4 x i32> <i32 0, i32 1, i32 3, i32 4>

OpReturn
OpFunctionEnd