Skip to content

Commit c6506ba

Browse files
authored
[Backport to 14] Handle OpVectorShuffle with differing vector sizes (#2391) (#2413)
The SPIR-V to LLVM conversion would bail out when encountering an `OpVectorShuffle` whose vector operands differ in size. SPIR-V allows differing vector sizes, but LLVM's `shufflevector` does not. Remove the assert and insert an additional `shufflevector` to align the vector operands when needed. (cherry picked from commit 3df5e38)
1 parent d54b2ce commit c6506ba

File tree

5 files changed

+91
-9
lines changed

5 files changed

+91
-9
lines changed

lib/SPIRV/SPIRVInternal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ using namespace llvm;
5959

6060
namespace llvm {
6161
class IntrinsicInst;
62+
class IRBuilderBase;
6263
}
6364

6465
namespace SPIRV {
@@ -603,6 +604,10 @@ std::string mapSPIRVTypeToOCLType(SPIRVType *Ty, bool Signed);
603604
std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed);
604605
SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target);
605606

607+
/// Return vector V extended with poison elements to match the number of
608+
/// components of NewType.
609+
Value *extendVector(Value *V, FixedVectorType *NewType, IRBuilderBase &Builder);
610+
606611
/// Add decorations to a SPIR-V entry.
607612
/// \param Decs Each string is a postfix without _ at the beginning.
608613
SPIRVValue *addDecorations(SPIRVValue *Target,

lib/SPIRV/SPIRVReader.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,10 +2297,37 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
22972297
if (BB) {
22982298
Builder.SetInsertPoint(BB);
22992299
}
2300-
return mapValue(BV, Builder.CreateShuffleVector(
2301-
transValue(VS->getVector1(), F, BB),
2302-
transValue(VS->getVector2(), F, BB),
2303-
ConstantVector::get(Components), BV->getName()));
2300+
Value *Vec1 = transValue(VS->getVector1(), F, BB);
2301+
Value *Vec2 = transValue(VS->getVector2(), F, BB);
2302+
auto *Vec1Ty = cast<FixedVectorType>(Vec1->getType());
2303+
auto *Vec2Ty = cast<FixedVectorType>(Vec2->getType());
2304+
if (Vec1Ty->getNumElements() != Vec2Ty->getNumElements()) {
2305+
// LLVM's shufflevector requires that the two vector operands have the
2306+
// same type; SPIR-V's OpVectorShuffle allows the vector operands to
2307+
// differ in the number of components. Adjust for that by extending
2308+
// the smaller vector.
2309+
if (Vec1Ty->getNumElements() < Vec2Ty->getNumElements()) {
2310+
Vec1 = extendVector(Vec1, Vec2Ty, Builder);
2311+
// Extending Vec1 requires offsetting any Vec2 indices in Components by
2312+
// the number of new elements.
2313+
unsigned Offset = Vec2Ty->getNumElements() - Vec1Ty->getNumElements();
2314+
unsigned Vec2Start = Vec1Ty->getNumElements();
2315+
for (auto &C : Components) {
2316+
if (auto *CI = dyn_cast<ConstantInt>(C)) {
2317+
uint64_t V = CI->getZExtValue();
2318+
if (V >= Vec2Start) {
2319+
// This is a Vec2 index; add the offset to it.
2320+
C = ConstantInt::get(Int32Ty, V + Offset);
2321+
}
2322+
}
2323+
}
2324+
} else {
2325+
Vec2 = extendVector(Vec2, Vec1Ty, Builder);
2326+
}
2327+
}
2328+
return mapValue(
2329+
BV, Builder.CreateShuffleVector(
2330+
Vec1, Vec2, ConstantVector::get(Components), BV->getName()));
23042331
}
23052332

23062333
case OpBitReverse: {

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,24 @@ void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
9090
Call->removeFnAttr(Attr);
9191
}
9292

93+
Value *extendVector(Value *V, FixedVectorType *NewType,
94+
IRBuilderBase &Builder) {
95+
unsigned OldSize = cast<FixedVectorType>(V->getType())->getNumElements();
96+
unsigned NewSize = NewType->getNumElements();
97+
assert(OldSize < NewSize);
98+
std::vector<Constant *> Components;
99+
IntegerType *Int32Ty = Builder.getInt32Ty();
100+
for (unsigned I = 0; I < NewSize; I++) {
101+
if (I < OldSize)
102+
Components.push_back(ConstantInt::get(Int32Ty, I));
103+
else
104+
Components.push_back(PoisonValue::get(Int32Ty));
105+
}
106+
107+
return Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()),
108+
ConstantVector::get(Components), "vecext");
109+
}
110+
93111
void saveLLVMModule(Module *M, const std::string &OutputFile) {
94112
std::error_code EC;
95113
ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None);

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,15 +2209,11 @@ class SPIRVVectorShuffleBase : public SPIRVInstTemplateBase {
22092209
protected:
22102210
void validate() const override {
22112211
SPIRVInstruction::validate();
2212-
SPIRVId Vector1 = Ops[0];
2213-
SPIRVId Vector2 = Ops[1];
2212+
[[maybe_unused]] SPIRVId Vector1 = Ops[0];
22142213
assert(OpCode == OpVectorShuffle);
22152214
assert(Type->isTypeVector());
22162215
assert(Type->getVectorComponentType() ==
22172216
getValueType(Vector1)->getVectorComponentType());
2218-
if (getValue(Vector1)->isForward() || getValue(Vector2)->isForward())
2219-
return;
2220-
assert(getValueType(Vector1) == getValueType(Vector2));
22212217
assert(Ops.size() - 2 == Type->getVectorComponentCount());
22222218
}
22232219
};

test/OpVectorShuffle.spvasm

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; REQUIRES: spirv-as
2+
; RUN: spirv-as --target-env spv1.0 -o %t.spv %s
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s
5+
OpCapability Addresses
6+
OpCapability Kernel
7+
OpMemoryModel Physical32 OpenCL
8+
OpEntryPoint Kernel %1 "testVecShuffle"
9+
%void = OpTypeVoid
10+
%uint = OpTypeInt 32 0
11+
%uintv2 = OpTypeVector %uint 2
12+
%uintv3 = OpTypeVector %uint 3
13+
%uintv4 = OpTypeVector %uint 4
14+
%func = OpTypeFunction %void %uintv2 %uintv3
15+
16+
%1 = OpFunction %void None %func
17+
%pv2 = OpFunctionParameter %uintv2
18+
%pv3 = OpFunctionParameter %uintv3
19+
%entry = OpLabel
20+
21+
; Same vector lengths
22+
%vs1 = OpVectorShuffle %uintv4 %pv3 %pv3 0 1 3 5
23+
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 3, i32 5>
24+
25+
; vec1 smaller than vec2
26+
%vs2 = OpVectorShuffle %uintv4 %pv2 %pv3 0 1 3 4
27+
; CHECK: %[[VS2EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 undef>
28+
; CHECK: shufflevector <3 x i32> %[[VS2EXT]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
29+
30+
; vec1 larger than vec2
31+
%vs3 = OpVectorShuffle %uintv4 %pv3 %pv2 0 1 3 4
32+
; CHECK: %[[VS3EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 undef>
33+
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[VS3EXT]], <4 x i32> <i32 0, i32 1, i32 3, i32 4>
34+
35+
OpReturn
36+
OpFunctionEnd

0 commit comments

Comments
 (0)