From d8fe1a6d4f855dc8ce5875aa93f84829c04e22df Mon Sep 17 00:00:00 2001 From: Jack Elliott Date: Thu, 26 Feb 2026 13:22:50 +1300 Subject: [PATCH 1/3] Mark wave intrinsics as convergent to prevent loop miscompilation The JumpThreading pass could restructure loops containing wave intrinsics by threading edges through the loop latch block. This moved wave ops (WaveReadLaneFirst, WaveActiveCountBits, etc.) from inside the loop to after it, changing the set of active lanes at the call site. On SIMT hardware, this produced incorrect results because all lanes reconverge at the post-loop point rather than only the matching subset. Fix: - DxilConvergentMark: Mark wave-sensitive HL functions with the Attribute::Convergent attribute before optimizer passes run. - JumpThreading: In ThreadEdge, walk backward from the latch to the loop header to identify loop body blocks. If any contains a convergent call, prevent threading through the latch. - DxilOperations::GetOpFunc: Mark DXIL wave op functions with Attribute::Convergent for post-dxilgen optimizer passes. Fixes a bug where a material binning pattern using WaveReadLaneFirst + WaveActiveCountBits in a while-loop produced correct results at -Od but incorrect results with optimizations enabled. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- lib/DXIL/DxilOperations.cpp | 11 ++++ lib/HLSL/DxilConvergent.cpp | 20 +++++-- lib/Transforms/Scalar/JumpThreading.cpp | 42 +++++++++++++++ .../convergent/wave-in-loop-not-sunk.hlsl | 52 +++++++++++++++++++ 4 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index eb5b2a2ceb..1116919843 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -6738,6 +6738,14 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { if (existF->getFunctionType() != pFT) return nullptr; F = existF; + // HLSL Change Begin - ensure attributes are set on existing functions. + if (OpProps.FuncAttr != Attribute::None && + !F->hasFnAttribute(OpProps.FuncAttr)) + F->addFnAttr(OpProps.FuncAttr); + // Mark wave ops as convergent since they depend on the active lane set. + if (IsDxilOpWave(opCode) && !F->hasFnAttribute(Attribute::Convergent)) + F->addFnAttr(Attribute::Convergent); + // HLSL Change End UpdateCache(opClass, pOverloadType, F); return F; } @@ -6749,6 +6757,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { F->addFnAttr(Attribute::NoUnwind); if (OpProps.FuncAttr != Attribute::None) F->addFnAttr(OpProps.FuncAttr); + // HLSL Change - mark wave ops as convergent. + if (IsDxilOpWave(opCode)) + F->addFnAttr(Attribute::Convergent); return F; } diff --git a/lib/HLSL/DxilConvergent.cpp b/lib/HLSL/DxilConvergent.cpp index a96af39fd8..fb3526c253 100644 --- a/lib/HLSL/DxilConvergent.cpp +++ b/lib/HLSL/DxilConvergent.cpp @@ -44,14 +44,28 @@ class DxilConvergentMark : public ModulePass { bool runOnModule(Module &M) override { const ShaderModel *SM = M.GetOrCreateHLModule().GetShaderModel(); + + bool bUpdated = false; + + // HLSL Change Begin - Mark wave-sensitive HL functions as convergent. + // This prevents optimizer passes (especially JumpThreading) from + // restructuring control flow around wave ops, which would change + // the set of active lanes at wave op call sites. + for (Function &F : M.functions()) { + if (F.isDeclaration() && IsHLWaveSensitive(&F) && + !F.hasFnAttribute(Attribute::Convergent)) { + F.addFnAttr(Attribute::Convergent); + bUpdated = true; + } + } + // HLSL Change End + // Can skip if in a shader and version that doesn't support derivatives. if (!SM->IsPS() && !SM->IsLib() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS()))) - return false; + return bUpdated; SupportsVectors = SM->IsSM69Plus(); - bool bUpdated = false; - for (Function &F : M.functions()) { if (F.isDeclaration()) continue; diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index e4757b472e..a97c639870 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CallSite.h" // HLSL Change - for convergent call detection #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -1388,6 +1389,47 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB, return false; } + // HLSL Change Begin - Don't thread through loop latch blocks when the loop + // body contains convergent calls (e.g., wave intrinsics). Threading through + // a latch can restructure the loop so that convergent calls that were inside + // the loop end up outside it, changing which lanes are active at those call + // sites on SIMT hardware. + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) { + BasicBlock *Header = *SI; + if (!LoopHeaders.count(Header)) + continue; + // BB is a loop latch (has a back-edge to Header). Walk backward from BB + // to find all blocks in the loop body and check for convergent calls. + SmallVector Worklist; + SmallPtrSet InLoop; + InLoop.insert(Header); // Seed to prevent going above header. + Worklist.push_back(BB); + bool HasConvergent = false; + while (!Worklist.empty() && !HasConvergent) { + BasicBlock *WBB = Worklist.pop_back_val(); + if (!InLoop.insert(WBB).second) + continue; + for (auto &I : *WBB) { + if (auto CS = CallSite(&I)) { + if (CS.hasFnAttr(Attribute::Convergent)) { + HasConvergent = true; + break; + } + } + } + if (!HasConvergent) + for (pred_iterator PI = pred_begin(WBB), PE = pred_end(WBB); PI != PE; + ++PI) + Worklist.push_back(*PI); + } + if (HasConvergent) { + DEBUG(dbgs() << " Not threading across loop latch BB '" << BB->getName() + << "' - loop body has convergent calls\n"); + return false; + } + } + // HLSL Change End + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); if (JumpThreadCost > BBDupThreshold) { DEBUG(dbgs() << " Not threading BB '" << BB->getName() diff --git a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl new file mode 100644 index 0000000000..a108c9b29d --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl @@ -0,0 +1,52 @@ +// RUN: %dxc -T cs_6_6 -E main %s | FileCheck %s + +// Regression test for a bug where the optimizer (JumpThreading) would +// restructure a while-loop containing wave intrinsics, moving +// WaveActiveCountBits outside the loop. This changes the set of active +// lanes at the wave op call site, producing incorrect results on SIMT +// hardware. + +// Verify that WaveAllBitCount (opcode 135) appears BEFORE the loop's +// back-edge phi, ensuring it stays inside the loop body. + +// CHECK: call i32 @dx.op.waveReadLaneFirst +// CHECK: call i32 @dx.op.waveAllOp +// CHECK: call i1 @dx.op.waveIsFirstLane +// CHECK: phi i32 +// CHECK: br i1 + +RWStructuredBuffer Output : register(u1); + +cbuffer Constants : register(b0) { + uint Width; + uint Height; + uint NumMaterials; +}; + +[numthreads(32, 1, 1)] +void main(uint3 DTid : SV_DispatchThreadID) { + uint x = DTid.x; + uint y = DTid.y; + + if (x >= Width || y >= Height) + return; + + // Compute a material ID per lane (simple hash). + uint materialID = ((x * 7) + (y * 13)) % NumMaterials; + + // Binning loop: each iteration peels off one material group. + // WaveReadLaneFirst picks a material, matching lanes count themselves + // with WaveActiveCountBits, and the first lane in the group writes + // the count. Non-matching lanes loop back for the next material. + bool go = true; + while (go) { + uint firstMat = WaveReadLaneFirst(materialID); + if (firstMat == materialID) { + uint count = WaveActiveCountBits(true); + if (WaveIsFirstLane()) { + InterlockedAdd(Output[firstMat], count); + } + go = false; + } + } +} From 57651ced48b1719d1f954ebe08f928ffd2187033 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 26 Feb 2026 00:31:51 +0000 Subject: [PATCH 2/3] chore: autopublish 2026-02-26T00:31:51Z --- lib/Transforms/Scalar/JumpThreading.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index a97c639870..e097f375ae 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -34,6 +33,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" From 8dc2b90c7a6ed399020f9d957a87e3d3bdd0f71e Mon Sep 17 00:00:00 2001 From: Jack Elliott Date: Thu, 26 Feb 2026 17:05:14 +1300 Subject: [PATCH 3/3] Address PR review: remove Hungarian notation and HLSL Change markers - Rename bUpdated -> Updated (drop Hungarian prefix per reviewer feedback) - Remove // HLSL Change Begin/End markers in DxilConvergent.cpp since the file is already HLSL-specific Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- lib/DXIL/DxilOperations.cpp | 5 ++--- lib/HLSL/DxilConvergent.cpp | 13 ++++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 1116919843..7726a79d9a 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -6738,14 +6738,13 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { if (existF->getFunctionType() != pFT) return nullptr; F = existF; - // HLSL Change Begin - ensure attributes are set on existing functions. + // Ensure attributes are set on existing functions. if (OpProps.FuncAttr != Attribute::None && !F->hasFnAttribute(OpProps.FuncAttr)) F->addFnAttr(OpProps.FuncAttr); // Mark wave ops as convergent since they depend on the active lane set. if (IsDxilOpWave(opCode) && !F->hasFnAttribute(Attribute::Convergent)) F->addFnAttr(Attribute::Convergent); - // HLSL Change End UpdateCache(opClass, pOverloadType, F); return F; } @@ -6757,7 +6756,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { F->addFnAttr(Attribute::NoUnwind); if (OpProps.FuncAttr != Attribute::None) F->addFnAttr(OpProps.FuncAttr); - // HLSL Change - mark wave ops as convergent. + // Mark wave ops as convergent since they depend on the active lane set. if (IsDxilOpWave(opCode)) F->addFnAttr(Attribute::Convergent); diff --git a/lib/HLSL/DxilConvergent.cpp b/lib/HLSL/DxilConvergent.cpp index fb3526c253..4848533c6b 100644 --- a/lib/HLSL/DxilConvergent.cpp +++ b/lib/HLSL/DxilConvergent.cpp @@ -45,9 +45,9 @@ class DxilConvergentMark : public ModulePass { bool runOnModule(Module &M) override { const ShaderModel *SM = M.GetOrCreateHLModule().GetShaderModel(); - bool bUpdated = false; + bool Updated = false; - // HLSL Change Begin - Mark wave-sensitive HL functions as convergent. + // Mark wave-sensitive HL functions as convergent. // This prevents optimizer passes (especially JumpThreading) from // restructuring control flow around wave ops, which would change // the set of active lanes at wave op call sites. @@ -55,15 +55,14 @@ class DxilConvergentMark : public ModulePass { if (F.isDeclaration() && IsHLWaveSensitive(&F) && !F.hasFnAttribute(Attribute::Convergent)) { F.addFnAttr(Attribute::Convergent); - bUpdated = true; + Updated = true; } } - // HLSL Change End // Can skip if in a shader and version that doesn't support derivatives. if (!SM->IsPS() && !SM->IsLib() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS()))) - return bUpdated; + return Updated; SupportsVectors = SM->IsSM69Plus(); for (Function &F : M.functions()) { @@ -80,13 +79,13 @@ class DxilConvergentMark : public ModulePass { if (PropagateConvergent(V, &F, PDR)) { // TODO: emit warning here. } - bUpdated = true; + Updated = true; } } } } - return bUpdated; + return Updated; } private: