@@ -129,7 +129,6 @@ class SROA_Helper {
129129 void RewriteMemIntrin (MemIntrinsic *MI, Value *OldV);
130130 void RewriteCall (CallInst *CI);
131131 void RewriteBitCast (BitCastInst *BCI);
132- void RewriteCallArg (CallInst *CI, unsigned ArgIdx, bool bIn, bool bOut);
133132};
134133
135134} // namespace
@@ -1478,6 +1477,46 @@ void isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset,
14781477 }
14791478}
14801479
1480+ static bool isUDTIntrinsicArg (CallInst *CI, unsigned OpIdx) {
1481+ if (HLOpcodeGroup::HLIntrinsic != GetHLOpcodeGroup (CI->getCalledFunction ()))
1482+ return false ;
1483+ switch (static_cast <IntrinsicOp>(GetHLOpcode (CI))) {
1484+ case IntrinsicOp::IOP_TraceRay:
1485+ if (OpIdx == HLOperandIndex::kTraceRayRayDescOpIdx )
1486+ return true ;
1487+ if (OpIdx == HLOperandIndex::kTraceRayPayloadPreOpIdx )
1488+ return true ;
1489+ break ;
1490+ case IntrinsicOp::IOP_ReportHit:
1491+ if (OpIdx == HLOperandIndex::kReportIntersectionAttributeOpIdx )
1492+ return true ;
1493+ break ;
1494+ case IntrinsicOp::IOP_CallShader:
1495+ if (OpIdx == HLOperandIndex::kCallShaderPayloadOpIdx )
1496+ return true ;
1497+ break ;
1498+ case IntrinsicOp::MOP_DxHitObject_FromRayQuery:
1499+ if (OpIdx ==
1500+ HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx )
1501+ return true ;
1502+ break ;
1503+ case IntrinsicOp::MOP_DxHitObject_TraceRay:
1504+ // TODO: Remove RayDesc for flattening
1505+ if (OpIdx == HLOperandIndex::kHitObjectTraceRay_RayDescOpIdx )
1506+ return true ;
1507+ if (OpIdx == HLOperandIndex::kHitObjectTraceRay_PayloadPreOpIdx )
1508+ return true ;
1509+ break ;
1510+ case IntrinsicOp::MOP_DxHitObject_Invoke:
1511+ if (OpIdx == HLOperandIndex::kHitObjectInvoke_PayloadOpIdx )
1512+ return true ;
1513+ break ;
1514+ default :
1515+ break ;
1516+ }
1517+ return false ;
1518+ }
1519+
14811520// / isSafeForScalarRepl - Check if instruction I is a safe use with regard to
14821521// / performing scalar replacement of alloca AI. The results are flagged in
14831522// / the Info parameter. Offset indicates the position within AI that is
@@ -1536,15 +1575,9 @@ void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info) {
15361575 if (HLOpcodeGroup::NotHL == group)
15371576 return MarkUnsafe (Info, User);
15381577 else if (HLOpcodeGroup::HLIntrinsic == group) {
1539- // TODO: should we check HL parameter type for UDT overload instead of
1540- // basing on IOP?
1541- IntrinsicOp opcode = static_cast <IntrinsicOp>(GetHLOpcode (CI));
1542- if (IntrinsicOp::IOP_TraceRay == opcode ||
1543- IntrinsicOp::MOP_DxHitObject_TraceRay == opcode ||
1544- IntrinsicOp::MOP_DxHitObject_Invoke == opcode ||
1545- IntrinsicOp::IOP_ReportHit == opcode ||
1546- IntrinsicOp::IOP_CallShader == opcode) {
1547- return MarkUnsafe (Info, User);
1578+ for (unsigned OpIdx = 0 ; OpIdx < CI->getNumArgOperands (); OpIdx++) {
1579+ if (CI->getArgOperand (OpIdx) == I && isUDTIntrinsicArg (CI, OpIdx))
1580+ return MarkUnsafe (Info, User);
15481581 }
15491582 }
15501583 } else {
@@ -2666,8 +2699,7 @@ void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
26662699// / replace OldVal with alloca and
26672700// / copy in copy out data between alloca and flattened NewElts
26682701// / in CallInst.
2669- void SROA_Helper::RewriteCallArg (CallInst *CI, unsigned ArgIdx, bool bIn,
2670- bool bOut) {
2702+ static void RewriteCallArg (CallInst *CI, unsigned ArgIdx, bool bIn, bool bOut) {
26712703 Function *F = CI->getParent ()->getParent ();
26722704 IRBuilder<> AllocaBuilder (dxilutil::FindAllocaInsertionPt (F));
26732705 const DataLayout &DL = F->getParent ()->getDataLayout ();
@@ -2678,16 +2710,60 @@ void SROA_Helper::RewriteCallArg(CallInst *CI, unsigned ArgIdx, bool bIn,
26782710 Value *Alloca = AllocaBuilder.CreateAlloca (userTyElt);
26792711 IRBuilder<> Builder (CI);
26802712 if (bIn) {
2681- MemCpyInst *cpy = cast<MemCpyInst>(Builder.CreateMemCpy (
2682- Alloca, userTyV, DL.getTypeAllocSize (userTyElt), false ));
2683- RewriteMemIntrin (cpy, cpy->getRawSource ());
2713+ Builder.CreateMemCpy (Alloca, userTyV, DL.getTypeAllocSize (userTyElt),
2714+ false );
26842715 }
26852716 CI->setArgOperand (ArgIdx, Alloca);
26862717 if (bOut) {
26872718 Builder.SetInsertPoint (CI->getNextNode ());
2688- MemCpyInst *cpy = cast<MemCpyInst>(Builder.CreateMemCpy (
2689- userTyV, Alloca, DL.getTypeAllocSize (userTyElt), false ));
2690- RewriteMemIntrin (cpy, cpy->getRawSource ());
2719+ Builder.CreateMemCpy (userTyV, Alloca, DL.getTypeAllocSize (userTyElt),
2720+ false );
2721+ }
2722+ }
2723+
2724+ static void copyIntrinsicUDTArgs (HLModule &HLM) {
2725+ // Iterate HLIntrinsic function users
2726+ // For specific intrinsics, use RewriteCallArg on UDT args
2727+ for (Function &F : HLM.GetModule ()->functions ()) {
2728+ if (F.isIntrinsic () || !F.isDeclaration ())
2729+ continue ;
2730+ if (GetHLOpcodeGroup (&F) != HLOpcodeGroup::HLIntrinsic)
2731+ continue ;
2732+ // Iterate users
2733+ for (User *U : F.users ()) {
2734+ if (CallInst *CI = dyn_cast<CallInst>(U)) {
2735+ switch (static_cast <IntrinsicOp>(GetHLOpcode (CI))) {
2736+ case IntrinsicOp::IOP_TraceRay:
2737+ // TODO: Remove RayDesc for flattening
2738+ RewriteCallArg (CI, HLOperandIndex::kTraceRayRayDescOpIdx ,
2739+ /* bIn*/ true , /* bOut*/ false );
2740+ RewriteCallArg (CI, HLOperandIndex::kTraceRayPayloadPreOpIdx ,
2741+ /* bIn*/ true , /* bOut*/ true );
2742+ break ;
2743+ case IntrinsicOp::IOP_ReportHit:
2744+ RewriteCallArg (CI, HLOperandIndex::kReportIntersectionAttributeOpIdx ,
2745+ /* bIn*/ true , /* bOut*/ false );
2746+ break ;
2747+ case IntrinsicOp::IOP_CallShader:
2748+ RewriteCallArg (CI, HLOperandIndex::kCallShaderPayloadOpIdx ,
2749+ /* bIn*/ true , /* bOut*/ true );
2750+ break ;
2751+ case IntrinsicOp::MOP_DxHitObject_FromRayQuery:
2752+ if (CI->getNumArgOperands () ==
2753+ HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_NumOp ) {
2754+ RewriteCallArg (
2755+ CI,
2756+ HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx ,
2757+ /* bIn*/ true , /* bOut*/ false );
2758+ }
2759+ break ;
2760+ case IntrinsicOp::MOP_DxHitObject_TraceRay:
2761+ RewriteCallArg (CI, HLOperandIndex::kHitObjectTraceRay_PayloadPreOpIdx ,
2762+ /* bIn*/ true , /* bOut*/ true );
2763+ break ;
2764+ }
2765+ }
2766+ }
26912767 }
26922768}
26932769
@@ -2741,10 +2817,23 @@ static CallInst *RewriteWithFlattenedHLIntrinsicCall(CallInst *CI,
27412817
27422818// / RewriteCall - Replace OldVal with flattened NewElts in CallInst.
27432819void SROA_Helper::RewriteCall (CallInst *CI) {
2744- HLOpcodeGroup group = GetHLOpcodeGroupByName (CI->getCalledFunction ());
2745- if (group != HLOpcodeGroup::NotHL) {
2820+ HLOpcodeGroup Group = GetHLOpcodeGroupByName (CI->getCalledFunction ());
2821+ if (Group != HLOpcodeGroup::NotHL) {
27462822 unsigned opcode = GetHLOpcode (CI);
2747- if (group == HLOpcodeGroup::HLIntrinsic) {
2823+ if (Group == HLOpcodeGroup::HLIntrinsic) {
2824+ // RayQuery this pointer replacement.
2825+ if (OldVal->getType ()->isPointerTy () &&
2826+ dxilutil::IsHLSLRayQueryType (
2827+ OldVal->getType ()->getPointerElementType ())) {
2828+ // For RayQuery methods, we want to replace the RayQuery this pointer
2829+ // with a load and use of the underlying handle value.
2830+ // This will allow elimination of RayQuery types earlier.
2831+ RewriteWithFlattenedHLIntrinsicCall (CI, OldVal, NewElts,
2832+ /* loadElts*/ true );
2833+ DeadInsts.push_back (CI);
2834+ return ;
2835+ }
2836+
27482837 IntrinsicOp IOP = static_cast <IntrinsicOp>(opcode);
27492838 switch (IOP) {
27502839 case IntrinsicOp::MOP_Append: {
@@ -2756,84 +2845,42 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
27562845 /* loadElts*/ false );
27572846 DeadInsts.push_back (CI);
27582847 } break ;
2759- case IntrinsicOp::IOP_TraceRay: {
2760- if (OldVal ==
2761- CI->getArgOperand (HLOperandIndex::kTraceRayRayDescOpIdx )) {
2762- RewriteCallArg (CI, HLOperandIndex::kTraceRayRayDescOpIdx ,
2763- /* bIn*/ true , /* bOut*/ false );
2764- } else {
2765- DXASSERT (OldVal ==
2766- CI->getArgOperand (HLOperandIndex::kTraceRayPayLoadOpIdx ),
2767- " else invalid TraceRay" );
2768- RewriteCallArg (CI, HLOperandIndex::kTraceRayPayLoadOpIdx ,
2769- /* bIn*/ true , /* bOut*/ true );
2770- }
2771- } break ;
2772- case IntrinsicOp::IOP_ReportHit: {
2773- RewriteCallArg (CI, HLOperandIndex::kReportIntersectionAttributeOpIdx ,
2774- /* bIn*/ true , /* bOut*/ false );
2775- } break ;
2776- case IntrinsicOp::IOP_CallShader: {
2777- RewriteCallArg (CI, HLOperandIndex::kCallShaderPayloadOpIdx ,
2778- /* bIn*/ true , /* bOut*/ true );
2779- } break ;
2780- case IntrinsicOp::MOP_DxHitObject_MakeMiss: {
2848+ // case IntrinsicOp::IOP_TraceRay:
2849+ // if (OldVal ==
2850+ // CI->getArgOperand(HLOperandIndex::kTraceRayRayDescOpIdx)) {
2851+ // // TODO: flatten RayDesc
2852+ // RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
2853+ // /*loadElts*/ true);
2854+ // DeadInsts.push_back(CI);
2855+ // }
2856+ // break;
2857+ // case IntrinsicOp::MOP_DxHitObject_TraceRay:
2858+ // if (OldVal ==
2859+ // CI->getArgOperand(HLOperandIndex::kHitObjectTraceRay_RayDescOpIdx)) {
2860+ // // TODO: flatten RayDesc
2861+ // RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts,
2862+ // /*loadElts*/ true);
2863+ // DeadInsts.push_back(CI);
2864+ // }
2865+ // break;
2866+ case IntrinsicOp::MOP_DxHitObject_MakeMiss:
27812867 if (OldVal ==
27822868 CI->getArgOperand (HLOperandIndex::kHitObjectMakeMissRayDescOpIdx )) {
27832869 RewriteWithFlattenedHLIntrinsicCall (CI, OldVal, NewElts,
27842870 /* loadElts*/ true );
27852871 DeadInsts.push_back (CI);
27862872 }
2787- } break ;
2788- case IntrinsicOp::MOP_TraceRayInline: {
2873+ break ;
2874+ case IntrinsicOp::MOP_TraceRayInline:
27892875 if (OldVal ==
27902876 CI->getArgOperand (HLOperandIndex::kTraceRayInlineRayDescOpIdx )) {
27912877 RewriteWithFlattenedHLIntrinsicCall (CI, OldVal, NewElts,
27922878 /* loadElts*/ true );
27932879 DeadInsts.push_back (CI);
27942880 break ;
27952881 }
2796- }
27972882 LLVM_FALLTHROUGH;
2798- case IntrinsicOp::MOP_DxHitObject_FromRayQuery: {
2799- const bool IsWithAttrs =
2800- CI->getNumArgOperands () ==
2801- HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_NumOp ;
2802- if (IsWithAttrs &&
2803- (OldVal ==
2804- CI->getArgOperand (
2805- HLOperandIndex::
2806- kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx ))) {
2807- RewriteCallArg (
2808- CI,
2809- HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx ,
2810- /* bIn*/ true , /* bOut*/ false );
2811- break ;
2812- }
2813-
2814- // For RayQuery methods, we want to replace the RayQuery this pointer
2815- // with a load and use of the underlying handle value.
2816- // This will allow elimination of RayQuery types earlier.
2817- RewriteWithFlattenedHLIntrinsicCall (CI, OldVal, NewElts,
2818- /* loadElts*/ true );
2819- DeadInsts.push_back (CI);
2820- break ;
2821- }
28222883 default :
2823- // RayQuery this pointer replacement.
2824- if (OldVal->getType ()->isPointerTy () &&
2825- CI->getNumArgOperands () >= HLOperandIndex::kHandleOpIdx &&
2826- OldVal == CI->getArgOperand (HLOperandIndex::kHandleOpIdx ) &&
2827- dxilutil::IsHLSLRayQueryType (
2828- OldVal->getType ()->getPointerElementType ())) {
2829- // For RayQuery methods, we want to replace the RayQuery this pointer
2830- // with a load and use of the underlying handle value.
2831- // This will allow elimination of RayQuery types earlier.
2832- RewriteWithFlattenedHLIntrinsicCall (CI, OldVal, NewElts,
2833- /* loadElts*/ true );
2834- DeadInsts.push_back (CI);
2835- break ;
2836- }
28372884 DXASSERT (0 , " cannot flatten hlsl intrinsic." );
28382885 }
28392886 }
@@ -4416,6 +4463,9 @@ class SROA_Parameter_HLSL : public ModulePass {
44164463 F->eraseFromParent ();
44174464 }
44184465
4466+ // Expand flattened copy-in/copy-out for intrinsic UDT args:
4467+ copyIntrinsicUDTArgs (*m_pHLModule);
4468+
44194469 // SROA globals and allocas.
44204470 SROAGlobalAndAllocas (*m_pHLModule, m_HasDbgInfo);
44214471
0 commit comments