Skip to content

Commit d8fe1a6

Browse files
JoeCitizenCopilot
andcommitted
Mark wave intrinsics as convergent to prevent loop miscompilation
The JumpThreading pass could restructure loops containing wave intrinsics by threading edges through the loop latch block. This moved wave ops (WaveReadLaneFirst, WaveActiveCountBits, etc.) from inside the loop to after it, changing the set of active lanes at the call site. On SIMT hardware, this produced incorrect results because all lanes reconverge at the post-loop point rather than only the matching subset. Fix: - DxilConvergentMark: Mark wave-sensitive HL functions with the Attribute::Convergent attribute before optimizer passes run. - JumpThreading: In ThreadEdge, walk backward from the latch to the loop header to identify loop body blocks. If any contains a convergent call, prevent threading through the latch. - DxilOperations::GetOpFunc: Mark DXIL wave op functions with Attribute::Convergent for post-dxilgen optimizer passes. Fixes a bug where a material binning pattern using WaveReadLaneFirst + WaveActiveCountBits in a while-loop produced correct results at -Od but incorrect results with optimizations enabled. Co-authored-by: Copilot <[email protected]>
1 parent c44a383 commit d8fe1a6

4 files changed

Lines changed: 122 additions & 3 deletions

File tree

lib/DXIL/DxilOperations.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6738,6 +6738,14 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
67386738
if (existF->getFunctionType() != pFT)
67396739
return nullptr;
67406740
F = existF;
6741+
// HLSL Change Begin - ensure attributes are set on existing functions.
6742+
if (OpProps.FuncAttr != Attribute::None &&
6743+
!F->hasFnAttribute(OpProps.FuncAttr))
6744+
F->addFnAttr(OpProps.FuncAttr);
6745+
// Mark wave ops as convergent since they depend on the active lane set.
6746+
if (IsDxilOpWave(opCode) && !F->hasFnAttribute(Attribute::Convergent))
6747+
F->addFnAttr(Attribute::Convergent);
6748+
// HLSL Change End
67416749
UpdateCache(opClass, pOverloadType, F);
67426750
return F;
67436751
}
@@ -6749,6 +6757,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
67496757
F->addFnAttr(Attribute::NoUnwind);
67506758
if (OpProps.FuncAttr != Attribute::None)
67516759
F->addFnAttr(OpProps.FuncAttr);
6760+
// HLSL Change - mark wave ops as convergent.
6761+
if (IsDxilOpWave(opCode))
6762+
F->addFnAttr(Attribute::Convergent);
67526763

67536764
return F;
67546765
}

lib/HLSL/DxilConvergent.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,28 @@ class DxilConvergentMark : public ModulePass {
4444

4545
bool runOnModule(Module &M) override {
4646
const ShaderModel *SM = M.GetOrCreateHLModule().GetShaderModel();
47+
48+
bool bUpdated = false;
49+
50+
// HLSL Change Begin - Mark wave-sensitive HL functions as convergent.
51+
// This prevents optimizer passes (especially JumpThreading) from
52+
// restructuring control flow around wave ops, which would change
53+
// the set of active lanes at wave op call sites.
54+
for (Function &F : M.functions()) {
55+
if (F.isDeclaration() && IsHLWaveSensitive(&F) &&
56+
!F.hasFnAttribute(Attribute::Convergent)) {
57+
F.addFnAttr(Attribute::Convergent);
58+
bUpdated = true;
59+
}
60+
}
61+
// HLSL Change End
62+
4763
// Can skip if in a shader and version that doesn't support derivatives.
4864
if (!SM->IsPS() && !SM->IsLib() &&
4965
(!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
50-
return false;
66+
return bUpdated;
5167
SupportsVectors = SM->IsSM69Plus();
5268

53-
bool bUpdated = false;
54-
5569
for (Function &F : M.functions()) {
5670
if (F.isDeclaration())
5771
continue;

lib/Transforms/Scalar/JumpThreading.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/Analysis/LazyValueInfo.h"
2525
#include "llvm/Analysis/Loads.h"
2626
#include "llvm/Analysis/TargetLibraryInfo.h"
27+
#include "llvm/IR/CallSite.h" // HLSL Change - for convergent call detection
2728
#include "llvm/IR/DataLayout.h"
2829
#include "llvm/IR/IntrinsicInst.h"
2930
#include "llvm/IR/LLVMContext.h"
@@ -1388,6 +1389,47 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
13881389
return false;
13891390
}
13901391

1392+
// HLSL Change Begin - Don't thread through loop latch blocks when the loop
1393+
// body contains convergent calls (e.g., wave intrinsics). Threading through
1394+
// a latch can restructure the loop so that convergent calls that were inside
1395+
// the loop end up outside it, changing which lanes are active at those call
1396+
// sites on SIMT hardware.
1397+
for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) {
1398+
BasicBlock *Header = *SI;
1399+
if (!LoopHeaders.count(Header))
1400+
continue;
1401+
// BB is a loop latch (has a back-edge to Header). Walk backward from BB
1402+
// to find all blocks in the loop body and check for convergent calls.
1403+
SmallVector<BasicBlock *, 16> Worklist;
1404+
SmallPtrSet<BasicBlock *, 16> InLoop;
1405+
InLoop.insert(Header); // Seed to prevent going above header.
1406+
Worklist.push_back(BB);
1407+
bool HasConvergent = false;
1408+
while (!Worklist.empty() && !HasConvergent) {
1409+
BasicBlock *WBB = Worklist.pop_back_val();
1410+
if (!InLoop.insert(WBB).second)
1411+
continue;
1412+
for (auto &I : *WBB) {
1413+
if (auto CS = CallSite(&I)) {
1414+
if (CS.hasFnAttr(Attribute::Convergent)) {
1415+
HasConvergent = true;
1416+
break;
1417+
}
1418+
}
1419+
}
1420+
if (!HasConvergent)
1421+
for (pred_iterator PI = pred_begin(WBB), PE = pred_end(WBB); PI != PE;
1422+
++PI)
1423+
Worklist.push_back(*PI);
1424+
}
1425+
if (HasConvergent) {
1426+
DEBUG(dbgs() << " Not threading across loop latch BB '" << BB->getName()
1427+
<< "' - loop body has convergent calls\n");
1428+
return false;
1429+
}
1430+
}
1431+
// HLSL Change End
1432+
13911433
unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold);
13921434
if (JumpThreadCost > BBDupThreshold) {
13931435
DEBUG(dbgs() << " Not threading BB '" << BB->getName()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: %dxc -T cs_6_6 -E main %s | FileCheck %s
2+
3+
// Regression test for a bug where the optimizer (JumpThreading) would
4+
// restructure a while-loop containing wave intrinsics, moving
5+
// WaveActiveCountBits outside the loop. This changes the set of active
6+
// lanes at the wave op call site, producing incorrect results on SIMT
7+
// hardware.
8+
9+
// Verify that WaveAllBitCount (opcode 135) appears BEFORE the loop's
10+
// back-edge phi, ensuring it stays inside the loop body.
11+
12+
// CHECK: call i32 @dx.op.waveReadLaneFirst
13+
// CHECK: call i32 @dx.op.waveAllOp
14+
// CHECK: call i1 @dx.op.waveIsFirstLane
15+
// CHECK: phi i32
16+
// CHECK: br i1
17+
18+
RWStructuredBuffer<uint> Output : register(u1);
19+
20+
cbuffer Constants : register(b0) {
21+
uint Width;
22+
uint Height;
23+
uint NumMaterials;
24+
};
25+
26+
[numthreads(32, 1, 1)]
27+
void main(uint3 DTid : SV_DispatchThreadID) {
28+
uint x = DTid.x;
29+
uint y = DTid.y;
30+
31+
if (x >= Width || y >= Height)
32+
return;
33+
34+
// Compute a material ID per lane (simple hash).
35+
uint materialID = ((x * 7) + (y * 13)) % NumMaterials;
36+
37+
// Binning loop: each iteration peels off one material group.
38+
// WaveReadLaneFirst picks a material, matching lanes count themselves
39+
// with WaveActiveCountBits, and the first lane in the group writes
40+
// the count. Non-matching lanes loop back for the next material.
41+
bool go = true;
42+
while (go) {
43+
uint firstMat = WaveReadLaneFirst(materialID);
44+
if (firstMat == materialID) {
45+
uint count = WaveActiveCountBits(true);
46+
if (WaveIsFirstLane()) {
47+
InterlockedAdd(Output[firstMat], count);
48+
}
49+
go = false;
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)