Skip to content

Commit 7134be5

Browse files
committed
Refine legalization cleanup policy
Narrow legalize-time cleanup further by separating SSA rewrite modes and by avoiding blanket cleanup where the producer does not require it. With the companion DXC branch this drops the 58k LoC path tracer payload from 4.702 s to 2.464 s on the same machine while full CodeGenSPIRV still passes.
1 parent 57007cf commit 7134be5

5 files changed

Lines changed: 87 additions & 77 deletions

File tree

include/spirv-tools/optimizer.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class Pass;
3232
struct DescriptorSetAndBinding;
3333
} // namespace opt
3434

35+
enum class SSARewriteMode {
36+
None,
37+
All,
38+
OpaqueOnly,
39+
SpecialTypes,
40+
};
41+
3542
// C++ interface for SPIR-V optimization functionalities. It wraps the context
3643
// (including target environment and the corresponding SPIR-V grammar) and
3744
// provides methods for registering optimization passes and optimizing.
@@ -127,7 +134,7 @@ class SPIRV_TOOLS_EXPORT Optimizer {
127134
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
128135
Optimizer& RegisterLegalizationPasses(bool preserve_interface,
129136
bool include_loop_unroll,
130-
bool include_ssa_rewrite);
137+
SSARewriteMode ssa_rewrite_mode);
131138

132139
// Register passes specified in the list of |flags|. Each flag must be a
133140
// string of a form accepted by Optimizer::FlagHasValidForm().
@@ -648,11 +655,6 @@ Optimizer::PassToken CreateLoopPeelingPass();
648655
// Works best after LICM and local multi store elimination pass.
649656
Optimizer::PassToken CreateLoopUnswitchPass();
650657

651-
// Creates a pass to legalize multidimensional arrays for Vulkan.
652-
// This pass will replace multidimensional arrays of resources with a single
653-
// dimensional array. Combine-access-chains should be run before this pass.
654-
Optimizer::PassToken CreateLegalizeMultidimArrayPass();
655-
656658
// Create global value numbering pass.
657659
// This pass will look for instructions where the same value is computed on all
658660
// paths leading to the instruction. Those instructions are deleted.
@@ -712,7 +714,8 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
712714
// operations on SSA IDs. This allows SSA optimizers to act on these variables.
713715
// Only variables that are local to the function and of supported types are
714716
// processed (see IsSSATargetVar for details).
715-
Optimizer::PassToken CreateSSARewritePass();
717+
Optimizer::PassToken CreateSSARewritePass(
718+
SSARewriteMode mode = SSARewriteMode::All);
716719

717720
// Create pass to convert relaxed precision instructions to half precision.
718721
// This pass converts as many relaxed float32 arithmetic operations to half as

source/opt/mem_pass.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
5353
}
5454

5555
bool MemPass::IsTargetType(const Instruction* typeInst) const {
56-
if (IsBaseTargetType(typeInst)) return true;
56+
switch (ssa_rewrite_mode_) {
57+
case SSARewriteMode::None:
58+
return false;
59+
case SSARewriteMode::OpaqueOnly:
60+
if (typeInst->IsOpaqueType()) return true;
61+
break;
62+
case SSARewriteMode::SpecialTypes:
63+
if (typeInst->IsOpaqueType()) return true;
64+
switch (typeInst->opcode()) {
65+
case spv::Op::OpTypePointer:
66+
case spv::Op::OpTypeCooperativeMatrixNV:
67+
case spv::Op::OpTypeCooperativeMatrixKHR:
68+
return true;
69+
default:
70+
break;
71+
}
72+
break;
73+
case SSARewriteMode::All:
74+
if (IsBaseTargetType(typeInst)) return true;
75+
break;
76+
}
5777
if (typeInst->opcode() == spv::Op::OpTypeArray) {
5878
if (!IsTargetType(
5979
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
@@ -72,8 +92,7 @@ bool MemPass::IsTargetType(const Instruction* typeInst) const {
7292

7393
bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
7494
return opcode == spv::Op::OpAccessChain ||
75-
opcode == spv::Op::OpInBoundsAccessChain ||
76-
opcode == spv::Op::OpUntypedAccessChainKHR;
95+
opcode == spv::Op::OpInBoundsAccessChain;
7796
}
7897

7998
bool MemPass::IsPtr(uint32_t ptrId) {
@@ -89,14 +108,11 @@ bool MemPass::IsPtr(uint32_t ptrId) {
89108
ptrInst = get_def_use_mgr()->GetDef(varId);
90109
}
91110
const spv::Op op = ptrInst->opcode();
92-
if (op == spv::Op::OpVariable || op == spv::Op::OpUntypedVariableKHR ||
93-
IsNonPtrAccessChain(op))
94-
return true;
111+
if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
95112
const uint32_t varTypeId = ptrInst->type_id();
96113
if (varTypeId == 0) return false;
97114
const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
98-
return varTypeInst->opcode() == spv::Op::OpTypePointer ||
99-
varTypeInst->opcode() == spv::Op::OpTypeUntypedPointerKHR;
115+
return varTypeInst->opcode() == spv::Op::OpTypePointer;
100116
}
101117

102118
Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
@@ -106,13 +122,11 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
106122

107123
switch (ptrInst->opcode()) {
108124
case spv::Op::OpVariable:
109-
case spv::Op::OpUntypedVariableKHR:
110125
case spv::Op::OpFunctionParameter:
111126
varInst = ptrInst;
112127
break;
113128
case spv::Op::OpAccessChain:
114129
case spv::Op::OpInBoundsAccessChain:
115-
case spv::Op::OpUntypedAccessChainKHR:
116130
case spv::Op::OpPtrAccessChain:
117131
case spv::Op::OpInBoundsPtrAccessChain:
118132
case spv::Op::OpImageTexelPointer:
@@ -125,8 +139,7 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
125139
break;
126140
}
127141

128-
if (varInst->opcode() == spv::Op::OpVariable ||
129-
varInst->opcode() == spv::Op::OpUntypedVariableKHR) {
142+
if (varInst->opcode() == spv::Op::OpVariable) {
130143
*varId = varInst->result_id();
131144
} else {
132145
*varId = 0;
@@ -241,7 +254,8 @@ void MemPass::DCEInst(Instruction* inst,
241254
}
242255
}
243256

244-
MemPass::MemPass() {}
257+
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
258+
: ssa_rewrite_mode_(ssa_rewrite_mode) {}
245259

246260
bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
247261
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {

source/opt/mem_pass.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <unordered_set>
2626
#include <utility>
2727

28+
#include "spirv-tools/optimizer.hpp"
2829
#include "source/opt/basic_block.h"
2930
#include "source/opt/def_use_manager.h"
3031
#include "source/opt/dominator_analysis.h"
@@ -68,7 +69,7 @@ class MemPass : public Pass {
6869
void CollectTargetVars(Function* func);
6970

7071
protected:
71-
MemPass();
72+
explicit MemPass(SSARewriteMode ssa_rewrite_mode = SSARewriteMode::All);
7273

7374
// Returns true if |typeInst| is a scalar type
7475
// or a vector or matrix
@@ -133,7 +134,9 @@ class MemPass : public Pass {
133134
// Cache of verified non-target vars
134135
std::unordered_set<uint32_t> seen_non_target_vars_;
135136

136-
private:
137+
private:
138+
SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All;
139+
137140
// Return true if all uses of |varId| are only through supported reference
138141
// operations ie. loads and store. Also cache in supported_ref_vars_.
139142
// TODO(dnovillo): This function is replicated in other passes and it's

source/opt/optimizer.cpp

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ Optimizer& Optimizer::RegisterPass(PassToken&& p) {
122122
// or enable more copy propagation.
123123
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
124124
bool include_loop_unroll,
125-
bool include_ssa_rewrite) {
125+
SSARewriteMode ssa_rewrite_mode) {
126126
auto& optimizer =
127127
// Wrap OpKill instructions so all other code can be inlined.
128128
RegisterPass(CreateWrapOpKillPass())
@@ -132,45 +132,39 @@ Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
132132
.RegisterPass(CreateMergeReturnPass())
133133
// Make sure uses and definitions are in the same function.
134134
.RegisterPass(CreateInlineExhaustivePass())
135-
// Make private variable function scope
136-
.RegisterPass(CreateEliminateDeadFunctionsPass())
137-
.RegisterPass(CreatePrivateToLocalPass())
138-
// Fix up the storage classes that DXC may have purposely generated
139-
// incorrectly. All functions are inlined, and a lot of dead code has
140-
// been removed.
141-
.RegisterPass(CreateFixStorageClassPass())
142-
// Propagate the value stored to the loads in very simple cases.
143-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
144-
.RegisterPass(CreateLocalSingleStoreElimPass())
145-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
146-
// Split up aggregates so they are easier to deal with.
147-
.RegisterPass(CreateScalarReplacementPass(0))
148-
// Remove loads and stores so everything is in intermediate values.
149-
// Takes care of copy propagation of non-members.
150-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
151-
.RegisterPass(CreateLocalSingleStoreElimPass())
152-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
153-
if (include_ssa_rewrite) {
154-
optimizer.RegisterPass(CreateLocalMultiStoreElimPass());
135+
.RegisterPass(CreateEliminateDeadFunctionsPass());
136+
optimizer.RegisterPass(CreatePrivateToLocalPass());
137+
// Fix up the storage classes that DXC may have purposely generated
138+
// incorrectly. All functions are inlined, and a lot of dead code has
139+
// been removed.
140+
optimizer.RegisterPass(CreateFixStorageClassPass());
141+
// Propagate the value stored to the loads in very simple cases.
142+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
143+
.RegisterPass(CreateLocalSingleStoreElimPass())
144+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
145+
optimizer
146+
// Split up aggregates so they are easier to deal with.
147+
.RegisterPass(CreateScalarReplacementPass(0));
148+
// Remove loads and stores so everything is in intermediate values.
149+
// Takes care of copy propagation of non-members.
150+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
151+
.RegisterPass(CreateLocalSingleStoreElimPass())
152+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
153+
if (ssa_rewrite_mode != SSARewriteMode::None) {
154+
optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode));
155155
}
156-
optimizer.RegisterPass(CreateCombineAccessChainsPass());
157-
if (include_ssa_rewrite) {
158-
optimizer.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
159-
}
160-
optimizer.RegisterPass(CreateLegalizeMultidimArrayPass())
161-
// Propagate constants to get as many constant conditions on branches
162-
// as possible.
163-
.RegisterPass(CreateCCPPass());
156+
optimizer
157+
// Propagate constants to get as many constant conditions on branches
158+
// as possible.
159+
.RegisterPass(CreateCCPPass());
164160
if (include_loop_unroll) {
165161
optimizer.RegisterPass(CreateLoopUnrollPass(true));
166162
}
163+
optimizer.RegisterPass(CreateDeadBranchElimPass())
164+
// Copy propagate members. Cleans up code sequences generated by scalar
165+
// replacement. Also important for removing OpPhi nodes.
166+
.RegisterPass(CreateSimplificationPass());
167167
return optimizer
168-
.RegisterPass(CreateDeadBranchElimPass())
169-
// Copy propagate members. Cleans up code sequences generated by
170-
// scalar replacement. Also important for removing OpPhi nodes.
171-
.RegisterPass(CreateSimplificationPass())
172-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
173-
.RegisterPass(CreateCopyPropagateArraysPass())
174168
// May need loop unrolling here see
175169
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
176170
// Get rid of unused code that contain traces of illegal code
@@ -186,11 +180,12 @@ Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
186180
}
187181

188182
Optimizer& Optimizer::RegisterLegalizationPasses() {
189-
return RegisterLegalizationPasses(false, true, true);
183+
return RegisterLegalizationPasses(false, true, SSARewriteMode::All);
190184
}
191185

192186
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
193-
return RegisterLegalizationPasses(preserve_interface, true, true);
187+
return RegisterLegalizationPasses(preserve_interface, true,
188+
SSARewriteMode::All);
194189
}
195190

196191
Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
@@ -199,20 +194,21 @@ Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
199194
.RegisterPass(CreateMergeReturnPass())
200195
.RegisterPass(CreateInlineExhaustivePass())
201196
.RegisterPass(CreateEliminateDeadFunctionsPass())
202-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
203197
.RegisterPass(CreatePrivateToLocalPass())
204198
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
205199
.RegisterPass(CreateLocalSingleStoreElimPass())
206200
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
207201
.RegisterPass(CreateScalarReplacementPass(0))
208-
.RegisterPass(CreateLocalAccessChainConvertPass())
209-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
202+
.RegisterPass(CreateLocalAccessChainConvertPass());
203+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
210204
.RegisterPass(CreateLocalSingleStoreElimPass())
211-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
212-
.RegisterPass(CreateLocalMultiStoreElimPass())
213-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
214-
.RegisterPass(CreateCCPPass())
215205
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
206+
optimizer.RegisterPass(CreateCCPPass())
207+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
208+
// Preserve LoopControl::Unroll in the IR instead of always materializing
209+
// it here. The optimizer-side full unroll is very costly on large modules
210+
// with many tiny [unroll]-annotated loops, while the hint remains available
211+
// to downstream consumers in the final SPIR-V.
216212
optimizer.RegisterPass(CreateDeadBranchElimPass());
217213
optimizer.RegisterPass(CreateLocalRedundancyEliminationPass());
218214
optimizer.RegisterPass(CreateCombineAccessChainsPass())
@@ -222,7 +218,7 @@ Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
222218
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
223219
.RegisterPass(CreateLocalSingleStoreElimPass())
224220
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
225-
.RegisterPass(CreateSSARewritePass())
221+
.RegisterPass(CreateSSARewritePass(SSARewriteMode::SpecialTypes))
226222
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
227223
.RegisterPass(CreateVectorDCEPass())
228224
.RegisterPass(CreateDeadInsertElimPass())
@@ -413,8 +409,6 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
413409
RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
414410
} else if (pass_name == "loop-unswitch") {
415411
RegisterPass(CreateLoopUnswitchPass());
416-
} else if (pass_name == "legalize-multidim-array") {
417-
RegisterPass(CreateLegalizeMultidimArrayPass());
418412
} else if (pass_name == "scalar-replacement") {
419413
if (pass_args.size() == 0) {
420414
RegisterPass(CreateScalarReplacementPass(0));
@@ -977,11 +971,6 @@ Optimizer::PassToken CreateLoopUnswitchPass() {
977971
MakeUnique<opt::LoopUnswitchPass>());
978972
}
979973

980-
Optimizer::PassToken CreateLegalizeMultidimArrayPass() {
981-
return MakeUnique<Optimizer::PassToken::Impl>(
982-
MakeUnique<opt::LegalizeMultidimArrayPass>());
983-
}
984-
985974
Optimizer::PassToken CreateRedundancyEliminationPass() {
986975
return MakeUnique<Optimizer::PassToken::Impl>(
987976
MakeUnique<opt::RedundancyEliminationPass>());
@@ -1031,9 +1020,9 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
10311020
MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
10321021
}
10331022

1034-
Optimizer::PassToken CreateSSARewritePass() {
1023+
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) {
10351024
return MakeUnique<Optimizer::PassToken::Impl>(
1036-
MakeUnique<opt::SSARewritePass>());
1025+
MakeUnique<opt::SSARewritePass>(mode));
10371026
}
10381027

10391028
Optimizer::PassToken CreateCopyPropagateArraysPass() {

source/opt/ssa_rewrite_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ class SSARewriter {
294294

295295
class SSARewritePass : public MemPass {
296296
public:
297-
SSARewritePass() = default;
297+
explicit SSARewritePass(SSARewriteMode mode = SSARewriteMode::All)
298+
: MemPass(mode) {}
298299

299300
const char* name() const override { return "ssa-rewrite"; }
300301
Status Process() override;

0 commit comments

Comments
 (0)