Skip to content

Commit 48c703b

Browse files
authored
opt: add pass to split combined image samplers (#6035)
* opt: add pass to split combined image samplers Remap as follows: - A combined-image sampler variable is replaced with a image variable and sampler variable at the same DescriptorSet and Binding, and copying other decoraitons. - A combined-image sampler function parameter is replaced with a image parameter and sampler parameter. - Recursively applies to pointer-to-combined, array-of-combined, and runtime-array-of-combined. - Remaps function types as needed. Removes the type definitions for pointer-to, array-of, and rt-array-of combined types. Maintains def-use analysis, and the type manager. Bug: crbug.com/398231475
1 parent 5986ec1 commit 48c703b

17 files changed

+2767
-7
lines changed

Android.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ SPVTOOLS_OPT_SRC_FILES := \
177177
source/opt/scalar_replacement_pass.cpp \
178178
source/opt/set_spec_constant_default_value_pass.cpp \
179179
source/opt/simplification_pass.cpp \
180+
source/opt/split_combined_image_sampler_pass.cpp \
180181
source/opt/spread_volatile_semantics.cpp \
181182
source/opt/ssa_rewrite_pass.cpp \
182183
source/opt/strength_reduction_pass.cpp \

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ cc_library(
201201
deps = [
202202
":spirv_tools_internal",
203203
"@spirv_headers//:spirv_common_headers",
204+
"@spirv_headers//:spirv_c_headers",
204205
],
205206
)
206207

BUILD.gn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,8 @@ static_library("spvtools_opt") {
781781
"source/opt/set_spec_constant_default_value_pass.h",
782782
"source/opt/simplification_pass.cpp",
783783
"source/opt/simplification_pass.h",
784+
"source/opt/split_combined_image_sampler_pass.cpp",
785+
"source/opt/split_combined_image_sampler_pass.h",
784786
"source/opt/spread_volatile_semantics.cpp",
785787
"source/opt/spread_volatile_semantics.h",
786788
"source/opt/ssa_rewrite_pass.cpp",

include/spirv-tools/optimizer.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ class SPIRV_TOOLS_EXPORT Optimizer {
240240

241241
private:
242242
struct SPIRV_TOOLS_LOCAL Impl; // Opaque struct for holding internal data.
243-
std::unique_ptr<Impl> impl_; // Unique pointer to internal data.
243+
std::unique_ptr<Impl> impl_; // Unique pointer to internal data.
244244
};
245245

246246
// Creates a null pass.
@@ -968,6 +968,11 @@ Optimizer::PassToken CreateInvocationInterlockPlacementPass();
968968
// Creates a pass to add/remove maximal reconvergence execution mode.
969969
// This pass either adds or removes maximal reconvergence from all entry points.
970970
Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add);
971+
972+
// Creates a pass to split combined image+sampler variables and function
973+
// parameters into separate image and sampler parts. Binding numbers and
974+
// other decorations are copied.
975+
Optimizer::PassToken CreateSplitCombinedImageSamplerPass();
971976
} // namespace spvtools
972977

973978
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_

source/opt/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
114114
scalar_replacement_pass.h
115115
set_spec_constant_default_value_pass.h
116116
simplification_pass.h
117+
split_combined_image_sampler_pass.h
117118
spread_volatile_semantics.h
118119
ssa_rewrite_pass.h
119120
strength_reduction_pass.h
@@ -230,6 +231,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
230231
scalar_replacement_pass.cpp
231232
set_spec_constant_default_value_pass.cpp
232233
simplification_pass.cpp
234+
split_combined_image_sampler_pass.cpp
233235
spread_volatile_semantics.cpp
234236
ssa_rewrite_pass.cpp
235237
strength_reduction_pass.cpp

source/opt/function.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <algorithm>
1919
#include <functional>
20+
#include <iterator>
2021
#include <memory>
2122
#include <string>
2223
#include <unordered_set>
@@ -39,6 +40,7 @@ class Function {
3940
public:
4041
using iterator = UptrVectorIterator<BasicBlock>;
4142
using const_iterator = UptrVectorIterator<BasicBlock, true>;
43+
using ParamList = std::vector<std::unique_ptr<Instruction>>;
4244

4345
// Creates a function instance declared by the given OpFunction instruction
4446
// |def_inst|.
@@ -77,6 +79,23 @@ class Function {
7779
// Does nothing if the function doesn't have such a parameter.
7880
inline void RemoveParameter(uint32_t id);
7981

82+
// Rewrites the function parameters by calling a replacer callback.
83+
// The replacer takes two parameters: an expiring unique pointer to a current
84+
// instruction, and a back-inserter into a new list of unique pointers to
85+
// instructions. The replacer is called for each current parameter, in order.
86+
// Not valid to call while also iterating through the parameter list, e.g.
87+
// within the ForEachParam method.
88+
using RewriteParamFn = std::function<void(
89+
std::unique_ptr<Instruction>&&, std::back_insert_iterator<ParamList>&)>;
90+
void RewriteParams(RewriteParamFn& replacer) {
91+
ParamList new_params;
92+
auto appender = std::back_inserter(new_params);
93+
for (auto& param : params_) {
94+
replacer(std::move(param), appender);
95+
}
96+
params_ = std::move(new_params);
97+
}
98+
8099
// Saves the given function end instruction.
81100
inline void SetFunctionEnd(std::unique_ptr<Instruction> end_inst);
82101

@@ -197,7 +216,7 @@ class Function {
197216
// The OpFunction instruction that begins the definition of this function.
198217
std::unique_ptr<Instruction> def_inst_;
199218
// All parameters to this function.
200-
std::vector<std::unique_ptr<Instruction>> params_;
219+
ParamList params_;
201220
// All debug instructions in this function's header.
202221
InstructionList debug_insts_in_header_;
203222
// All basic blocks inside this function in specification order

source/opt/ir_builder.h

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef SOURCE_OPT_IR_BUILDER_H_
1616
#define SOURCE_OPT_IR_BUILDER_H_
1717

18+
#include <cassert>
1819
#include <limits>
1920
#include <memory>
2021
#include <utility>
@@ -480,8 +481,11 @@ class InstructionBuilder {
480481
return AddInstruction(std::move(select));
481482
}
482483

483-
Instruction* AddAccessChain(uint32_t type_id, uint32_t base_ptr_id,
484-
std::vector<uint32_t> ids) {
484+
Instruction* AddOpcodeAccessChain(spv::Op opcode, uint32_t type_id,
485+
uint32_t base_ptr_id,
486+
const std::vector<uint32_t>& ids) {
487+
assert(opcode == spv::Op::OpAccessChain ||
488+
opcode == spv::Op::OpInBoundsAccessChain);
485489
std::vector<Operand> operands;
486490
operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}});
487491

@@ -490,12 +494,22 @@ class InstructionBuilder {
490494
}
491495

492496
// TODO(1841): Handle id overflow.
493-
std::unique_ptr<Instruction> new_inst(
494-
new Instruction(GetContext(), spv::Op::OpAccessChain, type_id,
495-
GetContext()->TakeNextId(), operands));
497+
std::unique_ptr<Instruction> new_inst(new Instruction(
498+
GetContext(), opcode, type_id, GetContext()->TakeNextId(), operands));
496499
return AddInstruction(std::move(new_inst));
497500
}
498501

502+
Instruction* AddAccessChain(uint32_t type_id, uint32_t base_ptr_id,
503+
const std::vector<uint32_t>& ids) {
504+
return AddOpcodeAccessChain(spv::Op::OpAccessChain, type_id, base_ptr_id,
505+
ids);
506+
}
507+
Instruction* AddInBoundsAccessChain(uint32_t type_id, uint32_t base_ptr_id,
508+
const std::vector<uint32_t>& ids) {
509+
return AddOpcodeAccessChain(spv::Op::OpInBoundsAccessChain, type_id,
510+
base_ptr_id, ids);
511+
}
512+
499513
Instruction* AddLoad(uint32_t type_id, uint32_t base_ptr_id,
500514
uint32_t alignment = 0) {
501515
std::vector<Operand> operands;
@@ -514,6 +528,16 @@ class InstructionBuilder {
514528
return AddInstruction(std::move(new_inst));
515529
}
516530

531+
Instruction* AddCopyObject(uint32_t type_id, uint32_t value_id) {
532+
std::vector<Operand> operands{{SPV_OPERAND_TYPE_ID, {value_id}}};
533+
534+
// TODO(1841): Handle id overflow.
535+
std::unique_ptr<Instruction> new_inst(
536+
new Instruction(GetContext(), spv::Op::OpCopyObject, type_id,
537+
GetContext()->TakeNextId(), operands));
538+
return AddInstruction(std::move(new_inst));
539+
}
540+
517541
Instruction* AddVariable(uint32_t type_id, uint32_t storage_class) {
518542
std::vector<Operand> operands;
519543
operands.push_back({SPV_OPERAND_TYPE_STORAGE_CLASS, {storage_class}});
@@ -572,6 +596,26 @@ class InstructionBuilder {
572596
return AddInstruction(std::move(new_inst));
573597
}
574598

599+
Instruction* AddDecoration(uint32_t target_id, spv::Decoration d,
600+
const std::vector<uint32_t>& literals) {
601+
std::vector<Operand> operands;
602+
operands.push_back({SPV_OPERAND_TYPE_ID, {target_id}});
603+
operands.push_back({SPV_OPERAND_TYPE_DECORATION, {uint32_t(d)}});
604+
for (uint32_t literal : literals) {
605+
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}});
606+
}
607+
608+
std::unique_ptr<Instruction> new_inst(
609+
new Instruction(GetContext(), spv::Op::OpDecorate, 0, 0, operands));
610+
// Decorations are annotation instructions. Add it via the IR context,
611+
// so the decoration manager will be updated.
612+
// Decorations don't belong to basic blocks, so there is no need
613+
// to update the instruction to block mapping.
614+
Instruction* result = new_inst.get();
615+
GetContext()->AddAnnotationInst(std::move(new_inst));
616+
return result;
617+
}
618+
575619
Instruction* AddNaryExtendedInstruction(
576620
uint32_t result_type, uint32_t set, uint32_t instruction,
577621
const std::vector<uint32_t>& ext_operands) {
@@ -593,6 +637,23 @@ class InstructionBuilder {
593637
return AddInstruction(std::move(new_inst));
594638
}
595639

640+
Instruction* AddSampledImage(uint32_t sampled_image_type_id,
641+
uint32_t image_id, uint32_t sampler_id) {
642+
std::vector<Operand> operands;
643+
operands.push_back({SPV_OPERAND_TYPE_ID, {image_id}});
644+
operands.push_back({SPV_OPERAND_TYPE_ID, {sampler_id}});
645+
646+
uint32_t result_id = GetContext()->TakeNextId();
647+
if (result_id == 0) {
648+
return nullptr;
649+
}
650+
651+
std::unique_ptr<Instruction> new_inst(
652+
new Instruction(GetContext(), spv::Op::OpSampledImage,
653+
sampled_image_type_id, result_id, operands));
654+
return AddInstruction(std::move(new_inst));
655+
}
656+
596657
// Inserts the new instruction before the insertion point.
597658
Instruction* AddInstruction(std::unique_ptr<Instruction>&& insn) {
598659
Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));

source/opt/optimizer.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
637637
}
638638
} else if (pass_name == "trim-capabilities") {
639639
RegisterPass(CreateTrimCapabilitiesPass());
640+
} else if (pass_name == "split-combined-image-sampler") {
641+
RegisterPass(CreateSplitCombinedImageSamplerPass());
640642
} else {
641643
Errorf(consumer(), nullptr, {},
642644
"Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -1188,6 +1190,11 @@ Optimizer::PassToken CreateOpExtInstWithForwardReferenceFixupPass() {
11881190
MakeUnique<opt::OpExtInstWithForwardReferenceFixupPass>());
11891191
}
11901192

1193+
Optimizer::PassToken CreateSplitCombinedImageSamplerPass() {
1194+
return MakeUnique<Optimizer::PassToken::Impl>(
1195+
MakeUnique<opt::SplitCombinedImageSamplerPass>());
1196+
}
1197+
11911198
} // namespace spvtools
11921199

11931200
extern "C" {

source/opt/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
#include "source/opt/scalar_replacement_pass.h"
7878
#include "source/opt/set_spec_constant_default_value_pass.h"
7979
#include "source/opt/simplification_pass.h"
80+
#include "source/opt/split_combined_image_sampler_pass.h"
8081
#include "source/opt/spread_volatile_semantics.h"
8182
#include "source/opt/ssa_rewrite_pass.h"
8283
#include "source/opt/strength_reduction_pass.h"

0 commit comments

Comments
 (0)