Skip to content

Commit f9389db

Browse files
authored
PIX: Correct the disambiguation of AS+MS threads for mesh shader output (#6592)
PIX requires that all vertex information generated by these passes be uniquely identified by vertex id and MS thread id. This change fixes the MS thread id part in two places: the amplification shader and the mesh shader. To be unique across an entire DispatchMesh call, we must uniquify the AS thread group, the AS thread, the MS thread group and the MS thread. This is a lot of multiplying and adding, and there wasn't quite enough math going on here before. In the AS case, we now generate a unique "flat" thread id from the flat-thread-id-in-group (the already-available system value) and the "flat group id", which we synthesize by multiplying together the group id components with the DispatchMesh API's thread group counts, and then multiplying that by the number of threads each AS group launches, then add the flat-thread-id-in-group. (This flat id then goes into an expanded version of the AS->MS payload, the code for which was pre-existing.) The MS will either treat the incoming AS thread id as its unique thread-group-within-the-whole-dispatch id. If the AS is not active, the instrumentation herein will synthesize a flat id in the same way as the AS did before it passed that id through the payload, again from the DispatchMesh parameters (newly-added params to that pass) and the flat-thread-in-group. In addition to the new filecheck tests for this, there is also a new filecheck test to cover coercion of non-i32 types to i32 before being written to PIX's output UAV, which I happened to notice wasn't adequately tested.
1 parent f6d1759 commit f9389db

6 files changed

Lines changed: 190 additions & 32 deletions

File tree

lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,29 +117,58 @@ bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
117117

118118
llvm::IRBuilder<> B(UserInstruction);
119119

120-
auto ThreadIdFunc =
121-
HlslOP->GetOpFunc(DXIL::OpCode::ThreadId, Type::getInt32Ty(Ctx));
122-
Constant *Opcode = HlslOP->GetU32Const((unsigned)DXIL::OpCode::ThreadId);
123120
Constant *Zero32Arg = HlslOP->GetU32Const(0);
124121
Constant *One32Arg = HlslOP->GetU32Const(1);
125122
Constant *Two32Arg = HlslOP->GetU32Const(2);
126123

127-
auto ThreadIdX =
128-
B.CreateCall(ThreadIdFunc, {Opcode, Zero32Arg}, "ThreadIdX");
129-
auto ThreadIdY =
130-
B.CreateCall(ThreadIdFunc, {Opcode, One32Arg}, "ThreadIdY");
131-
auto ThreadIdZ =
132-
B.CreateCall(ThreadIdFunc, {Opcode, Two32Arg}, "ThreadIdZ");
133-
134-
auto *XxY =
135-
B.CreateMul(ThreadIdX, HlslOP->GetU32Const(m_DispatchArgumentY));
136-
auto *XplusY = B.CreateAdd(ThreadIdY, XxY);
137-
auto *XYxZ = B.CreateMul(XplusY, HlslOP->GetU32Const(m_DispatchArgumentZ));
138-
auto *XYZ = B.CreateAdd(ThreadIdZ, XYxZ);
124+
auto GroupIdFunc =
125+
HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx));
126+
Constant *GroupIdOpcode =
127+
HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId);
128+
auto *GroupIdX =
129+
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX");
130+
auto *GroupIdY =
131+
B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY");
132+
auto *GroupIdZ =
133+
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ");
134+
135+
// FlatGroupID = z + y*numZ + x*numY*numZ
136+
// Where x,y,z are the group ID components, and numZ and numY are the
137+
// corresponding AS group-count arguments to the DispatchMesh Direct3D API
138+
auto *GroupYxNumZ = B.CreateMul(
139+
GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ");
140+
auto *FlatGroupNumZY = B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY");
141+
auto *GroupXxNumYZ = B.CreateMul(
142+
GroupIdX,
143+
HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ),
144+
"GroupXxNumYZ");
145+
auto *FlatGroupID =
146+
B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroFlatGroupIDupNum");
147+
148+
// The ultimate goal is a single unique thread ID for this AS thread.
149+
// So take the flat group number, multiply it by the number of
150+
// threads per group...
151+
auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul(
152+
FlatGroupID,
153+
HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) *
154+
DM.GetNumThreads(2)),
155+
"FlatGroupIDWithSpaceForThreadInGroupId");
156+
157+
auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc(
158+
DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx));
159+
Constant *FlattenedThreadIdInGroupOpcode =
160+
HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup);
161+
auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc,
162+
{FlattenedThreadIdInGroupOpcode},
163+
"FlattenedThreadIdInGroup");
164+
165+
// ...and add the flat thread id:
166+
auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId,
167+
FlatThreadIdInGroup, "FlatId");
139168

140169
AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca,
141170
OriginalPayloadStructType->getStructNumElements(),
142-
XYZ);
171+
FlatId);
143172
AddValueToExpandedPayload(
144173
HlslOP, B, expanded, NewStructAlloca,
145174
OriginalPayloadStructType->getStructNumElements() + 1,

lib/DxilPIXPasses/DxilPIXMeshShaderOutputInstrumentation.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class DxilPIXMeshShaderOutputInstrumentation : public ModulePass {
6969

7070
uint64_t m_UAVSize = 1024 * 1024;
7171
bool m_ExpandPayload = false;
72+
uint32_t m_DispatchArgumentY = 1;
73+
uint32_t m_DispatchArgumentZ = 1;
7274

7375
struct BuilderContext {
7476
Module &M;
@@ -91,6 +93,8 @@ class DxilPIXMeshShaderOutputInstrumentation : public ModulePass {
9193
void DxilPIXMeshShaderOutputInstrumentation::applyOptions(PassOptions O) {
9294
GetPassOptionUInt64(O, "UAVSize", &m_UAVSize, 1024 * 1024);
9395
GetPassOptionBool(O, "expand-payload", &m_ExpandPayload, 0);
96+
GetPassOptionUInt32(O, "dispatchArgY", &m_DispatchArgumentY, 1);
97+
GetPassOptionUInt32(O, "dispatchArgZ", &m_DispatchArgumentZ, 1);
9498
}
9599

96100
uint32_t DxilPIXMeshShaderOutputInstrumentation::UAVDumpingGroundOffset() {
@@ -244,16 +248,24 @@ SmallVector<Value *, 2> DxilPIXMeshShaderOutputInstrumentation::
244248
auto *GroupIdZ =
245249
Builder.CreateCall(GroupIdFunc, {Opcode, Two32Arg}, "GroupIdZ");
246250

247-
auto *XxY = AmplificationShaderIsActive
248-
? Builder.CreateMul(GroupIdX, ASDispatchMeshYCount)
249-
: GroupIdX;
250-
auto *XplusY = Builder.CreateAdd(GroupIdY, XxY);
251-
auto *XYxZ = AmplificationShaderIsActive
252-
? Builder.CreateMul(XplusY, ASDispatchMeshZCount)
253-
: XplusY;
254-
auto *XYZ = Builder.CreateAdd(GroupIdZ, XYxZ);
255-
256-
ret.push_back(XYZ);
251+
// flattend group number = z + y*numZ + x*numY*numZ
252+
if (AmplificationShaderIsActive) {
253+
auto *GroupYxNumZ = Builder.CreateMul(GroupIdY, ASDispatchMeshZCount);
254+
auto *FlatGroupNumZY = Builder.CreateAdd(GroupIdZ, GroupYxNumZ);
255+
auto *GroupXxNumZ = Builder.CreateMul(GroupIdX, ASDispatchMeshZCount);
256+
auto *GroupXxNumYZ = Builder.CreateMul(GroupXxNumZ, ASDispatchMeshYCount);
257+
auto *FlatGroupNum = Builder.CreateAdd(GroupXxNumYZ, FlatGroupNumZY);
258+
ret.push_back(FlatGroupNum);
259+
} else {
260+
auto *GroupYxNumZ =
261+
Builder.CreateMul(GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ));
262+
auto *FlatGroupNumZY = Builder.CreateAdd(GroupIdZ, GroupYxNumZ);
263+
auto *GroupXxNumYZ =
264+
Builder.CreateMul(GroupIdX, HlslOP->GetU32Const(m_DispatchArgumentY *
265+
m_DispatchArgumentZ));
266+
auto *FlatGroupNum = Builder.CreateAdd(GroupXxNumYZ, FlatGroupNumZY);
267+
ret.push_back(FlatGroupNum);
268+
}
257269

258270
return ret;
259271
}
Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
11
// RUN: %dxc -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s
22

3-
// CHECK: mul i32 %ThreadIdX, 3
4-
// CHECK: mul i32
5-
// CHECK: , 7
6-
// CHECK: @dx.op.dispatchMesh.PIX_AS2MS_Expanded_Type
3+
// CHECK: [[AppsGroupIdX:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 0)
4+
// CHECK: [[GROUPIDX:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 0)
5+
// CHECK: [[GROUPIDY:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 1)
6+
// CHECK: [[GROUPIDZ:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 2)
77

8+
// The integer literals here come from the dispatchArg* arguments in the command line above
9+
// CHECK: [[YTIMES7:%.*]] = mul i32 [[GROUPIDY]], 7
10+
// CHECK: [[ZPLUSYTIMES7:%.*]] = add i32 [[GROUPIDZ]], [[YTIMES7]]
11+
// CHECK: [[XTIMES21:%.*]] = mul i32 [[GROUPIDX]], 21
12+
// CHECK: [[FLATGROUPNUM:%.*]] = add i32 [[XTIMES21]], [[ZPLUSYTIMES7]]
13+
// This 105 is the thread counts for the shader multiplied together
14+
// CHECK: [[FLATWITHSPACE:%.*]] = mul i32 [[FLATGROUPNUM]], 105
15+
// CHECK: [[FLATTENEDTHREADIDINGROUP:%.*]] = call i32 @dx.op.flattenedThreadIdInGroup.i32(i32 96)
16+
// CHECK: [[FLATID:%.*]] = add i32 [[FLATWITHSPACE]], [[FLATTENEDTHREADIDINGROUP]]
17+
18+
// Check that this flat ID is stored into the expanded payload:
19+
// CHECKL store i32 [[FLATID]], i32*
20+
21+
// Check that the Y and Z dispatch-mesh counts are emitted to the expanded payload:
22+
// CHECK: store i32 3, i32*
23+
// CHECK: store i32 4, i32*
824
struct MyPayload
925
{
1026
uint i;
1127
};
1228

13-
[numthreads(1, 1, 1)]
29+
[numthreads(3, 5, 7)]
1430
void main(uint gid : SV_GroupID)
1531
{
1632
MyPayload payload;
1733
payload.i = gid;
18-
DispatchMesh(1, 1, 1, payload);
34+
DispatchMesh(2, 3, 4, payload);
1935
}
36+
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %dxc -EMSMain -enable-16bit-types -Tms_6_6 %s | %opt -S -hlsl-dxil-pix-meshshader-output-instrumentation,expand-payload=0,dispatchArgY=3,dispatchArgZ=7,UAVSize=8192 | %FileCheck %s
2+
3+
// Check that the instrumentation properly coerces different output types into int32
4+
5+
// CHECK: [[PAYLOAD:%.*]] = call %struct.MyPayload* @dx.op.getMeshPayload.struct.MyPayload(i32 170)
6+
7+
// FP16 value in 0th member
8+
// CHECK: [[F16:%.*]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* [[PAYLOAD]], i32 0, i32 0
9+
// CHECK: [[F16LOADED:%.*]] = load half, half* [[F16]]
10+
// CHECK: [[F16CONV:%.*]] = fpext half [[F16LOADED]] to float
11+
12+
// uint16 value in 1st member
13+
// CHECK: [[U16:%.*]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* [[PAYLOAD]], i32 0, i32 1
14+
// CHECK: [[U16LOADED:%.*]] = load i16, i16* [[U16]]
15+
// CHECK: [[U16CONV:%.*]] = uitofp i16 [[U16LOADED]] to float
16+
17+
// int16 value in 2nd member
18+
// CHECK: [[I16:%.*]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* [[PAYLOAD]], i32 0, i32 2
19+
// CHECK: [[I16LOADED:%.*]] = load i16, i16* [[I16]]
20+
// CHECK: [[I16CONV:%.*]] = sitofp i16 [[I16LOADED]] to float
21+
22+
// float value in 3rd member
23+
// CHECK: [[FP:%.*]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* [[PAYLOAD]], i32 0, i32 3
24+
// CHECK: [[FPLOADED:%.*]] = load float, float* [[FP]]
25+
26+
27+
// Check that these converted values are written out:
28+
// CHECK: [[F16COERCED:%.*]] = bitcast float [[F16CONV]] to i32
29+
// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[ANYRESOURCE:%.*]], i32 [[ANYOFFSET:%.*]], i32 undef, i32 [[F16COERCED]]
30+
// CHECK: [[U16COERCED:%.*]] = bitcast float [[U16CONV]] to i32
31+
// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[ANYRESOURCE:%.*]], i32 [[ANYOFFSET:%.*]], i32 undef, i32 [[U16COERCED]]
32+
// CHECK: [[I16COERCED:%.*]] = bitcast float [[I16CONV]] to i32
33+
// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[ANYRESOURCE:%.*]], i32 [[ANYOFFSET:%.*]], i32 undef, i32 [[I16COERCED]]
34+
// CHECK: [[FPCOERCED:%.*]] = bitcast float [[FPLOADED]] to i32
35+
// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[ANYRESOURCE:%.*]], i32 [[ANYOFFSET:%.*]], i32 undef, i32 [[FPCOERCED]]
36+
37+
struct PSInput
38+
{
39+
float4 position : SV_POSITION;
40+
};
41+
42+
struct MyPayload
43+
{
44+
half f16;
45+
uint16_t u16;
46+
int16_t i16;
47+
float f;
48+
};
49+
50+
[outputtopology("triangle")]
51+
[numthreads(4, 1, 1)]
52+
void MSMain(
53+
in payload MyPayload small,
54+
in uint tid : SV_GroupThreadID,
55+
out vertices PSInput verts[4],
56+
out indices uint3 triangles[2])
57+
{
58+
SetMeshOutputCounts(4, 2);
59+
verts[tid].position = float4(small.f16, small.u16, small.i16, small.f);
60+
triangles[tid % 2] = uint3(0, tid + 1, tid + 2);
61+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %dxc -EMSMain -Tms_6_6 %s | %opt -S -hlsl-dxil-pix-meshshader-output-instrumentation,expand-payload=0,dispatchArgY=3,dispatchArgZ=7,UAVSize=8192 | %FileCheck %s
2+
3+
// Check that the instrumentation calculates the expected "flat" group ID:
4+
5+
// CHECK: [[GROUPIDX:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 0)
6+
// CHECK: [[GROUPIDY:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 1)
7+
// CHECK: [[GROUPIDZ:%.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 2)
8+
9+
// The integer literals here come from the dispatchArg* arguments in the command line above
10+
// CHECK: [[YTIMES7:%.*]] = mul i32 [[GROUPIDY]], 7
11+
// CHECK: [[ZPLUSYTIMES7:%.*]] = add i32 [[GROUPIDZ]], [[YTIMES7]]
12+
// CHECK: [[XTIMES21:%.*]] = mul i32 [[GROUPIDX]], 21
13+
// CHECK: [[FLATID:%.*]] = add i32 [[XTIMES21]], [[ZPLUSYTIMES7]]
14+
15+
// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[SOMERESOURCEHANDLE:%.*]], i32 [[SOMEOFFSET:%.*]], i32 undef, i32 [[FLATID]],
16+
struct PSInput
17+
{
18+
float4 position : SV_POSITION;
19+
};
20+
21+
struct MyPayload
22+
{
23+
uint i;
24+
};
25+
26+
[outputtopology("triangle")]
27+
[numthreads(4, 1, 1)]
28+
void MSMain(
29+
in payload MyPayload small,
30+
in uint tid : SV_GroupThreadID,
31+
out vertices PSInput verts[4],
32+
out indices uint3 triangles[2])
33+
{
34+
SetMeshOutputCounts(4, 2);
35+
verts[tid].position = float4(small.i, 0, 0, 0);
36+
triangles[tid % 2] = uint3(0, tid + 1, tid + 2);
37+
}

utils/hct/hctdb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6116,6 +6116,8 @@ def add_pass(name, type_name, doc, opts):
61166116
[
61176117
{"n": "expand-payload", "t": "int", "c": 1},
61186118
{"n": "UAVSize", "t": "int", "c": 1},
6119+
{"n": "dispatchArgY", "t": "int", "c": 1},
6120+
{"n": "dispatchArgZ", "t": "int", "c": 1},
61196121
],
61206122
)
61216123
add_pass(

0 commit comments

Comments
 (0)