Skip to content

Commit 76e7647

Browse files
author
Greg Roth
authored
Prevent sinking coord calculation for sample (#3657)
merge #3655 The mark convergent pass is meant to prevent unwanted moving of operations on derivative op input. It was previously only run on pixel shaders. Because derivatives are supported in CS/MS/AS shaders as part of shader model 6.6, it needs to be run on these stages for that target too. (cherry picked from commit 93a9898)
1 parent 985b29b commit 76e7647

2 files changed

Lines changed: 50 additions & 1 deletion

File tree

lib/HLSL/DxilConvergent.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class DxilConvergentMark : public ModulePass {
4747

4848
bool runOnModule(Module &M) override {
4949
if (M.HasHLModule()) {
50-
if (!M.GetHLModule().GetShaderModel()->IsPS())
50+
const ShaderModel *SM = M.GetHLModule().GetShaderModel();
51+
if (!SM->IsPS() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
5152
return false;
5253
}
5354
bool bUpdated = false;
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %dxc -E MainCS -T cs_6_6 %s | FileCheck %s
2+
// RUN: %dxc -E MainAS -T as_6_6 %s | FileCheck %s
3+
// RUN: %dxc -E MainMS -T ms_6_6 %s | FileCheck %s
4+
5+
// Make sure add is not sunk into if.
6+
// Compute shader variant of convergent.hlsl
7+
8+
// CHECK: add
9+
// CHECK: add
10+
// CHECK: icmp
11+
// CHECK-NEXT: br
12+
13+
14+
Texture2D<float4> tex;
15+
RWBuffer<float4> output;
16+
SamplerState s;
17+
18+
void doit(uint ix, uint3 id){
19+
20+
float2 coord = id.xy + id.z;
21+
float4 c = id.z;
22+
if (id.z > 2) {
23+
c += tex.Sample(s, coord);
24+
}
25+
output[ix] = c;
26+
27+
}
28+
29+
[numthreads(4,4,4)]
30+
void MainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
31+
doit(ix, id);
32+
}
33+
34+
struct Payload { int nothing; };
35+
36+
[numthreads(4,4,4)]
37+
void MainAS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
38+
doit(ix, id);
39+
Payload pld = (Payload)0;
40+
DispatchMesh(1,1,1,pld);
41+
}
42+
43+
44+
[numthreads(4,4,4)]
45+
[outputtopology("triangle")]
46+
void MainMS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
47+
doit(ix, id);
48+
}

0 commit comments

Comments
 (0)