Skip to content

Commit 72e0fa4

Browse files
author
Greg Roth
authored
Support SetMeshOutputCounts in mesh nodes (#6476)
Allow the usage of SetMeshOutputCounts() in mesh launch node shaders and no other launch type nodes. This requires identifying the op initially as permitted in node shaders, but restraining that permission later on. Given the availability of launch type and opcode information, the locations for doing that were limited. Future implementation might opt to encode the launch type in the stage mask or else at least temporarily use a different opcode for a SetMeshOutputCounts call within a node shader. Fixes #6473
1 parent 9ec9169 commit 72e0fa4

6 files changed

Lines changed: 99 additions & 8 deletions

File tree

lib/DXIL/DxilModule.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,6 +2332,22 @@ void DxilModule::UpdateFunctionToShaderCompat(const llvm::Function *dxilFunc) {
23322332
OP::GetMinShaderModelAndMask(CI, bWithTranslation, m_ValMajor, m_ValMinor,
23332333
major, minor, mask);
23342334
DXIL::UpdateToMaxOfVersions(info.minMajor, info.minMinor, major, minor);
2335+
2336+
// Fix up permitting SetMeshOutputCounts in node shaders
2337+
2338+
// This is a hack, but it is required to reject SetMeshOutputCounts from
2339+
// non-mesh node shaders. It is the least invasive place where the launch
2340+
// type and the opcode are both known. Within GetMinShaderModelAndMask,
2341+
// the launch type isn't known. After this point, the opcode that
2342+
// seemingly indicates support for all node shaders is lost.
2343+
// Ultimately, we'll probably need to encode launch type into the mask
2344+
if (DXIL::OpCode::SetMeshOutputCounts == OP::GetDxilOpFuncCallInst(CI) &&
2345+
HasDxilFunctionProps(F)) {
2346+
const DxilFunctionProps &props = GetDxilFunctionProps(F);
2347+
if (props.shaderKind != DXIL::ShaderKind::Node ||
2348+
props.Node.LaunchType != DXIL::NodeLaunchType::Mesh)
2349+
mask &= ~SFLAG(Node);
2350+
}
23352351
info.mask &= mask;
23362352
} else if (const llvm::LoadInst *LI = dyn_cast<LoadInst>(user)) {
23372353
// If loading a groupshared variable, limit to CS/AS/MS/Node

lib/DXIL/DxilOperations.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,14 +3251,21 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
32513251
mask = SFLAG(Library) | SFLAG(Pixel);
32523252
return;
32533253
}
3254-
// Instructions: SetMeshOutputCounts=168, EmitIndices=169, GetMeshPayload=170,
3255-
// StoreVertexOutput=171, StorePrimitiveOutput=172
3256-
if ((168 <= op && op <= 172)) {
3254+
// Instructions: EmitIndices=169, GetMeshPayload=170, StoreVertexOutput=171,
3255+
// StorePrimitiveOutput=172
3256+
if ((169 <= op && op <= 172)) {
32573257
major = 6;
32583258
minor = 5;
32593259
mask = SFLAG(Mesh);
32603260
return;
32613261
}
3262+
// Instructions: SetMeshOutputCounts=168
3263+
if (op == 168) {
3264+
major = 6;
3265+
minor = 5;
3266+
mask = SFLAG(Mesh) | SFLAG(Node);
3267+
return;
3268+
}
32623269
// Instructions: CreateHandleFromHeap=218, Unpack4x8=219, Pack4x8=220,
32633270
// IsHelperLane=221
32643271
if ((218 <= op && op <= 221)) {

lib/HLSL/DxilValidation.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,8 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
14631463
}
14641464
} break;
14651465
case DXIL::OpCode::SetMeshOutputCounts: {
1466-
if (!props.IsMS()) {
1466+
if (!props.IsMS() && (!props.IsNode() || props.Node.LaunchType !=
1467+
DXIL::NodeLaunchType::Mesh)) {
14671468
ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
14681469
{"SetMeshOutputCounts", "Mesh shader"});
14691470
}
@@ -2795,9 +2796,11 @@ static void ValidateMsIntrinsics(Function *F, ValidationContext &ValCtx,
27952796
CallInst *setMeshOutputCounts,
27962797
CallInst *getMeshPayload) {
27972798
if (ValCtx.DxilMod.HasDxilFunctionProps(F)) {
2798-
DXIL::ShaderKind shaderKind =
2799-
ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
2800-
if (shaderKind != DXIL::ShaderKind::Mesh)
2799+
DxilFunctionProps &props = ValCtx.DxilMod.GetDxilFunctionProps(F);
2800+
DXIL::ShaderKind shaderKind = props.shaderKind;
2801+
if (shaderKind != DXIL::ShaderKind::Mesh &&
2802+
(shaderKind != DXIL::ShaderKind::Node ||
2803+
props.Node.LaunchType != DXIL::NodeLaunchType::Mesh))
28012804
return;
28022805
} else {
28032806
return;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: not %dxc -T lib_6_9 %s 2>&1 | FileCheck %s
2+
3+
// REQUIRES: dxil-1-9
4+
5+
// Ensure that setMeshOutputCounts will be appropriately rejected by non-mesh nodes
6+
7+
RWBuffer<float> buf0;
8+
9+
// CHECK: 14: error: Function uses features incompatible with the shader stage (node) of the entry function.
10+
[Shader("node")]
11+
[NumThreads(1024,1,1)]
12+
[NodeDispatchGrid(64,1,1)]
13+
[NodeLaunch("broadcasting")]
14+
void node_broadcasting() {
15+
SetMeshOutputCounts(32, 16);
16+
buf0[0] = 1.0;
17+
}
18+
19+
// CHECK: 23: error: Function uses features incompatible with the shader stage (node) of the entry function.
20+
[Shader("node")]
21+
[NodeLaunch("coalescing")]
22+
[NumThreads(1024,1,1)]
23+
void node_coalescing() {
24+
SetMeshOutputCounts(32, 16);
25+
buf0[0] = 2.0;
26+
}
27+
28+
// CHECK: 31: error: Function uses features incompatible with the shader stage (node) of the entry function.
29+
[Shader("node")]
30+
[NodeLaunch("thread")]
31+
void node_thread() {
32+
SetMeshOutputCounts(32, 16);
33+
buf0[0] = 4.0;
34+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %dxc -T lib_6_9 %s | FileCheck %s
2+
3+
// REQUIRES: dxil-1-9
4+
5+
// Ensure that setMeshOutputCounts can be correctly lowered in a mesh node
6+
7+
// CHECK: define void @node_setmeshoutputcounts()
8+
// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
9+
// CHECK: ret void
10+
11+
// CHECK: declare void @dx.op.setMeshOutputCounts(i32, i32, i32) #0
12+
13+
RWBuffer<float> buf0;
14+
15+
16+
[Shader("node")]
17+
[NodeLaunch("mesh")]
18+
[outputtopology("triangle")]
19+
[numthreads(128, 1, 1)]
20+
[NodeDispatchGrid(64,1,1)]
21+
void node_setmeshoutputcounts() {
22+
SetMeshOutputCounts(32, 16);
23+
buf0[0] = 1.0;
24+
}

utils/hct/hctdb.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,16 @@ def populate_categories_and_models(self):
594594
for i in "WaveMatch,WaveMultiPrefixOp,WaveMultiPrefixBitCount".split(","):
595595
self.name_idx[i].category = "Wave"
596596
self.name_idx[i].shader_model = 6, 5
597+
for i in "SetMeshOutputCounts".split(","):
598+
self.name_idx[i].category = "Mesh shader instructions"
599+
self.name_idx[i].shader_stages = (
600+
"mesh",
601+
"node",
602+
)
603+
self.name_idx[i].shader_model = 6, 5
597604
for (
598605
i
599-
) in "SetMeshOutputCounts,EmitIndices,GetMeshPayload,StoreVertexOutput,StorePrimitiveOutput".split(
606+
) in "EmitIndices,GetMeshPayload,StoreVertexOutput,StorePrimitiveOutput".split(
600607
","
601608
):
602609
self.name_idx[i].category = "Mesh shader instructions"

0 commit comments

Comments
 (0)