Skip to content

Commit 78b881b

Browse files
authored
opt: Fold OpCompositeExtract feeding from OpCopyLogical or OpLoad (#6614)
- CopyLogicalFeedingExtract: Hoist OpCompositeExtract before OpCopyLogical. If the input to an OpCompositeExtract is an OpCopyLogical, we can extract from the original composite and then copy the result. - LoadFeedingExtract: Change OpLoad + OpCompositeExtract to OpAccessChain + OpLoad. If the input to an OpCompositeExtract is an OpLoad, we can load the specific element instead of the entire composite. This is restricted to non-Function/Private storage classes to avoid interfering with local-access-chain-convert. Updated fold_test.cpp with 8 new test cases and updated existing tests to use SPV_ENV_UNIVERSAL_1_5. Fixes #6611
1 parent ff5c503 commit 78b881b

5 files changed

Lines changed: 529 additions & 15 deletions

File tree

source/opt/folding_rules.cpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,174 @@ uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
21932193
return type_id;
21942194
}
21952195

2196+
// If the input to an OpCompositeExtract is an OpCopyLogical, then we can
2197+
// hoist the extraction before the copy.
2198+
bool CopyLogicalFeedingExtract(IRContext* context, Instruction* inst,
2199+
const std::vector<const analysis::Constant*>&) {
2200+
assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2201+
"Wrong opcode. Should be OpCompositeExtract.");
2202+
2203+
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2204+
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2205+
Instruction* cinst = def_use_mgr->GetDef(cid);
2206+
2207+
if (cinst->opcode() != spv::Op::OpCopyLogical) {
2208+
return false;
2209+
}
2210+
2211+
uint32_t original_composite_id = cinst->GetSingleWordInOperand(0);
2212+
Instruction* original_composite_inst =
2213+
def_use_mgr->GetDef(original_composite_id);
2214+
2215+
std::vector<uint32_t> indices;
2216+
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2217+
indices.push_back(inst->GetSingleWordInOperand(i));
2218+
}
2219+
2220+
uint32_t original_element_type_id =
2221+
GetElementType(original_composite_inst->type_id(), inst->begin() + 3,
2222+
inst->end(), def_use_mgr);
2223+
assert(original_element_type_id != 0 &&
2224+
"Could not find the element type. Invalid SPIR-V.");
2225+
2226+
InstructionBuilder ir_builder(
2227+
context, inst,
2228+
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2229+
2230+
Instruction* new_extract = ir_builder.AddCompositeExtract(
2231+
original_element_type_id, original_composite_id, indices);
2232+
2233+
if (original_element_type_id == inst->type_id())
2234+
inst->SetOpcode(spv::Op::OpCopyObject);
2235+
else
2236+
inst->SetOpcode(spv::Op::OpCopyLogical);
2237+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_extract->result_id()}}});
2238+
return true;
2239+
}
2240+
2241+
// If the input to an OpCompositeExtract is an OpLoad, we can change the
2242+
// load into a load of an OpAccessChain.
2243+
bool LoadFeedingExtract(IRContext* context, Instruction* inst,
2244+
const std::vector<const analysis::Constant*>&) {
2245+
assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2246+
"Wrong opcode. Should be OpCompositeExtract.");
2247+
2248+
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2249+
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2250+
Instruction* cinst = def_use_mgr->GetDef(cid);
2251+
2252+
if (cinst->opcode() != spv::Op::OpLoad) {
2253+
return false;
2254+
}
2255+
2256+
Instruction* composite_type_inst = def_use_mgr->GetDef(cinst->type_id());
2257+
if (composite_type_inst->opcode() != spv::Op::OpTypeStruct &&
2258+
composite_type_inst->opcode() != spv::Op::OpTypeArray) {
2259+
return false;
2260+
}
2261+
2262+
// Check the memory operands.
2263+
if (cinst->NumInOperands() > 1) {
2264+
uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1);
2265+
if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Volatile)) {
2266+
return false;
2267+
}
2268+
}
2269+
2270+
uint32_t ptr_id = cinst->GetSingleWordInOperand(0);
2271+
Instruction* ptr_inst = def_use_mgr->GetDef(ptr_id);
2272+
Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_inst->type_id());
2273+
assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer);
2274+
spv::StorageClass storage_class =
2275+
static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
2276+
2277+
// If the storage class is Function or Private, we do not want to fold.
2278+
// These are the storage classes that the local-access-chain-convert pass
2279+
// works on.
2280+
if (storage_class == spv::StorageClass::Function ||
2281+
storage_class == spv::StorageClass::Private) {
2282+
return false;
2283+
}
2284+
2285+
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2286+
analysis::TypeManager* type_mgr = context->get_type_mgr();
2287+
std::vector<uint32_t> index_ids;
2288+
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2289+
uint32_t index = inst->GetSingleWordInOperand(i);
2290+
const analysis::Constant* index_const =
2291+
const_mgr->GetConstant(type_mgr->GetUIntType(), {index});
2292+
index_ids.push_back(
2293+
const_mgr->GetDefiningInstruction(index_const)->result_id());
2294+
}
2295+
2296+
InstructionBuilder ir_builder(
2297+
context, cinst,
2298+
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2299+
2300+
uint32_t element_ptr_type_id =
2301+
type_mgr->FindPointerToType(inst->type_id(), storage_class);
2302+
if (element_ptr_type_id == 0) {
2303+
return false;
2304+
}
2305+
2306+
Instruction* access_chain =
2307+
ir_builder.AddAccessChain(element_ptr_type_id, ptr_id, index_ids);
2308+
std::vector<Operand> load_operands;
2309+
load_operands.push_back({SPV_OPERAND_TYPE_ID, {access_chain->result_id()}});
2310+
2311+
if (cinst->NumInOperands() > 1) {
2312+
uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1);
2313+
load_operands.push_back(
2314+
{SPV_OPERAND_TYPE_MEMORY_ACCESS, {memory_access_mask}});
2315+
2316+
uint32_t current_operand_index = 2;
2317+
if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Aligned)) {
2318+
uint32_t original_alignment =
2319+
cinst->GetSingleWordInOperand(current_operand_index);
2320+
2321+
std::vector<uint32_t> extract_indices;
2322+
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2323+
extract_indices.push_back(inst->GetSingleWordInOperand(i));
2324+
}
2325+
2326+
std::optional<uint32_t> offset =
2327+
type_mgr->GetType(cinst->type_id())->GetByteOffset(extract_indices);
2328+
if (!offset) {
2329+
return false;
2330+
}
2331+
2332+
uint32_t new_alignment = original_alignment;
2333+
if (*offset != 0) {
2334+
uint32_t offset_alignment = *offset & ~(*offset - 1);
2335+
new_alignment = std::min(original_alignment, offset_alignment);
2336+
}
2337+
2338+
load_operands.push_back(
2339+
{SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, {new_alignment}});
2340+
current_operand_index++;
2341+
}
2342+
2343+
// Copy the remaining operands
2344+
for (; current_operand_index < cinst->NumInOperands();
2345+
++current_operand_index) {
2346+
load_operands.push_back(cinst->GetInOperand(current_operand_index));
2347+
}
2348+
}
2349+
2350+
uint32_t load_result_id = context->TakeNextId();
2351+
if (load_result_id == 0) return false;
2352+
2353+
std::unique_ptr<Instruction> new_load_inst(
2354+
new Instruction(context, spv::Op::OpLoad, inst->type_id(), load_result_id,
2355+
load_operands));
2356+
Instruction* new_load = ir_builder.AddInstruction(std::move(new_load_inst));
2357+
2358+
inst->SetOpcode(spv::Op::OpCopyObject);
2359+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_load->result_id()}}});
2360+
2361+
return true;
2362+
}
2363+
21962364
// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
21972365
// to index into a composite object, excluding the last index. The two
21982366
// instructions must have the same opcode, and be either OpCompositeExtract or
@@ -4309,6 +4477,8 @@ void FoldingRules::AddFoldingRules() {
43094477
CompositeConstructFeedingExtract);
43104478
rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
43114479
rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
4480+
rules_[spv::Op::OpCompositeExtract].push_back(CopyLogicalFeedingExtract);
4481+
rules_[spv::Op::OpCompositeExtract].push_back(LoadFeedingExtract);
43124482

43134483
rules_[spv::Op::OpCompositeInsert].push_back(
43144484
CompositeInsertToCompositeConstruct);

source/opt/types.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,78 @@ uint64_t Type::NumberOfComponents() const {
303303
}
304304
}
305305

306+
std::optional<uint32_t> Type::GetByteOffset(
307+
const std::vector<uint32_t>& access_chain) const {
308+
uint32_t offset = 0;
309+
const Type* current_type = this;
310+
for (uint32_t index : access_chain) {
311+
if (const Struct* struct_type = current_type->AsStruct()) {
312+
std::optional<uint32_t> member_offset;
313+
for (const auto& deco : struct_type->element_decorations()) {
314+
if (deco.first != index) continue;
315+
for (const auto& inst : deco.second) {
316+
if (inst[0] == uint32_t(spv::Decoration::Offset)) {
317+
member_offset = inst[1];
318+
break;
319+
}
320+
}
321+
}
322+
if (!member_offset) return {};
323+
offset += *member_offset;
324+
current_type = struct_type->element_types()[index];
325+
} else if (const Array* array_type = current_type->AsArray()) {
326+
std::optional<uint32_t> array_stride;
327+
for (const auto& deco : array_type->decorations()) {
328+
if (deco[0] == uint32_t(spv::Decoration::ArrayStride)) {
329+
array_stride = deco[1];
330+
break;
331+
}
332+
}
333+
if (!array_stride) return {};
334+
offset += *array_stride * index;
335+
current_type = array_type->element_type();
336+
} else if (const RuntimeArray* runtime_array_type =
337+
current_type->AsRuntimeArray()) {
338+
std::optional<uint32_t> array_stride;
339+
for (const auto& deco : runtime_array_type->decorations()) {
340+
if (deco[0] == uint32_t(spv::Decoration::ArrayStride)) {
341+
array_stride = deco[1];
342+
break;
343+
}
344+
}
345+
if (!array_stride) return {};
346+
offset += *array_stride * index;
347+
current_type = runtime_array_type->element_type();
348+
} else if (const Matrix* matrix_type = current_type->AsMatrix()) {
349+
std::optional<uint32_t> matrix_stride;
350+
for (const auto& deco : matrix_type->decorations()) {
351+
if (deco[0] == uint32_t(spv::Decoration::MatrixStride)) {
352+
matrix_stride = deco[1];
353+
break;
354+
}
355+
}
356+
if (!matrix_stride) return {};
357+
offset += *matrix_stride * index;
358+
current_type = matrix_type->element_type();
359+
} else if (const Vector* vector_type = current_type->AsVector()) {
360+
const Type* component_type = vector_type->element_type();
361+
uint32_t component_size = 0;
362+
if (component_type->AsInteger()) {
363+
component_size = component_type->AsInteger()->width() / 8;
364+
} else if (component_type->AsFloat()) {
365+
component_size = component_type->AsFloat()->width() / 8;
366+
} else {
367+
return {};
368+
}
369+
offset += component_size * index;
370+
current_type = component_type;
371+
} else {
372+
return {};
373+
}
374+
}
375+
return offset;
376+
}
377+
306378
bool Integer::IsSameImpl(const Type* that, IsSameCache*) const {
307379
const Integer* it = that->AsInteger();
308380
return it && width_ == it->width_ && signed_ == it->signed_ &&

source/opt/types.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <map>
2323
#include <memory>
24+
#include <optional>
2425
#include <set>
2526
#include <string>
2627
#include <unordered_map>
@@ -196,6 +197,13 @@ class Type {
196197
// non-composite type.
197198
uint64_t NumberOfComponents() const;
198199

200+
// Returns the byte offset of the member of this type that is identified
201+
// by |access_chain|. The vector |access_chain| is a series of integers that
202+
// are used to pick members as in the |OpCompositeExtract| instructions.
203+
// Returns {} if the offset cannot be computed.
204+
std::optional<uint32_t> GetByteOffset(
205+
const std::vector<uint32_t>& access_chain) const;
206+
199207
// A bunch of methods for casting this type to a given type. Returns this if
200208
// the cast can be done, nullptr otherwise.
201209
// clang-format off

0 commit comments

Comments
 (0)