Skip to content

Commit 7db2fc1

Browse files
authored
[SPIR-V] Implement WaveMatch intrinsic function (#6546)
Adds support for the `WaveMatch()` intrinsic function from Shader Model 6.5 using the `OpGroupNonUniformPartitionNV` instruction from the `SPV_NV_shader_subgroup_partitioned` extension. SPIRV-Tools bumped to include: KhronosGroup/SPIRV-Tools#5648 Fixes #6545
1 parent af955d6 commit 7db2fc1

7 files changed

Lines changed: 43 additions & 2 deletions

File tree

docs/SPIR-V.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3826,8 +3826,8 @@ RayQuery Mapping to SPIR-V
38263826
|``.WorldRayOrigin` | ``OpRayQueryGetWorldRayOriginKHR`` |
38273827
+---------------------------------------------------+-------------------------------------------------------------------------+
38283828

3829-
Shader Model 6.0 Wave Intrinsics
3830-
================================
3829+
Shader Model 6.0+ Wave Intrinsics
3830+
=================================
38313831

38323832

38333833
Note that Wave intrinsics requires SPIR-V 1.3, which is supported by Vulkan 1.1.
@@ -3865,6 +3865,7 @@ Quad ``QuadReadAcrossX()`` ``OpGroupNonUniformQuadSwap``
38653865
Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap``
38663866
Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
38673867
Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast``
3868+
N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV``
38683869
============= ============================ =================================== ======================
38693870

38703871
The Implicit ``vk`` Namespace

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ enum class Extension {
6161
KHR_fragment_shader_barycentric,
6262
KHR_maximal_reconvergence,
6363
KHR_float_controls,
64+
NV_shader_subgroup_partitioned,
6465
Unknown,
6566
};
6667

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {
898898
Extension::KHR_vulkan_memory_model,
899899
{spv::Capability::VulkanMemoryModelDeviceScope});
900900

901+
addExtensionAndCapabilitiesIfEnabled(
902+
Extension::NV_shader_subgroup_partitioned,
903+
{spv::Capability::GroupNonUniformPartitionedNV});
904+
901905
return true;
902906
}
903907

tools/clang/lib/SPIRV/FeatureManager.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
197197
.Case("SPV_KHR_maximal_reconvergence",
198198
Extension::KHR_maximal_reconvergence)
199199
.Case("SPV_KHR_float_controls", Extension::KHR_float_controls)
200+
.Case("SPV_NV_shader_subgroup_partitioned",
201+
Extension::NV_shader_subgroup_partitioned)
200202
.Default(Extension::Unknown);
201203
}
202204

@@ -264,6 +266,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
264266
return "SPV_KHR_maximal_reconvergence";
265267
case Extension::KHR_float_controls:
266268
return "SPV_KHR_float_controls";
269+
case Extension::NV_shader_subgroup_partitioned:
270+
return "SPV_NV_shader_subgroup_partitioned";
267271
default:
268272
break;
269273
}

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8716,6 +8716,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
87168716
case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
87178717
retVal = processWaveBroadcast(callExpr);
87188718
break;
8719+
case hlsl::IntrinsicOp::IOP_WaveMatch:
8720+
retVal = processWaveMatch(callExpr);
8721+
break;
87198722
case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
87208723
case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
87218724
case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
@@ -9710,6 +9713,17 @@ SpirvEmitter::processWaveActiveAllEqualMatrix(SpirvInstruction *arg,
97109713
return spvBuilder.createCompositeConstruct(booleanMatrixType, rows, srcLoc);
97119714
}
97129715

9716+
SpirvInstruction *SpirvEmitter::processWaveMatch(const CallExpr *callExpr) {
9717+
assert(callExpr->getNumArgs() == 1);
9718+
const auto loc = callExpr->getExprLoc();
9719+
9720+
// The SPV_NV_shader_subgroup_partitioned extension requires SPIR-V 1.3.
9721+
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation", loc);
9722+
9723+
SpirvInstruction *arg = doExpr(callExpr->getArg(0));
9724+
return spvBuilder.createUnaryOp(spv::Op::OpGroupNonUniformPartitionNV,
9725+
callExpr->getType(), arg, loc);
9726+
}
97139727
SpirvInstruction *SpirvEmitter::processIntrinsicModf(const CallExpr *callExpr) {
97149728
// Signature is: ret modf(x, ip)
97159729
// [in] x: the input floating-point value.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,9 @@ class SpirvEmitter : public ASTConsumer {
655655
processWaveActiveAllEqualMatrix(SpirvInstruction *arg, QualType,
656656
clang::SourceLocation srcLoc);
657657

658+
/// Processes SM6.5 WaveMatch function.
659+
SpirvInstruction *processWaveMatch(const CallExpr *);
660+
658661
/// Processes the NonUniformResourceIndex intrinsic function.
659662
SpirvInstruction *processIntrinsicNonUniformResourceIndex(const CallExpr *);
660663

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %dxc -E main -T ps_6_5 -spirv -O0 -fspv-target-env=vulkan1.1 %s | FileCheck %s
2+
// RUN: not %dxc -E main -T ps_6_5 -spirv -O0 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
// CHECK-ERROR: error: Vulkan 1.1 is required for Wave Operation but not permitted to use
5+
6+
// CHECK: OpCapability GroupNonUniformPartitionedNV
7+
// CHECK: OpExtension "SPV_NV_shader_subgroup_partitioned"
8+
9+
uint4 main(uint4 input : ATTR0) : SV_Target {
10+
// CHECK: [[input:%[0-9]+]] = OpLoad %v4uint %input
11+
// CHECK: {{%[0-9]+}} = OpGroupNonUniformPartitionNV %v4uint [[input]]
12+
uint4 res = WaveMatch(input);
13+
return res;
14+
}

0 commit comments

Comments
 (0)