Skip to content

Commit 39f76a3

Browse files
committed
Copy RayDesc args in too, copyIntrinsicUDTArgs -> copyIntrinsicAggArgs
This renames: - copyIntrinsicUDTArgs -> copyIntrinsicAggArgs - RewriteCallArg -> memcpyAggCallArg. Updated to copy the RayDesc aggregate arguments as well. This will fix the issue with a RayDesc argument provided directly from a cbuffer, since the incoming argument pointer will no longer be skipped by SROA. Expected IR for tests will need to be updated after flattening RayDesc, so holding off on updating tests until then.
1 parent 39f8067 commit 39f76a3

1 file changed

Lines changed: 37 additions & 23 deletions

File tree

lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2695,11 +2695,11 @@ void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
26952695
RewriteForGEP(cast<GEPOperator>(GEP), GEPBuilder);
26962696
}
26972697

2698-
/// RewriteCallArg - For Functions which don't flat,
2699-
/// replace OldVal with alloca and
2700-
/// copy in copy out data between alloca and flattened NewElts
2701-
/// in CallInst.
2702-
static void RewriteCallArg(CallInst *CI, unsigned ArgIdx, bool bIn, bool bOut) {
2698+
/// memcpyAggCallArg - For an aggregate call argument, this replaces the
2699+
/// argument with an alloca and inserts a memcpy for input (if CopyIn) and
2700+
/// output (if CopyOut).
2701+
static void memcpyAggCallArg(CallInst *CI, unsigned ArgIdx, bool CopyIn,
2702+
bool CopyOut) {
27032703
Function *F = CI->getParent()->getParent();
27042704
IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(F));
27052705
const DataLayout &DL = F->getParent()->getDataLayout();
@@ -2709,21 +2709,24 @@ static void RewriteCallArg(CallInst *CI, unsigned ArgIdx, bool bIn, bool bOut) {
27092709
Type *userTyElt = userTy->getElementType();
27102710
Value *Alloca = AllocaBuilder.CreateAlloca(userTyElt);
27112711
IRBuilder<> Builder(CI);
2712-
if (bIn) {
2712+
if (CopyIn) {
27132713
Builder.CreateMemCpy(Alloca, userTyV, DL.getTypeAllocSize(userTyElt),
27142714
false);
27152715
}
27162716
CI->setArgOperand(ArgIdx, Alloca);
2717-
if (bOut) {
2717+
if (CopyOut) {
27182718
Builder.SetInsertPoint(CI->getNextNode());
27192719
Builder.CreateMemCpy(userTyV, Alloca, DL.getTypeAllocSize(userTyElt),
27202720
false);
27212721
}
27222722
}
27232723

2724-
static void copyIntrinsicUDTArgs(HLModule &HLM) {
2724+
static void copyIntrinsicAggArgs(HLModule &HLM) {
27252725
// Iterate HLIntrinsic function users
2726-
// For specific intrinsics, use RewriteCallArg on UDT args
2726+
// For specific intrinsics, use memcpyAggCallArg on aggregate args
2727+
// This ensures that the call does not directly use the pointer supplied,
2728+
// allowing certain arguments to be flattened, and UDT args to be correctly
2729+
// lowered.
27272730
for (Function &F : HLM.GetModule()->functions()) {
27282731
if (F.isIntrinsic() || !F.isDeclaration())
27292732
continue;
@@ -2734,32 +2737,43 @@ static void copyIntrinsicUDTArgs(HLModule &HLM) {
27342737
if (CallInst *CI = dyn_cast<CallInst>(U)) {
27352738
switch (static_cast<IntrinsicOp>(GetHLOpcode(CI))) {
27362739
case IntrinsicOp::IOP_TraceRay:
2737-
// TODO: Remove RayDesc for flattening
2738-
RewriteCallArg(CI, HLOperandIndex::kTraceRayRayDescOpIdx,
2739-
/*bIn*/ true, /*bOut*/ false);
2740-
RewriteCallArg(CI, HLOperandIndex::kTraceRayPayloadPreOpIdx,
2741-
/*bIn*/ true, /*bOut*/ true);
2740+
memcpyAggCallArg(CI, HLOperandIndex::kTraceRayRayDescOpIdx,
2741+
/*CopyIn*/ true, /*CopyOut*/ false);
2742+
memcpyAggCallArg(CI, HLOperandIndex::kTraceRayPayloadPreOpIdx,
2743+
/*CopyIn*/ true, /*CopyOut*/ true);
27422744
break;
27432745
case IntrinsicOp::IOP_ReportHit:
2744-
RewriteCallArg(CI, HLOperandIndex::kReportIntersectionAttributeOpIdx,
2745-
/*bIn*/ true, /*bOut*/ false);
2746+
memcpyAggCallArg(CI,
2747+
HLOperandIndex::kReportIntersectionAttributeOpIdx,
2748+
/*CopyIn*/ true, /*CopyOut*/ false);
27462749
break;
27472750
case IntrinsicOp::IOP_CallShader:
2748-
RewriteCallArg(CI, HLOperandIndex::kCallShaderPayloadOpIdx,
2749-
/*bIn*/ true, /*bOut*/ true);
2751+
memcpyAggCallArg(CI, HLOperandIndex::kCallShaderPayloadOpIdx,
2752+
/*CopyIn*/ true, /*CopyOut*/ true);
2753+
break;
2754+
case IntrinsicOp::MOP_TraceRayInline:
2755+
memcpyAggCallArg(CI, HLOperandIndex::kTraceRayInlineRayDescOpIdx,
2756+
/*CopyIn*/ true, /*CopyOut*/ false);
27502757
break;
27512758
case IntrinsicOp::MOP_DxHitObject_FromRayQuery:
27522759
if (CI->getNumArgOperands() ==
27532760
HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_NumOp) {
2754-
RewriteCallArg(
2761+
memcpyAggCallArg(
27552762
CI,
27562763
HLOperandIndex::kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx,
2757-
/*bIn*/ true, /*bOut*/ false);
2764+
/*CopyIn*/ true, /*CopyOut*/ false);
27582765
}
27592766
break;
2767+
case IntrinsicOp::MOP_DxHitObject_MakeMiss:
2768+
memcpyAggCallArg(CI, HLOperandIndex::kHitObjectMakeMissRayDescOpIdx,
2769+
/*CopyIn*/ true, /*CopyOut*/ false);
2770+
break;
27602771
case IntrinsicOp::MOP_DxHitObject_TraceRay:
2761-
RewriteCallArg(CI, HLOperandIndex::kHitObjectTraceRay_PayloadPreOpIdx,
2762-
/*bIn*/ true, /*bOut*/ true);
2772+
memcpyAggCallArg(CI, HLOperandIndex::kHitObjectTraceRay_RayDescOpIdx,
2773+
/*CopyIn*/ true, /*CopyOut*/ false);
2774+
memcpyAggCallArg(CI,
2775+
HLOperandIndex::kHitObjectTraceRay_PayloadPreOpIdx,
2776+
/*CopyIn*/ true, /*CopyOut*/ true);
27632777
break;
27642778
}
27652779
}
@@ -4464,7 +4478,7 @@ class SROA_Parameter_HLSL : public ModulePass {
44644478
}
44654479

44664480
// Expand flattened copy-in/copy-out for intrinsic UDT args:
4467-
copyIntrinsicUDTArgs(*m_pHLModule);
4481+
copyIntrinsicAggArgs(*m_pHLModule);
44684482

44694483
// SROA globals and allocas.
44704484
SROAGlobalAndAllocas(*m_pHLModule, m_HasDbgInfo);

0 commit comments

Comments
 (0)