Skip to content

Commit e5c51ea

Browse files
authored
Implement GetGroupWaveIndex and GetGroupWaveCount (#7959)
Implement Group Wave Index and Group Wave Count as proposed by: https://github.com/microsoft/hlsl-specs/blob/main/proposals/0048-group-wave-index.md Added two new intrinsics: - GetGroupWaveIndex - returns the index of the wave in the thread group - GetGroupWaveCount - returns the number of waves in the thread group Limited to Shader Model 6.10 and Compute, Mesh, Node and Amp. shaders. Added basic test.
1 parent 4969749 commit e5c51ea

22 files changed

Lines changed: 1093 additions & 4 deletions

docs/ReleaseNotes.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ The included licenses apply to the following files:
2424
- Header file `dxcpix.h` was added to the release package.
2525
- Moved Linear Algebra (Cooperative Vector) DXIL Opcodes to experimental Shader Model 6.10
2626
- Added support for `long long` and `unsigned long long` compile-time constant evaluation, fixes [#7952](https://github.com/microsoft/DirectXShaderCompiler/issues/7952).
27+
- Implement GetGroupWaveIndex and GetGroupWaveCount in experimental Shader Model 6.10
28+
- [proposal](https://github.com/microsoft/hlsl-specs/blob/main/proposals/0048-group-wave-index.md)
29+
- GetGroupWaveIndex: New intrinsic for Compute, Mesh, Amplification and Node shaders which returns the index of the wave within the thread group that the the thread is executing.
30+
- GetGroupWaveCount: New intrinsic for Compute, Mesh, Amplification and Node shaders which returns the total number of waves executing within the thread group.
2731

2832
### Version 1.8.2505
2933

lib/DXIL/DxilOperations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3938,7 +3938,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
39383938
if ((2147483649 <= op && op <= 2147483650)) {
39393939
major = 6;
39403940
minor = 10;
3941-
mask = SFLAG(Compute) | SFLAG(Mesh) | SFLAG(Amplification) | SFLAG(Library);
3941+
mask = SFLAG(Compute) | SFLAG(Mesh) | SFLAG(Amplification) | SFLAG(Node);
39423942
return;
39433943
}
39443944
// Instructions: ClusterID=2147483651, TriangleObjectPosition=2147483655

lib/DXIL/DxilShaderFlags.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,10 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
729729
if (OP::BarrierRequiresGroup(CI))
730730
requiresGroup = true;
731731
break;
732+
case DXIL::OpCode::GetGroupWaveIndex:
733+
case DXIL::OpCode::GetGroupWaveCount:
734+
requiresGroup = true;
735+
break;
732736
default:
733737
// Normal opcodes.
734738
break;

lib/HLSL/HLOperationLower.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7515,9 +7515,9 @@ constexpr IntrinsicLower gLowerTable[] = {
75157515

75167516
{IntrinsicOp::IOP_isnormal, TrivialIsSpecialFloat, DXIL::OpCode::IsNormal},
75177517

7518-
{IntrinsicOp::IOP_GetGroupWaveCount, EmptyLower,
7518+
{IntrinsicOp::IOP_GetGroupWaveCount, TranslateWaveToVal,
75197519
DXIL::OpCode::GetGroupWaveCount},
7520-
{IntrinsicOp::IOP_GetGroupWaveIndex, EmptyLower,
7520+
{IntrinsicOp::IOP_GetGroupWaveIndex, TranslateWaveToVal,
75217521
DXIL::OpCode::GetGroupWaveIndex},
75227522

75237523
{IntrinsicOp::IOP_ClusterID, EmptyLower, DXIL::OpCode::ClusterID},
@@ -7616,6 +7616,8 @@ static void TranslateBuiltinIntrinsic(CallInst *CI,
76167616
bool &Translated) {
76177617
unsigned opcode = hlsl::GetHLOpcode(CI);
76187618
const IntrinsicLower &lower = gLowerTable[opcode];
7619+
DXASSERT((unsigned)lower.IntriOpcode == opcode,
7620+
"Intrinsic lowering table index must match intrinsic opcode.");
76197621
Value *Result = lower.LowerFunc(CI, lower.IntriOpcode, lower.DxilOpcode,
76207622
helper, pObjHelper, Translated);
76217623
if (Result)

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4239,6 +4239,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
42394239
case spv::BuiltIn::LocalInvocationIndex:
42404240
case spv::BuiltIn::RemainingRecursionLevelsAMDX:
42414241
case spv::BuiltIn::ShaderIndexAMDX:
4242+
case spv::BuiltIn::SubgroupId:
4243+
case spv::BuiltIn::NumSubgroups:
42424244
sc = spv::StorageClass::Input;
42434245
break;
42444246
case spv::BuiltIn::TaskCountNV:

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9469,6 +9469,22 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
94699469
case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
94709470
retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
94719471
break;
9472+
case hlsl::IntrinsicOp::IOP_GetGroupWaveIndex: {
9473+
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "GetGroupWaveIndex",
9474+
srcLoc);
9475+
const QualType retType = callExpr->getCallReturnType(astContext);
9476+
auto *var =
9477+
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupId, retType, srcLoc);
9478+
retVal = spvBuilder.createLoad(retType, var, srcLoc, srcRange);
9479+
} break;
9480+
case hlsl::IntrinsicOp::IOP_GetGroupWaveCount: {
9481+
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "GetGroupWaveCount",
9482+
srcLoc);
9483+
const QualType retType = callExpr->getCallReturnType(astContext);
9484+
auto *var =
9485+
declIdMapper.getBuiltinVar(spv::BuiltIn::NumSubgroups, retType, srcLoc);
9486+
retVal = spvBuilder.createLoad(retType, var, srcLoc, srcRange);
9487+
} break;
94729488
case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
94739489
case hlsl::IntrinsicOp::IOP_WaveActiveSum:
94749490
case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %dxc -T cs_6_10 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s
2+
3+
// CHECK: ; Version: 1.3
4+
5+
RWStructuredBuffer<uint> output: register(u0);
6+
7+
// CHECK: OpCapability GroupNonUniform
8+
9+
// CHECK: OpEntryPoint GLCompute
10+
// CHECK-SAME: %NumSubgroups
11+
12+
// CHECK: OpDecorate %NumSubgroups BuiltIn NumSubgroups
13+
14+
// CHECK: %NumSubgroups = OpVariable %_ptr_Input_uint Input
15+
16+
[numthreads(64, 1, 1)]
17+
void main(uint3 id: SV_DispatchThreadID) {
18+
// CHECK: OpLoad %uint %NumSubgroups
19+
output[id.x] = GetGroupWaveCount();
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %dxc -T cs_6_10 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s
2+
3+
// CHECK: ; Version: 1.3
4+
5+
RWStructuredBuffer<uint> output: register(u0);
6+
7+
// CHECK: OpCapability GroupNonUniform
8+
9+
// CHECK: OpEntryPoint GLCompute
10+
// CHECK-SAME: %SubgroupId
11+
12+
// CHECK: OpDecorate %SubgroupId BuiltIn SubgroupId
13+
14+
// CHECK: %SubgroupId = OpVariable %_ptr_Input_uint Input
15+
16+
[numthreads(64, 1, 1)]
17+
void main(uint3 id: SV_DispatchThreadID) {
18+
// CHECK: OpLoad %uint %SubgroupId
19+
output[id.x] = GetGroupWaveIndex();
20+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
; RUN: %dxopt %s -hlsl-passes-resume -dxilgen -S | FileCheck %s
2+
3+
; CHECK: call i32 @dx.op.getGroupWaveIndex(i32 -2147483647)
4+
; CHECK: call i32 @dx.op.getGroupWaveCount(i32 -2147483646)
5+
6+
; Generated from:
7+
; utils/hct/ExtractIRForPassTest.py -p dxilgen -o tools/clang/test/DXC/Passes/DxilGen/group-wave-index.ll tools/clang/test/HLSLFileCheckLit/hlsl/intrinsics/wave/group-wave-index.hlsl -- -T cs_6_10 -E main
8+
; Debug info manually stripped.
9+
10+
target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
11+
target triple = "dxil-ms-dx"
12+
13+
%"class.RWStructuredBuffer<unsigned int>" = type { i32 }
14+
%dx.types.Handle = type { i8* }
15+
%dx.types.ResourceProperties = type { i32, i32 }
16+
17+
@"\01?output0@@3V?$RWStructuredBuffer@I@@A" = external global %"class.RWStructuredBuffer<unsigned int>", align 4
18+
19+
; Function Attrs: nounwind
20+
define void @main(<3 x i32> %id) #0 {
21+
entry:
22+
%0 = call i32 @"dx.hl.op.rn.i32 (i32)"(i32 396)
23+
%1 = call i32 @"dx.hl.op.rn.i32 (i32)"(i32 395)
24+
%2 = load %"class.RWStructuredBuffer<unsigned int>", %"class.RWStructuredBuffer<unsigned int>"* @"\01?output0@@3V?$RWStructuredBuffer@I@@A"
25+
%3 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32 0, %"class.RWStructuredBuffer<unsigned int>" %2)
26+
%4 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32 14, %dx.types.Handle %3, %dx.types.ResourceProperties { i32 4108, i32 4 }, %"class.RWStructuredBuffer<unsigned int>" zeroinitializer)
27+
%5 = call i32* @"dx.hl.subscript.[].rn.i32* (i32, %dx.types.Handle, i32)"(i32 0, %dx.types.Handle %4, i32 0)
28+
store i32 %0, i32* %5
29+
%6 = load %"class.RWStructuredBuffer<unsigned int>", %"class.RWStructuredBuffer<unsigned int>"* @"\01?output0@@3V?$RWStructuredBuffer@I@@A"
30+
%7 = call %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32 0, %"class.RWStructuredBuffer<unsigned int>" %6)
31+
%8 = call %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32 14, %dx.types.Handle %7, %dx.types.ResourceProperties { i32 4108, i32 4 }, %"class.RWStructuredBuffer<unsigned int>" zeroinitializer)
32+
%9 = call i32* @"dx.hl.subscript.[].rn.i32* (i32, %dx.types.Handle, i32)"(i32 0, %dx.types.Handle %8, i32 16)
33+
store i32 %1, i32* %9
34+
ret void
35+
}
36+
37+
; Function Attrs: nounwind
38+
declare void @llvm.lifetime.start(i64, i8* nocapture) #0
39+
40+
; Function Attrs: nounwind
41+
declare void @llvm.lifetime.end(i64, i8* nocapture) #0
42+
43+
; Function Attrs: nounwind readnone
44+
declare i32 @"dx.hl.op.rn.i32 (i32)"(i32) #1
45+
46+
; Function Attrs: nounwind readnone
47+
declare i32* @"dx.hl.subscript.[].rn.i32* (i32, %dx.types.Handle, i32)"(i32, %dx.types.Handle, i32) #1
48+
49+
; Function Attrs: nounwind readnone
50+
declare %dx.types.Handle @"dx.hl.createhandle..%dx.types.Handle (i32, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32, %"class.RWStructuredBuffer<unsigned int>") #1
51+
52+
; Function Attrs: nounwind readnone
53+
declare %dx.types.Handle @"dx.hl.annotatehandle..%dx.types.Handle (i32, %dx.types.Handle, %dx.types.ResourceProperties, %\22class.RWStructuredBuffer<unsigned int>\22)"(i32, %dx.types.Handle, %dx.types.ResourceProperties, %"class.RWStructuredBuffer<unsigned int>") #1
54+
55+
attributes #0 = { nounwind }
56+
attributes #1 = { nounwind readnone }
57+
58+
!pauseresume = !{!1}
59+
!llvm.ident = !{!2}
60+
!dx.version = !{!3}
61+
!dx.valver = !{!3}
62+
!dx.shaderModel = !{!4}
63+
!dx.typeAnnotations = !{!5, !11}
64+
!dx.entryPoints = !{!18}
65+
!dx.fnprops = !{!23}
66+
!dx.options = !{!24, !25}
67+
68+
!1 = !{!"hlsl-hlemit", !"hlsl-hlensure"}
69+
!2 = !{!"dxc(private) 1.8.0.5134 (Group-Wave-Intrinsics, 84e7262d3)"}
70+
!3 = !{i32 1, i32 10}
71+
!4 = !{!"cs", i32 6, i32 10}
72+
!5 = !{i32 0, %"class.RWStructuredBuffer<unsigned int>" undef, !6}
73+
!6 = !{i32 4, !7, !8}
74+
!7 = !{i32 6, !"h", i32 3, i32 0, i32 7, i32 5}
75+
!8 = !{i32 0, !9}
76+
!9 = !{!10}
77+
!10 = !{i32 0, i32 undef}
78+
!11 = !{i32 1, void (<3 x i32>)* @main, !12}
79+
!12 = !{!13, !15}
80+
!13 = !{i32 1, !14, !14}
81+
!14 = !{}
82+
!15 = !{i32 0, !16, !17}
83+
!16 = !{i32 4, !"SV_DispatchThreadID", i32 7, i32 5, i32 13, i32 3}
84+
!17 = !{i32 0}
85+
!18 = !{void (<3 x i32>)* @main, !"main", null, !19, null}
86+
!19 = !{null, !20, null, null}
87+
!20 = !{!21}
88+
!21 = !{i32 0, %"class.RWStructuredBuffer<unsigned int>"* @"\01?output0@@3V?$RWStructuredBuffer@I@@A", !"output0", i32 0, i32 0, i32 1, i32 12, i1 false, i1 false, i1 false, !22}
89+
!22 = !{i32 1, i32 4}
90+
!23 = !{void (<3 x i32>)* @main, i32 5, i32 1, i32 1, i32 1}
91+
!24 = !{i32 64}
92+
!25 = !{i32 -1}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// REQUIRES: dxil-1-10
2+
3+
// RUN: %dxc -T cs_6_10 -E main -fcgl %s | FileCheck %s --check-prefix=FCGL
4+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
5+
6+
// FCGL: call i32 @"dx.hl.op.rn.i32 (i32)"(i32 396)
7+
// FCGL: call i32 @"dx.hl.op.rn.i32 (i32)"(i32 395)
8+
9+
// CHECK: %[[Index:[^ ]+]] = call i32 @dx.op.getGroupWaveIndex(i32 -2147483647) ; GetGroupWaveIndex()
10+
// CHECK: %[[Count:[^ ]+]] = call i32 @dx.op.getGroupWaveCount(i32 -2147483646) ; GetGroupWaveCount()
11+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[^,]+}}, i32 0, i32 0, i32 %[[Index]], i32 undef, i32 undef, i32 undef, i8 1, i32 4)
12+
// CHECK: call void @dx.op.rawBufferStore.i32(i32 140, %dx.types.Handle %{{[^,]+}}, i32 16, i32 0, i32 %[[Count]], i32 undef, i32 undef, i32 undef, i8 1, i32 4)
13+
14+
RWStructuredBuffer<uint> output0 : register(u0);
15+
16+
[numthreads(1, 1, 1)]
17+
void main(uint3 id: SV_DispatchThreadID) {
18+
uint waveIdx = GetGroupWaveIndex();
19+
uint waveCount = GetGroupWaveCount();
20+
21+
output0[0] = waveIdx;
22+
output0[16] = waveCount;
23+
}

0 commit comments

Comments
 (0)