Skip to content

Commit 8a59115

Browse files
authored
Allow derivatives on mesh nodes (#6507)
Previously, mesh nodes (that is, node shaders with mesh launch types) would emit diagnostics when they contain a derivative operation. However, mesh nodes should allow derivative operations. This PR adjusts this rule for mesh nodes, so that derivative operations are allowed. An accompanying regression test ensures that diagnostics that would previously be emitted are no longer emitted. Fixes #6480
1 parent e6c1c4f commit 8a59115

7 files changed

Lines changed: 233 additions & 4 deletions

File tree

docs/DXIL.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3216,7 +3216,7 @@ SM.HSINPUTCONTROLPOINTCOUNTRANGE HS input control point count must be [
32163216
SM.HULLPASSTHRUCONTROLPOINTCOUNTMATCH For pass thru hull shader, input control point count must match output control point count
32173217
SM.INCOMPATIBLECALLINENTRY Features used in internal function calls must be compatible with entry
32183218
SM.INCOMPATIBLEDERIVINCOMPUTESHADERMODEL Derivatives in compute-model shaders require shader model 6.6 and above
3219-
SM.INCOMPATIBLEDERIVLAUNCH Node shaders only support derivatives in broadcasting launch mode
3219+
SM.INCOMPATIBLEDERIVLAUNCH Node shaders only support derivatives in broadcasting or mesh launch modes
32203220
SM.INCOMPATIBLEOPERATION Operations used in entry function must be compatible with shader stage and other properties
32213221
SM.INCOMPATIBLEREQUIRESGROUP Functions requiring groupshared memory must be called from shaders with a visible group
32223222
SM.INCOMPATIBLESHADERMODEL Functions may only use features available in the current shader model

lib/DXIL/DxilModule.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,9 @@ static void AdjustMinimumShaderModelAndFlags(const DxilFunctionProps *props,
21482148
// Requires flag for support on SM 6.6+.
21492149
flags.SetDerivativesInMeshAndAmpShaders(true);
21502150
DXIL::UpdateToMaxOfVersions(minMajor, minMinor, 6, 6);
2151+
} else if (props->IsNode()) {
2152+
if (props->Node.LaunchType == DXIL::NodeLaunchType::Mesh)
2153+
flags.SetDerivativesInMeshAndAmpShaders(true);
21512154
}
21522155
}
21532156

lib/HLSL/DxilValidation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5430,7 +5430,8 @@ struct CompatibilityChecker {
54305430
static_cast<uint32_t>(ConflictFlags::DerivInComputeShaderModel);
54315431
} else if (shaderKind == DXIL::ShaderKind::Node) {
54325432
// Only broadcasting launch supports derivatives.
5433-
if (props.Node.LaunchType != DXIL::NodeLaunchType::Broadcasting)
5433+
if (props.Node.LaunchType != DXIL::NodeLaunchType::Broadcasting &&
5434+
props.Node.LaunchType != DXIL::NodeLaunchType::Mesh)
54345435
maskForDeriv |= static_cast<uint32_t>(ConflictFlags::DerivLaunch);
54355436
// Thread launch node has no group.
54365437
if (props.Node.LaunchType == DXIL::NodeLaunchType::Thread)

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11234,7 +11234,8 @@ void Sema::DiagnoseReachableHLSLMethodCall(const CXXMethodDecl *MD,
1123411234
} break;
1123511235
case DXIL::ShaderKind::Node: {
1123611236
if (const auto *pAttr = EntryDecl->getAttr<HLSLNodeLaunchAttr>()) {
11237-
if (pAttr->getLaunchType() != "broadcasting") {
11237+
if (pAttr->getLaunchType() != "broadcasting" &&
11238+
pAttr->getLaunchType() != "mesh") {
1123811239
Diags.Report(Loc,
1123911240
diag::warn_hlsl_derivatives_in_wrong_shader_kind)
1124011241
<< MD->getNameAsString() << EntryDecl->getNameAsString();
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// RUN: %dxilver 1.9 | %dxc -T lib_6_9 %s | %D3DReflect %s | %FileCheck %s -check-prefixes=RDAT
2+
3+
// Ensure that categories of deriv ops are allowed for node shaders.
4+
// Ensure that the OptFeatureInfo_UsesDerivatives flag is set as well.
5+
6+
// RDAT: FunctionTable[{{.*}}] = {
7+
8+
Texture2D<float4> T2D : register(t0, space0);
9+
SamplerState Samp : register(s0, space0);
10+
RWByteAddressBuffer BAB : register(u1, space0);
11+
12+
///////////////////////////////////////////////////////////////////////////////
13+
// Category: derivatives ddx/ddy/ddx_coarse/ddy_coarse/ddx_fine/ddy_fine
14+
15+
// RDAT-LABEL: UnmangledName: "node_deriv"
16+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
17+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
18+
// RDAT: ShaderStageFlag: (Node)
19+
// RDAT: MinShaderTarget: 0xf0069
20+
21+
[shader("node")]
22+
[NodeLaunch("mesh")]
23+
[OutputTopology("line")]
24+
[NodeDispatchGrid(1, 1, 1)]
25+
[NumThreads(4,4,1)]
26+
void node_deriv(uint3 tid : SV_GroupThreadID) {
27+
float2 uv = tid.xy / float2(4, 4);
28+
float2 ddx_uv = ddx(uv);
29+
BAB.Store(0, ddx_uv);
30+
}
31+
32+
// RDAT-LABEL: UnmangledName: "use_deriv"
33+
// RDAT: FeatureInfo1: 0
34+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
35+
// RDAT: ShaderStageFlag: (Pixel | Compute | Library | Mesh | Amplification | Node)
36+
// RDAT: MinShaderTarget: 0x60060
37+
38+
[noinline] export
39+
void use_deriv(float2 uv) {
40+
float2 ddx_uv = ddx(uv);
41+
BAB.Store(0, ddx_uv);
42+
}
43+
44+
// RDAT-LABEL: UnmangledName: "node_deriv_in_call"
45+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
46+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
47+
// RDAT: ShaderStageFlag: (Node)
48+
// RDAT: MinShaderTarget: 0xf0069
49+
50+
[shader("node")]
51+
[NodeLaunch("mesh")]
52+
[OutputTopology("line")]
53+
[NodeDispatchGrid(1, 1, 1)]
54+
[NumThreads(4,4,1)]
55+
void node_deriv_in_call(uint3 tid : SV_GroupThreadID) {
56+
float2 uv = tid.xy / float2(4, 4);
57+
use_deriv(uv);
58+
}
59+
60+
///////////////////////////////////////////////////////////////////////////////
61+
// Category: CalculateLOD
62+
63+
// RDAT-LABEL: UnmangledName: "node_calclod"
64+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
65+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
66+
// RDAT: ShaderStageFlag: (Node)
67+
// RDAT: MinShaderTarget: 0xf0069
68+
69+
[shader("node")]
70+
[NodeLaunch("mesh")]
71+
[OutputTopology("line")]
72+
[NodeDispatchGrid(1, 1, 1)]
73+
[NumThreads(4,4,1)]
74+
void node_calclod(uint3 tid : SV_GroupThreadID) {
75+
float2 uv = tid.xy / float2(4, 4);
76+
float lod = T2D.CalculateLevelOfDetail(Samp, uv);
77+
BAB.Store(0, lod);
78+
}
79+
80+
// RDAT-LABEL: UnmangledName: "use_calclod"
81+
// RDAT: FeatureInfo1: 0
82+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
83+
// RDAT: ShaderStageFlag: (Pixel | Compute | Library | Mesh | Amplification | Node)
84+
// RDAT: MinShaderTarget: 0x60060
85+
86+
[noinline] export
87+
void use_calclod(float2 uv) {
88+
float lod = T2D.CalculateLevelOfDetail(Samp, uv);
89+
BAB.Store(0, lod);
90+
}
91+
92+
// RDAT-LABEL: UnmangledName: "node_calclod_in_call"
93+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
94+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
95+
// RDAT: ShaderStageFlag: (Node)
96+
// RDAT: MinShaderTarget: 0xf0069
97+
98+
[shader("node")]
99+
[NodeLaunch("mesh")]
100+
[OutputTopology("line")]
101+
[NodeDispatchGrid(1, 1, 1)]
102+
[NumThreads(4,4,1)]
103+
void node_calclod_in_call(uint3 tid : SV_GroupThreadID) {
104+
float2 uv = tid.xy / float2(4, 4);
105+
use_calclod(uv);
106+
}
107+
108+
///////////////////////////////////////////////////////////////////////////////
109+
// Category: Sample with implicit derivatives
110+
111+
// RDAT-LABEL: UnmangledName: "node_sample"
112+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
113+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
114+
// RDAT: ShaderStageFlag: (Node)
115+
// RDAT: MinShaderTarget: 0xf0069
116+
117+
[shader("node")]
118+
[NodeLaunch("mesh")]
119+
[OutputTopology("line")]
120+
[NodeDispatchGrid(1, 1, 1)]
121+
[NumThreads(4,4,1)]
122+
void node_sample(uint3 tid : SV_GroupThreadID) {
123+
float2 uv = tid.xy / float2(4, 4);
124+
float4 color = T2D.Sample(Samp, uv);
125+
BAB.Store(0, color);
126+
}
127+
128+
// RDAT-LABEL: UnmangledName: "use_sample"
129+
// RDAT: FeatureInfo1: 0
130+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
131+
// RDAT: ShaderStageFlag: (Pixel | Compute | Library | Mesh | Amplification | Node)
132+
// RDAT: MinShaderTarget: 0x60060
133+
134+
[noinline] export
135+
void use_sample(float2 uv) {
136+
float4 color = T2D.Sample(Samp, uv);
137+
BAB.Store(0, color);
138+
}
139+
140+
// RDAT-LABEL: UnmangledName: "node_sample_in_call"
141+
// RDAT: FeatureInfo1: (DerivativesInMeshAndAmpShaders)
142+
// RDAT: FeatureInfo2: (Opt_UsesDerivatives)
143+
// RDAT: ShaderStageFlag: (Node)
144+
// RDAT: MinShaderTarget: 0xf0069
145+
146+
[shader("node")]
147+
[NodeLaunch("mesh")]
148+
[OutputTopology("line")]
149+
[NodeDispatchGrid(1, 1, 1)]
150+
[NumThreads(4,4,1)]
151+
void node_sample_in_call(uint3 tid : SV_GroupThreadID) {
152+
float2 uv = tid.xy / float2(4, 4);
153+
use_sample(uv);
154+
}
155+
156+
///////////////////////////////////////////////////////////////////////////////
157+
// Category: Quad ops
158+
// Quad ops do not set the UsesDerivatives flag, only requiring wave ops flag.
159+
160+
// RDAT-LABEL: UnmangledName: "node_quad"
161+
// RDAT: FeatureInfo1: (WaveOps)
162+
// RDAT: FeatureInfo2: 0
163+
// RDAT: MinShaderTarget: 0xf0069
164+
165+
[shader("node")]
166+
[NodeLaunch("mesh")]
167+
[OutputTopology("line")]
168+
[NodeDispatchGrid(1, 1, 1)]
169+
[NumThreads(4,4,1)]
170+
void node_quad(uint3 tid : SV_GroupThreadID) {
171+
float2 uv = tid.xy / float2(4, 4);
172+
float2 result = QuadReadAcrossDiagonal(uv);
173+
BAB.Store(0, result);
174+
}
175+
176+
// RDAT-LABEL: UnmangledName: "use_quad"
177+
// RDAT: FeatureInfo1: (WaveOps)
178+
// RDAT: FeatureInfo2: 0
179+
// RDAT: ShaderStageFlag: (Pixel | Compute | Library | Mesh | Amplification | Node)
180+
// RDAT: MinShaderTarget: 0x60060
181+
182+
[noinline] export
183+
void use_quad(float2 uv) {
184+
float2 result = QuadReadAcrossDiagonal(uv);
185+
BAB.Store(0, result);
186+
}
187+
188+
// RDAT-LABEL: UnmangledName: "node_quad_in_call"
189+
// RDAT: FeatureInfo1: (WaveOps)
190+
// RDAT: FeatureInfo2: 0
191+
// RDAT: ShaderStageFlag: (Node)
192+
// RDAT: MinShaderTarget: 0xf0069
193+
194+
[shader("node")]
195+
[NodeLaunch("mesh")]
196+
[OutputTopology("line")]
197+
[NodeDispatchGrid(1, 1, 1)]
198+
[NumThreads(4,4,1)]
199+
void node_quad_in_call(uint3 tid : SV_GroupThreadID) {
200+
float2 uv = tid.xy / float2(4, 4);
201+
use_quad(uv);
202+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// REQUIRES: dxil-1-9
2+
// RUN: %dxc -Tlib_6_9 %s -verify
3+
4+
// expected-no-diagnostics
5+
6+
SamplerComparisonState s;
7+
Texture1D t;
8+
9+
10+
[Shader("node")]
11+
[NodeLaunch("mesh")]
12+
[OutputTopology("line")]
13+
[NumThreads(4,1,1)]
14+
[NodeDispatchGrid(4,5,6)]
15+
[NodeIsProgramEntry]
16+
[NodeID("some_call_me_tim", 7)]
17+
[NodeLocalRootArgumentsTableIndex(13)]
18+
void node01_mesh_dispatch() {
19+
float a = 3.0;
20+
(void)(t.CalculateLevelOfDetail(s, a) +
21+
t.CalculateLevelOfDetailUnclamped(s, a));
22+
}

utils/hct/hctdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8127,7 +8127,7 @@ def build_valrules(self):
81278127
)
81288128
self.add_valrule_msg(
81298129
"Sm.IncompatibleDerivLaunch",
8130-
"Node shaders only support derivatives in broadcasting launch mode",
8130+
"Node shaders only support derivatives in broadcasting or mesh launch modes",
81318131
"Function called from %0 launch node shader uses derivatives; only broadcasting launch supports derivatives.",
81328132
)
81338133

0 commit comments

Comments
 (0)