Skip to content

Commit 2bdb776

Browse files
author
Greg Roth
authored
Enable system values for mesh nodes (#6472)
This allows the intersection of system values shared by mesh shaders and broadcasting nodes to be used by allowing mesh nodes to accept most input values that broadcasting nodes do. The exception is the writable DispatchRecords. It also involved rewording some error messages to account for mesh nodes which required changing some existing tests to just accept the new diagnostic. New tests were added to ensure that all the system values are accepted and treated properly and also that invalid usages were rejected. Fixes #6470
1 parent 15fe220 commit 2bdb776

12 files changed

Lines changed: 361 additions & 20 deletions

File tree

lib/DXIL/DxilNodeProps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,18 @@ bool NodeFlags::IsValidNodeKind() const {
7070

7171
bool NodeFlags::RecordTypeMatchesLaunchType(
7272
DXIL::NodeLaunchType launchType) const {
73-
DXIL::NodeIOFlags recordLaunchType = (DXIL::NodeIOFlags)(
73+
DXIL::NodeIOFlags granularity = (DXIL::NodeIOFlags)(
7474
(uint32_t)m_Flags & (uint32_t)DXIL::NodeIOFlags::RecordGranularityMask);
75+
uint32_t writable =
76+
((uint32_t)m_Flags & (uint32_t)DXIL::NodeIOFlags::ReadWrite);
7577
return (launchType == DXIL::NodeLaunchType::Broadcasting &&
76-
recordLaunchType == DXIL::NodeIOFlags::DispatchRecord) ||
78+
granularity == DXIL::NodeIOFlags::DispatchRecord) ||
7779
(launchType == DXIL::NodeLaunchType::Coalescing &&
78-
recordLaunchType == DXIL::NodeIOFlags::GroupRecord) ||
80+
granularity == DXIL::NodeIOFlags::GroupRecord) ||
7981
(launchType == DXIL::NodeLaunchType::Thread &&
80-
recordLaunchType == DXIL::NodeIOFlags::ThreadRecord);
82+
granularity == DXIL::NodeIOFlags::ThreadRecord) ||
83+
(launchType == DXIL::NodeLaunchType::Mesh &&
84+
granularity == DXIL::NodeIOFlags::DispatchRecord && !writable);
8185
}
8286

8387
void NodeFlags::SetTrackRWInputSharing() {

lib/HLSL/DxilValidation.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,6 +2180,8 @@ std::string GetLaunchTypeStr(DXIL::NodeLaunchType LT) {
21802180
return "Coalescing";
21812181
case DXIL::NodeLaunchType::Thread:
21822182
return "Thread";
2183+
case DXIL::NodeLaunchType::Mesh:
2184+
return "Mesh";
21832185
default:
21842186
return "Invalid";
21852187
}
@@ -2423,7 +2425,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
24232425
break;
24242426
}
24252427

2426-
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting)
2428+
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting ||
2429+
nodeLaunchType == DXIL::NodeLaunchType::Mesh)
24272430
break;
24282431

24292432
ValCtx.EmitInstrFormatError(
@@ -2436,7 +2439,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
24362439
break;
24372440
}
24382441

2439-
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting)
2442+
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting ||
2443+
nodeLaunchType == DXIL::NodeLaunchType::Mesh)
24402444
break;
24412445

24422446
ValCtx.EmitInstrFormatError(
@@ -2450,7 +2454,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
24502454
}
24512455

24522456
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting ||
2453-
nodeLaunchType == DXIL::NodeLaunchType::Coalescing)
2457+
nodeLaunchType == DXIL::NodeLaunchType::Coalescing ||
2458+
nodeLaunchType == DXIL::NodeLaunchType::Mesh)
24542459
break;
24552460

24562461
ValCtx.EmitInstrFormatError(CI,
@@ -2466,7 +2471,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
24662471
}
24672472

24682473
if (nodeLaunchType == DXIL::NodeLaunchType::Broadcasting ||
2469-
nodeLaunchType == DXIL::NodeLaunchType::Coalescing)
2474+
nodeLaunchType == DXIL::NodeLaunchType::Coalescing ||
2475+
nodeLaunchType == DXIL::NodeLaunchType::Mesh)
24702476
break;
24712477

24722478
ValCtx.EmitInstrFormatError(CI,
@@ -3582,6 +3588,9 @@ static void ValidateNodeInputRecord(Function *F, ValidationContext &ValCtx) {
35823588
case DXIL::NodeLaunchType::Thread:
35833589
validInputs = "{RW}ThreadNodeInputRecord";
35843590
break;
3591+
case DXIL::NodeLaunchType::Mesh:
3592+
validInputs = "DispatchNodeInputRecord";
3593+
break;
35853594
default:
35863595
llvm_unreachable("invalid launch type");
35873596
}

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7851,7 +7851,7 @@ def err_hlsl_wg_nodetrackrwinputsharing_invalid : Error<
78517851
"NodeTrackRWInputSharing attribute cannot be applied to Input Records that are not RWDispatchNodeInputRecord">;
78527852
def err_hlsl_wg_input_kind : Error<
78537853
"'%0' may not be used with %1 nodes (only %select{DispatchNodeInputRecord or RWDispatchNodeInputRecord|"
7854-
"GroupNodeInputRecords, RWGroupNodeInputRecords, or EmptyNodeInput|ThreadNodeInputRecord or RWThreadNodeInputRecord}2)">;
7854+
"GroupNodeInputRecords, RWGroupNodeInputRecords, or EmptyNodeInput|ThreadNodeInputRecord or RWThreadNodeInputRecord|DispatchNodeInputRecord}2)">;
78557855
def err_hlsl_wg_attr_only_on_output : Error<
78567856
"attribute %0 may only be used with output nodes">;
78577857
def err_hlsl_wg_attr_only_on_output_or_input_record : Error<
@@ -7891,9 +7891,9 @@ def err_hlsl_incompatible_node_attr : Error<
78917891
def err_hlsl_missing_node_attr : Error<
78927892
"Node shader '%0' with %1 launch type requires '%2' attribute">;
78937893
def err_hlsl_missing_dispatchgrid_attr : Error<
7894-
"Broadcasting node shader '%0' must have either the NodeDispatchGrid or NodeMaxDispatchGrid attribute">;
7894+
"Broadcasting/Mesh node shader '%0' must have either the NodeDispatchGrid or NodeMaxDispatchGrid attribute">;
78957895
def err_hlsl_missing_dispatchgrid_semantic : Error<
7896-
"Broadcasting node shader '%0' with NodeMaxDispatchGrid attribute must declare an input record containing a field with SV_DispatchGrid semantic">;
7896+
"Broadcasting/Mesh node shader '%0' with NodeMaxDispatchGrid attribute must declare an input record containing a field with SV_DispatchGrid semantic">;
78977897
def err_hlsl_dispatchgrid_semantic_already_specified : Error<
78987898
"a field with SV_DispatchGrid has already been specified">;
78997899
def err_hlsl_incompatible_dispatchgrid_semantic_type : Error<

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15182,6 +15182,9 @@ static bool nodeInputIsCompatible(DXIL::NodeIOKind IOType,
1518215182
DXIL::NodeLaunchType launchType) {
1518315183
switch (IOType) {
1518415184
case DXIL::NodeIOKind::DispatchNodeInputRecord:
15185+
return launchType == DXIL::NodeLaunchType::Broadcasting ||
15186+
launchType == DXIL::NodeLaunchType::Mesh;
15187+
1518515188
case DXIL::NodeIOKind::RWDispatchNodeInputRecord:
1518615189
return launchType == DXIL::NodeLaunchType::Broadcasting;
1518715190

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: %dxc -T lib_6_9 %s | FileCheck %s
2+
3+
// REQUIRES: dxil-1-9
4+
5+
// Test all valid mesh node input parameters work
6+
7+
// CHECK: define void @node01()
8+
// CHECK: %[[tid_x:.+]] = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component)
9+
// CHECK: %[[tid_y:.+]] = call i32 @dx.op.threadId.i32(i32 93, i32 1) ; ThreadId(component)
10+
// CHECK: %[[tid_z:.+]] = call i32 @dx.op.threadId.i32(i32 93, i32 2) ; ThreadId(component)
11+
12+
// CHECK: %[[ftid:.+]] = call i32 @dx.op.flattenedThreadIdInGroup.i32(i32 96) ; FlattenedThreadIdInGroup()
13+
14+
// CHECK: %[[tid_group_x:.+]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 0) ; ThreadIdInGroup(component)
15+
// CHECK: %[[tid_group_y:.+]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 1) ; ThreadIdInGroup(component)
16+
// CHECK: %[[tid_group_z:.+]] = call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 2) ; ThreadIdInGroup(component)
17+
18+
// CHECK: %[[Hdl:.+]] = call %dx.types.NodeRecordHandle @dx.op.createNodeInputRecordHandle(i32 250, i32 0) ; CreateNodeInputRecordHandle(MetadataIdx)
19+
// CHECK: %[[annotHdl:.+]] = call %dx.types.NodeRecordHandle @dx.op.annotateNodeRecordHandle(i32 251, %dx.types.NodeRecordHandle %[[Hdl]], %dx.types.NodeRecordInfo { i32 97, i32 52 }) ; AnnotateNodeRecordHandle(noderecord,props)
20+
21+
// CHECK: %[[node_ptr:.+]] = call %struct.RECORD.0 addrspace(6)* @dx.op.getNodeRecordPtr.struct.RECORD.0(i32 239, %dx.types.NodeRecordHandle %[[annotHdl]], i32 0) ; GetNodeRecordPtr(recordhandle,arrayIndex)
22+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 0, i32 0
23+
// CHECK: %[[ld1:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
24+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 0, i32 1
25+
// CHECK: %[[ld2:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
26+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 0, i32 2
27+
// CHECK: %[[ld3:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
28+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[0-9]+}}, i32 1, i32 0, i32 %[[ld1]], i32 %[[ld2]], i32 %[[ld3]], i32 undef, i8 7, i32 4) ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
29+
30+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 1, i32 0
31+
// CHECK: %[[ld1:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
32+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 1, i32 1
33+
// CHECK: %[[ld2:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
34+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 1, i32 2
35+
// CHECK: %[[ld3:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
36+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[0-9]+}}, i32 2, i32 0, i32 %[[ld1]], i32 %[[ld2]], i32 %[[ld3]], i32 undef, i8 7, i32 4) ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
37+
38+
// CHECK: %[[ptr:.+]] = getelementptr %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 2
39+
// CHECK: %[[ld1:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
40+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[0-9]+}}, i32 3, i32 0, i32 %[[ld1]], i32 %[[ld1]], i32 %[[ld1]], i32 undef, i8 7, i32 4) ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
41+
42+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 3, i32 0
43+
// CHECK: %[[ld1:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
44+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 3, i32 1
45+
// CHECK: %[[ld2:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
46+
// CHECK: %[[ptr:.+]] = getelementptr inbounds %struct.RECORD.0, %struct.RECORD.0 addrspace(6)* %[[node_ptr]], i32 0, i32 3, i32 2
47+
// CHECK: %[[ld3:.+]] = load i32, i32 addrspace(6)* %[[ptr]], align 4
48+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[0-9]+}}, i32 4, i32 0, i32 %[[ld1]], i32 %[[ld2]], i32 %[[ld3]], i32 undef, i8 7, i32 4) ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
49+
// CHECK: ret void
50+
51+
struct RECORD
52+
{
53+
uint3 dtid;
54+
uint3 gid;
55+
uint gidx;
56+
uint3 gtid;
57+
uint3 dg : SV_DispatchGrid;
58+
};
59+
60+
RWStructuredBuffer<uint3> outbuf;
61+
62+
[Shader("node")]
63+
[numthreads(4,4,4)]
64+
[NodeMaxDispatchGrid(4,4,4)]
65+
[NodeLaunch("mesh")]
66+
[OutputTopology("line")]
67+
void node01(DispatchNodeInputRecord<RECORD> input,
68+
uint3 DTID : SV_DispatchThreadID,
69+
uint3 GID : SV_GroupID,
70+
uint GIdx : SV_GroupIndex,
71+
uint3 GTID : SV_GroupThreadID )
72+
{
73+
outbuf[0] = input.Get().dg;
74+
if (any(DTID) != 0)
75+
return;
76+
outbuf[1] = input.Get().dtid;
77+
if (any(GID) != 0)
78+
return;
79+
outbuf[2] = input.Get().gid;
80+
if (GIdx != 0)
81+
return;
82+
outbuf[3] = input.Get().gidx;
83+
if (any(GTID) != 0)
84+
return;
85+
outbuf[4] = input.Get().gtid;
86+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %dxc -T lib_6_9 -verify %s
2+
3+
// REQUIRES: dxil-1-9
4+
5+
// Test that invalid mesh node input parameters fail with appropriate diagnostics
6+
7+
struct RECORD {
8+
uint3 gtid;
9+
};
10+
11+
[Shader("node")]
12+
[numthreads(4,4,4)]
13+
[NodeDispatchGrid(4,4,4)]
14+
[NodeLaunch("mesh")] // expected-note {{Launch type defined here}}
15+
void node01_rw(RWDispatchNodeInputRecord<RECORD> input, // expected-error {{'RWDispatchNodeInputRecord' may not be used with mesh nodes (only DispatchNodeInputRecord)}}
16+
uint3 GTID : SV_GroupThreadID ) {
17+
input.Get().gtid = GTID;
18+
}
19+
20+
[Shader("node")]
21+
[numthreads(4,4,4)]
22+
[NodeMaxDispatchGrid(4,4,4)]
23+
[NodeLaunch("mesh")]
24+
void node02_maxdisp(DispatchNodeInputRecord<RECORD> input, // expected-error {{Broadcasting/Mesh node shader 'node02_maxdisp' with NodeMaxDispatchGrid attribute must declare an input record containing a field with SV_DispatchGrid semantic}}
25+
uint3 GTID : SV_GroupThreadID ) {
26+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Source file for altered mesh-node-inputs.ll
2+
// Not intended for indpendent testing
3+
4+
// Run line required in this location, so we'll verify compilation succeeds.
5+
// RUN: %dxc -T lib_6_8 %s | FileCheck %s
6+
// CHECK: define void @node_RWDispatchNodeInputRecord()
7+
// CHECK: define void @node_GroupNodeInputRecords()
8+
// CHECK: define void @node_RWGroupNodeInputRecords()
9+
// CHECK: define void @node_ThreadNodeInputRecord()
10+
// CHECK: define void @node_RWThreadNodeInputRecord()
11+
12+
RWBuffer<uint> buf0;
13+
14+
struct RECORD {
15+
uint ival;
16+
};
17+
18+
[Shader("node")]
19+
[NumThreads(1024,1,1)]
20+
[NodeDispatchGrid(64,1,1)]
21+
[NodeLaunch("broadcasting")]
22+
void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord<RECORD> input) {
23+
buf0[0] = input.Get().ival;
24+
}
25+
26+
[Shader("node")]
27+
[NodeLaunch("coalescing")]
28+
[NumThreads(1024,1,1)]
29+
void node_GroupNodeInputRecords(GroupNodeInputRecords<RECORD> input) {
30+
buf0[0] = input.Get().ival;
31+
}
32+
33+
[Shader("node")]
34+
[NodeLaunch("coalescing")]
35+
[NumThreads(1024,1,1)]
36+
void node_RWGroupNodeInputRecords(RWGroupNodeInputRecords<RECORD> input) {
37+
buf0[0] = input.Get().ival;
38+
}
39+
40+
[Shader("node")]
41+
[NodeLaunch("thread")]
42+
void node_ThreadNodeInputRecord(ThreadNodeInputRecord<RECORD> input) {
43+
buf0[0] = input.Get().ival;
44+
}
45+
46+
[Shader("node")]
47+
[NodeLaunch("thread")]
48+
void node_RWThreadNodeInputRecord(RWThreadNodeInputRecord<RECORD> input) {
49+
buf0[0] = input.Get().ival;
50+
}

0 commit comments

Comments
 (0)