Skip to content

Commit 77f146e

Browse files
authored
Fix issues with resource array aliasing support (#3810)
SROA: Skip arrays of object/matrix as well. Fix issues with ReplaceUseOfZeroInit* - Fix ReplaceUseOfZeroInitPostDom bailing on use of memcpy dest. When memcpy not in entry block, any use of dest after memcpy would abort replacement, as if it was unsafe to replace dest with src. But uses of dest dominated by memcpy *are* safe to replace. - Fix ReplaceUseOfZeroInit* misuse of post-dom. Post-dom was used to detect whether it was safe to replace uses of dest before the memcpy with zero when dest was zeroinitialized. But post-dom is not the right way to tell if this is safe. It is unsafe if any uses *could* follow the memcpy. So the new test is to gather a set of blocks that could be reachable from the successors of the memcpy block (where memcpy ends the block because we split it there). If so, it is not safe to replace the use with zeroinitializer, and because it was not dominated by the memcpy either, it's not safe to replace dest with src here either, so memcpy replacement must abort. - Re-merge blocks after splitting in ReplaceUseOfZeroInitBeforeDef. Restore CFG after ReplaceUseOfZeroInitBeforeDef, rather than leaving it in more of a mess to be cleaned up much later. This has the side benefit of preserving more trivial entry-block replacement opportunities. DFE should apply to internal functions in lib target as well. - Also, iterating over RemoveUnusedFunctions will allow removal of more functions only called by other internal functions being removed. Add LowerStaticGlobalIntoAlloca before SROA for LowerMemcpy opportunities - Existing location after SROA must also be there for non-resource aggregates - Don't preserve zeroinitializer when moving static GV to alloca for objects - Minor opt in LowerStaticGlobalIntoAlloca: skip fn decls more cheaply
1 parent e5801bd commit 77f146e

6 files changed

Lines changed: 127 additions & 63 deletions

File tree

lib/DXIL/DxilUtil.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
121121
for (auto &F : M.functions()) {
122122
if (&F == EntryFunc || &F == PatchConstantFunc)
123123
continue;
124-
if (F.isDeclaration() || !IsLib) {
124+
if (F.isDeclaration() || !IsLib ||
125+
F.hasInternalLinkage()) {
125126
if (F.user_empty())
126127
deadList.emplace_back(&F);
127128
}

lib/HLSL/HLDeadFunctionElimination.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ class HLDeadFunctionElimination : public ModulePass {
3535
Function *EntryFunc = HLM.GetEntryFunction();
3636
Function *PatchConstantFunc = HLM.GetPatchConstantFunction();
3737

38-
return dxilutil::RemoveUnusedFunctions(M, EntryFunc, PatchConstantFunc,
39-
IsLib);
38+
bool bChanged = false;
39+
while (dxilutil::RemoveUnusedFunctions(M, EntryFunc, PatchConstantFunc,
40+
IsLib))
41+
bChanged = true;
42+
return bChanged;
4043
}
4144

4245
return false;

lib/Transforms/IPO/PassManagerBuilder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, bool OnlyWarnOn
223223
MPM.add(createHLDeadFunctionEliminationPass());
224224
}
225225

226+
// Do this before scalarrepl-param-hlsl for opportunities to move things
227+
// like resource arrays to alloca, allowing more likely memcpy replacement.
228+
MPM.add(createLowerStaticGlobalIntoAlloca());
229+
226230
// Expand buffer store intrinsics before we SROA
227231
MPM.add(createHLExpandStoreIntrinsicsPass());
228232

lib/Transforms/Scalar/SROA.cpp

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,17 @@ static Value *foldPHINodeOrSelectInst(Instruction &I) {
627627
return foldSelectInst(cast<SelectInst>(I));
628628
}
629629

630+
// HLSL Change - Detect HLSL Object or Matrix [array] type
631+
// These types should be SROA'd elsewhere as necessary.
632+
bool SkipHLSLType(Type *Ty, bool SkipHLSLMat) {
633+
if (Ty->isPointerTy())
634+
Ty = Ty->getPointerElementType();
635+
while (Ty->isArrayTy())
636+
Ty = Ty->getArrayElementType();
637+
return (SkipHLSLMat && hlsl::HLMatrixType::isa(Ty)) ||
638+
hlsl::dxilutil::IsHLSLObjectType(Ty);
639+
}
640+
630641
/// \brief Builder for the alloca slices.
631642
///
632643
/// This class builds a set of alloca slices by recursively visiting the uses
@@ -697,21 +708,10 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
697708
if (BC.use_empty())
698709
return markAsDead(BC);
699710
// HLSL Change Begin - not sroa matrix type.
700-
if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
701-
Type *EltTy = PT->getElementType();
702-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(EltTy)) ||
703-
hlsl::dxilutil::IsHLSLObjectType(EltTy)) {
704-
AS.PointerEscapingInstr = &BC;
705-
return;
706-
}
707-
if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
708-
Type *SrcEltTy = SrcPT->getElementType();
709-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(SrcEltTy)) ||
710-
hlsl::dxilutil::IsHLSLObjectType(SrcEltTy)) {
711-
AS.PointerEscapingInstr = &BC;
712-
return;
713-
}
714-
}
711+
if (SkipHLSLType(BC.getType(), SkipHLSLMat) ||
712+
SkipHLSLType(BC.getSrcTy(), SkipHLSLMat)) {
713+
AS.PointerEscapingInstr = &BC;
714+
return;
715715
}
716716
// HLSL Change End.
717717
return Base::visitBitCastInst(BC);
@@ -775,8 +775,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
775775

776776
void visitLoadInst(LoadInst &LI) {
777777
// HLSL Change Begin - not sroa matrix type.
778-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(LI.getType())) ||
779-
hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
778+
if (SkipHLSLType(LI.getType(), SkipHLSLMat))
780779
return PI.setEscapedAndAborted(&LI);
781780
// HLSL Change End.
782781
assert((!LI.isSimple() || LI.getType()->isSingleValueType()) &&
@@ -796,8 +795,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
796795
if (ValOp == *U)
797796
return PI.setEscapedAndAborted(&SI);
798797
// HLSL Change Begin - not sroa matrix type.
799-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(ValOp->getType())) ||
800-
hlsl::dxilutil::IsHLSLObjectType(ValOp->getType()))
798+
if (SkipHLSLType(ValOp->getType(), SkipHLSLMat))
801799
return PI.setEscapedAndAborted(&SI);
802800
// HLSL Change End.
803801

@@ -3366,8 +3364,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
33663364
if (!LI.isSimple() || LI.getType()->isSingleValueType())
33673365
return false;
33683366
// HLSL Change Begin - not sroa matrix type.
3369-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(LI.getType())) ||
3370-
hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
3367+
if (SkipHLSLType(LI.getType(), SkipHLSLMat))
33713368
return false;
33723369
// HLSL Change End.
33733370

@@ -3405,8 +3402,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
34053402
if (V->getType()->isSingleValueType())
34063403
return false;
34073404
// HLSL Change Begin - not sroa matrix type.
3408-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(V->getType())) ||
3409-
hlsl::dxilutil::IsHLSLObjectType(V->getType()))
3405+
if (SkipHLSLType(V->getType(), SkipHLSLMat))
34103406
return false;
34113407
// HLSL Change End.
34123408
// We have an aggregate being stored, split it apart.
@@ -3419,17 +3415,9 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
34193415

34203416
bool visitBitCastInst(BitCastInst &BC) {
34213417
// HLSL Change Begin - not sroa matrix type.
3422-
if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
3423-
Type *EltTy = PT->getElementType();
3424-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(EltTy)) ||
3425-
hlsl::dxilutil::IsHLSLObjectType(EltTy))
3426-
return false;
3427-
if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
3428-
Type *SrcEltTy = SrcPT->getElementType();
3429-
if ((SkipHLSLMat && hlsl::HLMatrixType::isa(SrcEltTy)) ||
3430-
hlsl::dxilutil::IsHLSLObjectType(SrcEltTy))
3431-
return false;
3432-
}
3418+
if (SkipHLSLType(BC.getType(), SkipHLSLMat) ||
3419+
SkipHLSLType(BC.getSrcTy(), SkipHLSLMat)) {
3420+
return false;
34333421
}
34343422
// HLSL Change End.
34353423
enqueueUsers(BC);
@@ -4420,8 +4408,7 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
44204408
hlsl::dxilutil::IsHLSLObjectType(
44214409
AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
44224410
// HLSL Change Begin - not sroa matrix type.
4423-
(SkipHLSLMat &&
4424-
hlsl::HLMatrixType::isa(AI.getAllocatedType())) ||
4411+
SkipHLSLType(AI.getAllocatedType(), SkipHLSLMat) ||
44254412
// HLSL Change End.
44264413
DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
44274414
return false;

lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
5050
#include "llvm/Transforms/Utils/SSAUpdater.h"
5151
#include "llvm/Transforms/Utils/Local.h"
52+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
5253
#include "dxc/HLSL/HLOperations.h"
5354
#include "dxc/DXIL/DxilConstants.h"
5455
#include "dxc/HLSL/HLModule.h"
@@ -3443,7 +3444,7 @@ static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
34433444
if (V != Src && V->hasOneUse() && Src->hasOneUse())
34443445
return false;
34453446

3446-
// If the memcpy doesn't dominate all its users,
3447+
// If the source of the memcpy (Src) doesn't dominate all users of dest (V),
34473448
// full replacement isn't possible without complicated PHI insertion
34483449
// This will likely replace with ld/st which will be replaced in mem2reg
34493450
if (Instruction *SrcI = dyn_cast<Instruction>(Src))
@@ -3587,32 +3588,40 @@ static bool ReplaceUseOfZeroInitEntry(Instruction *I, Value *V) {
35873588
return true;
35883589
}
35893590

3590-
static bool ReplaceUseOfZeroInitPostDom(Instruction *I, Value *V,
3591-
PostDominatorTree &PDT) {
3591+
// If a V user is dominated by memcpy (I),
3592+
// skip it - memcpy dest can simply alias to src for this user.
3593+
// If the V user may follow the memcpy (I),
3594+
// return false - memcpy dest not safe to replace with src.
3595+
// Otherwise,
3596+
// replace use with zeroinitializer.
3597+
static bool ReplaceUseOfZeroInit(Instruction *I, Value *V,
3598+
DominatorTree &DT,
3599+
SmallPtrSet<BasicBlock*, 8> &Reachable) {
35923600
BasicBlock *BB = I->getParent();
35933601
Function *F = I->getParent()->getParent();
35943602
for (auto U = V->user_begin(); U != V->user_end(); ) {
35953603
Instruction *UI = dyn_cast<Instruction>(*(U++));
3596-
if (!UI)
3604+
if (!UI || UI == I)
35973605
continue;
35983606
if (UI->getParent()->getParent() != F)
35993607
continue;
36003608

3601-
if (!PDT.dominates(BB, UI->getParent()))
3609+
// Skip properly dominated users
3610+
if (DT.properlyDominates(BB, UI->getParent()))
3611+
continue;
3612+
3613+
// If user is found in memcpy successor list
3614+
// then the user is not safe to replace with zeroinitializer.
3615+
if (Reachable.count(UI->getParent()))
36023616
return false;
36033617

3618+
// Remaining cases are where I:
3619+
// - is at the end of the same block
3620+
// - does not precede UI on any path
36043621
if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
3605-
if (!ReplaceUseOfZeroInitPostDom(I, UI, PDT))
3606-
return false;
3607-
else
3622+
if (ReplaceUseOfZeroInit(I, UI, DT, Reachable))
36083623
continue;
3609-
}
3610-
3611-
if (BB != UI->getParent() || UI == I)
3612-
continue;
3613-
// I is the last inst in the block after split.
3614-
// Any inst in current block is before I.
3615-
if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
3624+
} else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
36163625
LI->replaceAllUsesWith(ConstantAggregateZero::get(LI->getType()));
36173626
LI->eraseFromParent();
36183627
continue;
@@ -3621,23 +3630,42 @@ static bool ReplaceUseOfZeroInitPostDom(Instruction *I, Value *V,
36213630
}
36223631
return true;
36233632
}
3633+
3634+
// Recursively collect all successors of BB and BB's successors.
3635+
// BB will not be in set unless it's reachable through its successors.
3636+
static void CollectReachableBBs(BasicBlock *BB, SmallPtrSet<BasicBlock*, 8> &Reachable) {
3637+
for (auto S : successors(BB)) {
3638+
if (Reachable.insert(S).second)
3639+
CollectReachableBBs(S, Reachable);
3640+
}
3641+
}
3642+
36243643
// When zero initialized GV has only one define, all uses before the def should
36253644
// use zero.
36263645
static bool ReplaceUseOfZeroInitBeforeDef(Instruction *I, GlobalVariable *GV) {
36273646
BasicBlock *BB = I->getParent();
36283647
Function *F = I->getParent()->getParent();
36293648
// Make sure I is the last inst for BB.
3649+
BasicBlock *NewBB = nullptr;
36303650
if (I != BB->getTerminator())
3631-
BB->splitBasicBlock(I->getNextNode());
3651+
NewBB = BB->splitBasicBlock(I->getNextNode());
36323652

3653+
bool bSuccess = false;
36333654
if (&F->getEntryBlock() == I->getParent()) {
3634-
return ReplaceUseOfZeroInitEntry(I, GV);
3655+
bSuccess = ReplaceUseOfZeroInitEntry(I, GV);
36353656
} else {
3636-
// Post dominator tree.
3637-
PostDominatorTree PDT;
3638-
PDT.runOnFunction(*F);
3639-
return ReplaceUseOfZeroInitPostDom(I, GV, PDT);
3657+
DominatorTree DT;
3658+
DT.recalculate(*F);
3659+
SmallPtrSet<BasicBlock*, 8> Reachable;
3660+
CollectReachableBBs(BB, Reachable);
3661+
bSuccess = ReplaceUseOfZeroInit(I, GV, DT, Reachable);
36403662
}
3663+
3664+
// Re-merge basic block to keep things simpler
3665+
if (NewBB)
3666+
llvm::MergeBlockIntoPredecessor(NewBB);
3667+
3668+
return bSuccess;
36413669
}
36423670

36433671
// Use `DT` to trace all users and make sure `I`'s BB dominates them all
@@ -5918,7 +5946,7 @@ class LowerStaticGlobalIntoAlloca : public ModulePass {
59185946
}
59195947
} else {
59205948
for (Function &F : M) {
5921-
if (!HLM.IsEntry(&F)) {
5949+
if (F.isDeclaration() || !HLM.IsEntry(&F)) {
59225950
continue;
59235951
}
59245952
entryAndInitFunctionSet.insert(&F);
@@ -5936,7 +5964,7 @@ class LowerStaticGlobalIntoAlloca : public ModulePass {
59365964
}
59375965
} else {
59385966
for (Function &F : M) {
5939-
if (!DM.IsEntry(&F))
5967+
if (F.isDeclaration() || !DM.IsEntry(&F))
59405968
continue;
59415969
entryAndInitFunctionSet.insert(&F);
59425970
}
@@ -6171,14 +6199,17 @@ bool LowerStaticGlobalIntoAlloca::lowerStaticGlobalIntoAlloca(
61716199
GlobalVariable *GV, const DataLayout &DL, DxilTypeSystem &typeSys,
61726200
SetVector<Function *> &entryAndInitFunctionSet) {
61736201
GV->removeDeadConstantUsers();
6202+
bool bIsObjectTy = dxilutil::IsHLSLObjectType(
6203+
dxilutil::StripArrayTypes(GV->getType()->getElementType()));
61746204
// Create alloca for each entry.
61756205
DenseMap<Function *, AllocaInst *> allocaMap;
61766206
for (Function *F : entryAndInitFunctionSet) {
61776207
IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(F));
61786208
AllocaInst *AI = Builder.CreateAlloca(GV->getType()->getElementType());
61796209
allocaMap[F] = AI;
61806210
// Store initializer is exist.
6181-
if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
6211+
if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer()) &&
6212+
!bIsObjectTy) { // Do not zerio-initialize object allocas
61826213
Builder.CreateStore(GV->getInitializer(), GV);
61836214
}
61846215
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %dxc -T lib_6_5 %s | FileCheck %s
2+
3+
// Make sure we don't use alloca, meaning we successfully aliased the static
4+
// and local resource arrays to the original global array.
5+
6+
// CHECK: define void @main()
7+
// CHECK-NOT: alloca
8+
9+
Texture2D<float4> buf[16];
10+
static Texture2D<float4> s_buf[16];
11+
12+
void Init(Texture2D<float4> buf[16]) {
13+
s_buf = buf;
14+
}
15+
16+
float4 Use(Texture2D<float4> buf[16], uint i) {
17+
return buf[i][int2(3,4)];
18+
}
19+
20+
float4 Use(uint i) {
21+
return Use(s_buf, i);
22+
}
23+
24+
[shader("pixel")]
25+
float4 main(uint i:I) : SV_Target {
26+
Texture2D<float4> lbuf[16];
27+
lbuf = buf;
28+
Init(lbuf);
29+
return Use(i);
30+
}
31+
32+
[shader("pixel")]
33+
float4 other_main(uint i:I) : SV_Target {
34+
Texture2D<float4> lbuf[16];
35+
lbuf = buf;
36+
Init(lbuf);
37+
return Use(i);
38+
}

0 commit comments

Comments
 (0)