@@ -2193,6 +2193,174 @@ uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
21932193 return type_id;
21942194}
21952195
2196+ // If the input to an OpCompositeExtract is an OpCopyLogical, then we can
2197+ // hoist the extraction before the copy.
2198+ bool CopyLogicalFeedingExtract (IRContext* context, Instruction* inst,
2199+ const std::vector<const analysis::Constant*>&) {
2200+ assert (inst->opcode () == spv::Op::OpCompositeExtract &&
2201+ " Wrong opcode. Should be OpCompositeExtract." );
2202+
2203+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr ();
2204+ uint32_t cid = inst->GetSingleWordInOperand (kExtractCompositeIdInIdx );
2205+ Instruction* cinst = def_use_mgr->GetDef (cid);
2206+
2207+ if (cinst->opcode () != spv::Op::OpCopyLogical) {
2208+ return false ;
2209+ }
2210+
2211+ uint32_t original_composite_id = cinst->GetSingleWordInOperand (0 );
2212+ Instruction* original_composite_inst =
2213+ def_use_mgr->GetDef (original_composite_id);
2214+
2215+ std::vector<uint32_t > indices;
2216+ for (uint32_t i = 1 ; i < inst->NumInOperands (); ++i) {
2217+ indices.push_back (inst->GetSingleWordInOperand (i));
2218+ }
2219+
2220+ uint32_t original_element_type_id =
2221+ GetElementType (original_composite_inst->type_id (), inst->begin () + 3 ,
2222+ inst->end (), def_use_mgr);
2223+ assert (original_element_type_id != 0 &&
2224+ " Could not find the element type. Invalid SPIR-V." );
2225+
2226+ InstructionBuilder ir_builder (
2227+ context, inst,
2228+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping );
2229+
2230+ Instruction* new_extract = ir_builder.AddCompositeExtract (
2231+ original_element_type_id, original_composite_id, indices);
2232+
2233+ if (original_element_type_id == inst->type_id ())
2234+ inst->SetOpcode (spv::Op::OpCopyObject);
2235+ else
2236+ inst->SetOpcode (spv::Op::OpCopyLogical);
2237+ inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {new_extract->result_id ()}}});
2238+ return true ;
2239+ }
2240+
2241+ // If the input to an OpCompositeExtract is an OpLoad, we can change the
2242+ // load into a load of an OpAccessChain.
2243+ bool LoadFeedingExtract (IRContext* context, Instruction* inst,
2244+ const std::vector<const analysis::Constant*>&) {
2245+ assert (inst->opcode () == spv::Op::OpCompositeExtract &&
2246+ " Wrong opcode. Should be OpCompositeExtract." );
2247+
2248+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr ();
2249+ uint32_t cid = inst->GetSingleWordInOperand (kExtractCompositeIdInIdx );
2250+ Instruction* cinst = def_use_mgr->GetDef (cid);
2251+
2252+ if (cinst->opcode () != spv::Op::OpLoad) {
2253+ return false ;
2254+ }
2255+
2256+ Instruction* composite_type_inst = def_use_mgr->GetDef (cinst->type_id ());
2257+ if (composite_type_inst->opcode () != spv::Op::OpTypeStruct &&
2258+ composite_type_inst->opcode () != spv::Op::OpTypeArray) {
2259+ return false ;
2260+ }
2261+
2262+ // Check the memory operands.
2263+ if (cinst->NumInOperands () > 1 ) {
2264+ uint32_t memory_access_mask = cinst->GetSingleWordInOperand (1 );
2265+ if (memory_access_mask & uint32_t (spv::MemoryAccessMask::Volatile)) {
2266+ return false ;
2267+ }
2268+ }
2269+
2270+ uint32_t ptr_id = cinst->GetSingleWordInOperand (0 );
2271+ Instruction* ptr_inst = def_use_mgr->GetDef (ptr_id);
2272+ Instruction* ptr_type_inst = def_use_mgr->GetDef (ptr_inst->type_id ());
2273+ assert (ptr_type_inst->opcode () == spv::Op::OpTypePointer);
2274+ spv::StorageClass storage_class =
2275+ static_cast <spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand (0 ));
2276+
2277+ // If the storage class is Function or Private, we do not want to fold.
2278+ // These are the storage classes that the local-access-chain-convert pass
2279+ // works on.
2280+ if (storage_class == spv::StorageClass::Function ||
2281+ storage_class == spv::StorageClass::Private) {
2282+ return false ;
2283+ }
2284+
2285+ analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
2286+ analysis::TypeManager* type_mgr = context->get_type_mgr ();
2287+ std::vector<uint32_t > index_ids;
2288+ for (uint32_t i = 1 ; i < inst->NumInOperands (); ++i) {
2289+ uint32_t index = inst->GetSingleWordInOperand (i);
2290+ const analysis::Constant* index_const =
2291+ const_mgr->GetConstant (type_mgr->GetUIntType (), {index});
2292+ index_ids.push_back (
2293+ const_mgr->GetDefiningInstruction (index_const)->result_id ());
2294+ }
2295+
2296+ InstructionBuilder ir_builder (
2297+ context, cinst,
2298+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping );
2299+
2300+ uint32_t element_ptr_type_id =
2301+ type_mgr->FindPointerToType (inst->type_id (), storage_class);
2302+ if (element_ptr_type_id == 0 ) {
2303+ return false ;
2304+ }
2305+
2306+ Instruction* access_chain =
2307+ ir_builder.AddAccessChain (element_ptr_type_id, ptr_id, index_ids);
2308+ std::vector<Operand> load_operands;
2309+ load_operands.push_back ({SPV_OPERAND_TYPE_ID, {access_chain->result_id ()}});
2310+
2311+ if (cinst->NumInOperands () > 1 ) {
2312+ uint32_t memory_access_mask = cinst->GetSingleWordInOperand (1 );
2313+ load_operands.push_back (
2314+ {SPV_OPERAND_TYPE_MEMORY_ACCESS, {memory_access_mask}});
2315+
2316+ uint32_t current_operand_index = 2 ;
2317+ if (memory_access_mask & uint32_t (spv::MemoryAccessMask::Aligned)) {
2318+ uint32_t original_alignment =
2319+ cinst->GetSingleWordInOperand (current_operand_index);
2320+
2321+ std::vector<uint32_t > extract_indices;
2322+ for (uint32_t i = 1 ; i < inst->NumInOperands (); ++i) {
2323+ extract_indices.push_back (inst->GetSingleWordInOperand (i));
2324+ }
2325+
2326+ std::optional<uint32_t > offset =
2327+ type_mgr->GetType (cinst->type_id ())->GetByteOffset (extract_indices);
2328+ if (!offset) {
2329+ return false ;
2330+ }
2331+
2332+ uint32_t new_alignment = original_alignment;
2333+ if (*offset != 0 ) {
2334+ uint32_t offset_alignment = *offset & ~(*offset - 1 );
2335+ new_alignment = std::min (original_alignment, offset_alignment);
2336+ }
2337+
2338+ load_operands.push_back (
2339+ {SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, {new_alignment}});
2340+ current_operand_index++;
2341+ }
2342+
2343+ // Copy the remaining operands
2344+ for (; current_operand_index < cinst->NumInOperands ();
2345+ ++current_operand_index) {
2346+ load_operands.push_back (cinst->GetInOperand (current_operand_index));
2347+ }
2348+ }
2349+
2350+ uint32_t load_result_id = context->TakeNextId ();
2351+ if (load_result_id == 0 ) return false ;
2352+
2353+ std::unique_ptr<Instruction> new_load_inst (
2354+ new Instruction (context, spv::Op::OpLoad, inst->type_id (), load_result_id,
2355+ load_operands));
2356+ Instruction* new_load = ir_builder.AddInstruction (std::move (new_load_inst));
2357+
2358+ inst->SetOpcode (spv::Op::OpCopyObject);
2359+ inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {new_load->result_id ()}}});
2360+
2361+ return true ;
2362+ }
2363+
21962364// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
21972365// to index into a composite object, excluding the last index. The two
21982366// instructions must have the same opcode, and be either OpCompositeExtract or
@@ -4309,6 +4477,8 @@ void FoldingRules::AddFoldingRules() {
43094477 CompositeConstructFeedingExtract);
43104478 rules_[spv::Op::OpCompositeExtract].push_back (VectorShuffleFeedingExtract ());
43114479 rules_[spv::Op::OpCompositeExtract].push_back (FMixFeedingExtract ());
4480+ rules_[spv::Op::OpCompositeExtract].push_back (CopyLogicalFeedingExtract);
4481+ rules_[spv::Op::OpCompositeExtract].push_back (LoadFeedingExtract);
43124482
43134483 rules_[spv::Op::OpCompositeInsert].push_back (
43144484 CompositeInsertToCompositeConstruct);
0 commit comments