diff --git a/docs/SPIR-V.rst b/docs/SPIR-V.rst index 899b587492..b5e9c05079 100644 --- a/docs/SPIR-V.rst +++ b/docs/SPIR-V.rst @@ -320,6 +320,7 @@ Supported extensions * SPV_KHR_maximal_reconvergence * SPV_KHR_float_controls * SPV_NV_shader_subgroup_partitioned +* SPV_KHR_quad_control Vulkan specific attributes -------------------------- @@ -4008,6 +4009,8 @@ Quad ``QuadReadAcrossX()`` ``OpGroupNonUniformQuadSwap`` Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap`` Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap`` Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast`` +Quad ``QuadAny()`` ``OpGroupNonUniformQuadAnyKHR`` +Quad ``QuadAll()`` ``OpGroupNonUniformQuadAllKHR`` N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV`` Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV`` Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV`` @@ -4016,6 +4019,11 @@ Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` ` Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV`` ============= ============================ =================================== ============================== +``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and +``OpGroupNonUniformQuadAllKHR`` instructions if the ``SPV_KHR_quad_control`` +extension is enabled. If it is not, they will fall back to constructing the +value using multiple calls to ``OpGroupNonUniformQuadBroadcast``. + The Implicit ``vk`` Namespace ============================= diff --git a/tools/clang/include/clang/SPIRV/FeatureManager.h b/tools/clang/include/clang/SPIRV/FeatureManager.h index 8a9755ae79..3c1871df37 100644 --- a/tools/clang/include/clang/SPIRV/FeatureManager.h +++ b/tools/clang/include/clang/SPIRV/FeatureManager.h @@ -64,6 +64,7 @@ enum class Extension { KHR_maximal_reconvergence, KHR_float_controls, NV_shader_subgroup_partitioned, + KHR_quad_control, Unknown, }; diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index f03735115b..87d6c1713d 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -239,7 +239,7 @@ class SpirvBuilder { /// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V /// opcode. SpirvGroupNonUniformOp *createGroupNonUniformOp( - spv::Op op, QualType resultType, spv::Scope execScope, + spv::Op op, QualType resultType, llvm::Optional execScope, llvm::ArrayRef operands, SourceLocation, llvm::Optional groupOp = llvm::None); diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index 7ec1375bde..34070aa59b 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -1514,7 +1514,8 @@ class SpirvFunctionCall : public SpirvInstruction { /// \brief OpGroupNonUniform* instructions class SpirvGroupNonUniformOp : public SpirvInstruction { public: - SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope, + SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, + llvm::Optional scope, llvm::ArrayRef operands, SourceLocation loc, llvm::Optional group); @@ -1528,7 +1529,8 @@ class SpirvGroupNonUniformOp : public SpirvInstruction { bool invokeVisitor(Visitor *v) override; - spv::Scope getExecutionScope() const { return execScope; } + bool hasExecutionScope() const { return execScope.hasValue(); } + spv::Scope getExecutionScope() const { return execScope.getValue(); } llvm::ArrayRef getOperands() const { return operands; } @@ -1546,7 +1548,7 @@ class SpirvGroupNonUniformOp : public SpirvInstruction { } private: - spv::Scope execScope; + llvm::Optional execScope; llvm::SmallVector operands; llvm::Optional groupOp; }; diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index c2b5acff53..840231a2a0 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -882,6 +882,9 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) { addCapability(spv::Capability::InterpolationFunction); + addExtensionAndCapabilitiesIfEnabled(Extension::KHR_quad_control, + {spv::Capability::QuadControlKHR}); + return true; } diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 6f6f5f88cd..713ffac3ac 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -1108,9 +1108,10 @@ bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) { initInstruction(inst); curInst.push_back(inst->getResultTypeId()); curInst.push_back(getOrAssignResultId(inst)); - curInst.push_back(typeHandler.getOrCreateConstantInt( - llvm::APInt(32, static_cast(inst->getExecutionScope())), - context.getUIntType(32), /* isSpecConst */ false)); + if (inst->hasExecutionScope()) + curInst.push_back(typeHandler.getOrCreateConstantInt( + llvm::APInt(32, static_cast(inst->getExecutionScope())), + context.getUIntType(32), /* isSpecConst */ false)); if (inst->hasGroupOp()) curInst.push_back(static_cast(inst->getGroupOp())); for (auto *operand : inst->getOperands()) diff --git a/tools/clang/lib/SPIRV/FeatureManager.cpp b/tools/clang/lib/SPIRV/FeatureManager.cpp index a8ee1de000..7fb449fee9 100644 --- a/tools/clang/lib/SPIRV/FeatureManager.cpp +++ b/tools/clang/lib/SPIRV/FeatureManager.cpp @@ -226,6 +226,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) { .Case("SPV_KHR_float_controls", Extension::KHR_float_controls) .Case("SPV_NV_shader_subgroup_partitioned", Extension::NV_shader_subgroup_partitioned) + .Case("SPV_KHR_quad_control", Extension::KHR_quad_control) .Default(Extension::Unknown); } @@ -297,6 +298,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) { return "SPV_KHR_float_controls"; case Extension::NV_shader_subgroup_partitioned: return "SPV_NV_shader_subgroup_partitioned"; + case Extension::KHR_quad_control: + return "SPV_KHR_quad_control"; default: break; } diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index 1275e2b252..290792ab61 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -432,7 +432,7 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp( } SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp( - spv::Op op, QualType resultType, spv::Scope execScope, + spv::Op op, QualType resultType, llvm::Optional execScope, llvm::ArrayRef operands, SourceLocation loc, llvm::Optional groupOp) { assert(insertPoint && "null insert point"); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 579af04ea6..5b86e52d34 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -9204,6 +9204,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { case hlsl::IntrinsicOp::IOP_QuadReadLaneAt: retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode); break; + case hlsl::IntrinsicOp::IOP_QuadAny: + case hlsl::IntrinsicOp::IOP_QuadAll: + retVal = processWaveQuadAnyAll(callExpr, hlslOpcode); + break; case hlsl::IntrinsicOp::IOP_abort: case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount: case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: { @@ -10158,6 +10162,53 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr, opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc); } +SpirvInstruction *SpirvEmitter::processWaveQuadAnyAll(const CallExpr *callExpr, + hlsl::IntrinsicOp op) { + // Signatures: + // bool QuadAny(bool localValue) + // bool QuadAll(bool localValue) + assert(callExpr->getNumArgs() == 1); + assert(op == hlsl::IntrinsicOp::IOP_QuadAny || + op == hlsl::IntrinsicOp::IOP_QuadAll); + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation", + callExpr->getExprLoc()); + + auto *predicate = doExpr(callExpr->getArg(0)); + const auto srcLoc = callExpr->getExprLoc(); + + if (!featureManager.isExtensionEnabled(Extension::KHR_quad_control)) { + // We can't use QuadAny/QuadAll, so implement them using QuadSwap. We + // will read the value at each quad invocation, then combine them. + + spv::Op reducer = op == hlsl::IntrinsicOp::IOP_QuadAny + ? spv::Op::OpLogicalOr + : spv::Op::OpLogicalAnd; + + SpirvInstruction *result = predicate; + + for (size_t i = 0; i < 3; i++) { + SpirvInstruction *invocationValue = spvBuilder.createGroupNonUniformOp( + spv::Op::OpGroupNonUniformQuadSwap, astContext.BoolTy, + spv::Scope::Subgroup, + {predicate, spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, i))}, + srcLoc); + result = spvBuilder.createBinaryOp(reducer, astContext.BoolTy, result, + invocationValue, srcLoc); + } + + return result; + } + + spv::Op opcode = op == hlsl::IntrinsicOp::IOP_QuadAny + ? spv::Op::OpGroupNonUniformQuadAnyKHR + : spv::Op::OpGroupNonUniformQuadAllKHR; + + return spvBuilder.createGroupNonUniformOp(opcode, astContext.BoolTy, + llvm::Optional(), + {predicate}, srcLoc); +} + SpirvInstruction * SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) { assert(callExpr->getNumArgs() == 1); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index eca038527f..36aa693f1b 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -657,6 +657,10 @@ class SpirvEmitter : public ASTConsumer { SpirvInstruction *processWaveQuadWideShuffle(const CallExpr *, hlsl::IntrinsicOp op); + /// Processes SM6.7 quad any/all. + SpirvInstruction *processWaveQuadAnyAll(const CallExpr *, + hlsl::IntrinsicOp op); + /// Generates the Spir-V instructions needed to implement the given call to /// WaveActiveAllEqual. Returns a pointer to the instruction that produces the /// final result. diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 21aada9e82..5bfa4a68da 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -677,7 +677,7 @@ SpirvFunctionCall::SpirvFunctionCall(QualType resultType, SourceLocation loc, function(fn), args(argsVec.begin(), argsVec.end()) {} SpirvGroupNonUniformOp::SpirvGroupNonUniformOp( - spv::Op op, QualType resultType, spv::Scope scope, + spv::Op op, QualType resultType, llvm::Optional scope, llvm::ArrayRef operandsVec, SourceLocation loc, llvm::Optional group) : SpirvInstruction(IK_GroupNonUniformOp, op, resultType, loc), @@ -709,6 +709,8 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp( case spv::Op::OpGroupNonUniformLogicalAnd: case spv::Op::OpGroupNonUniformLogicalOr: case spv::Op::OpGroupNonUniformLogicalXor: + case spv::Op::OpGroupNonUniformQuadAnyKHR: + case spv::Op::OpGroupNonUniformQuadAllKHR: assert(operandsVec.size() == 1); break; @@ -740,6 +742,11 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp( assert(false && "Unexpected Group non-uniform opcode"); break; } + + if (op != spv::Op::OpGroupNonUniformQuadAnyKHR && + op != spv::Op::OpGroupNonUniformQuadAllKHR) { + assert(scope.hasValue()); + } } SpirvImageOp::SpirvImageOp( diff --git a/tools/clang/test/CodeGenSPIRV/sm6.quad-any-all.hlsl b/tools/clang/test/CodeGenSPIRV/sm6.quad-any-all.hlsl new file mode 100644 index 0000000000..fb9f6e0d76 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/sm6.quad-any-all.hlsl @@ -0,0 +1,41 @@ +// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,QUAD +// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fspv-extension=SPV_KHR_16bit_storage -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,NOQUAD +// RUN: not %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.0 -fcgl %s -spirv 2>&1 | FileCheck %s --check-prefixes=ERROR + +// CHECK: ; Version: 1.3 + +// QUAD: OpCapability QuadControlKHR +// QUAD: OpExtension "SPV_KHR_quad_control" + +RWStructuredBuffer values; + +[numthreads(32, 1, 1)] +void main(uint3 id: SV_DispatchThreadID) { + uint outIdx = (id.y * 8) + id.x; + +// CHECK: [[val1:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}} +// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAnyKHR %bool [[val1]] + +// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_0 +// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalOr %bool [[val1]] [[inv0]] +// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_1 +// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalOr %bool [[or0]] [[inv1]] +// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_2 +// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalOr %bool [[or1]] [[inv2]] + +// ERROR: 27:24: error: Vulkan 1.1 is required for Wave Operation but not permitted to use + values[outIdx].x = QuadAny(outIdx % 4 == 0) ? 1.0 : 2.0; + +// CHECK: [[val2:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}} +// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAllKHR %bool [[val2]] + +// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_0 +// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalAnd %bool [[val2]] [[inv0]] +// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_1 +// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalAnd %bool [[or0]] [[inv1]] +// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_2 +// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalAnd %bool [[or1]] [[inv2]] + +// ERROR: 40:24: error: Vulkan 1.1 is required for Wave Operation but not permitted to use + values[outIdx].y = QuadAll(outIdx % 2 == 0) ? 3.0 : 4.0; +}