Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/SPIR-V.rst
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ Supported extensions
* SPV_KHR_maximal_reconvergence
* SPV_KHR_float_controls
* SPV_NV_shader_subgroup_partitioned
* SPV_KHR_quad_control

Vulkan specific attributes
--------------------------
Expand Down Expand Up @@ -4008,6 +4009,8 @@ Quad ``QuadReadAcrossX()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast``
Quad ``QuadAny()`` ``OpGroupNonUniformQuadAnyKHR``
Quad ``QuadAll()`` ``OpGroupNonUniformQuadAllKHR``
N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV``
Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV``
Expand All @@ -4016,6 +4019,11 @@ Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` `
Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV``
============= ============================ =================================== ==============================

``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and
``OpGroupNonUniformQuadAllKHR`` instructions if the ``SPV_KHR_quad_control``
extension is enabled. If it is not, they will fall back to constructing the
value using multiple calls to ``OpGroupNonUniformQuadBroadcast``.

The Implicit ``vk`` Namespace
=============================

Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/SPIRV/FeatureManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum class Extension {
KHR_maximal_reconvergence,
KHR_float_controls,
NV_shader_subgroup_partitioned,
KHR_quad_control,
Unknown,
};

Expand Down
2 changes: 1 addition & 1 deletion tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class SpirvBuilder {
/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
/// opcode.
SpirvGroupNonUniformOp *createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);

Expand Down
8 changes: 5 additions & 3 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,8 @@ class SpirvFunctionCall : public SpirvInstruction {
/// \brief OpGroupNonUniform* instructions
class SpirvGroupNonUniformOp : public SpirvInstruction {
public:
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType,
llvm::Optional<spv::Scope> scope,
llvm::ArrayRef<SpirvInstruction *> operands,
SourceLocation loc,
llvm::Optional<spv::GroupOperation> group);
Expand All @@ -1528,7 +1529,8 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {

bool invokeVisitor(Visitor *v) override;

spv::Scope getExecutionScope() const { return execScope; }
bool hasExecutionScope() const { return execScope.hasValue(); }
spv::Scope getExecutionScope() const { return execScope.getValue(); }

llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }

Expand All @@ -1546,7 +1548,7 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {
}

private:
spv::Scope execScope;
llvm::Optional<spv::Scope> execScope;
llvm::SmallVector<SpirvInstruction *, 4> operands;
llvm::Optional<spv::GroupOperation> groupOp;
};
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,9 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {

addCapability(spv::Capability::InterpolationFunction);

addExtensionAndCapabilitiesIfEnabled(Extension::KHR_quad_control,
{spv::Capability::QuadControlKHR});

return true;
}

Expand Down
7 changes: 4 additions & 3 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,9 +1108,10 @@ bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
if (inst->hasExecutionScope())
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
if (inst->hasGroupOp())
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
for (auto *operand : inst->getOperands())
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/FeatureManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
.Case("SPV_KHR_float_controls", Extension::KHR_float_controls)
.Case("SPV_NV_shader_subgroup_partitioned",
Extension::NV_shader_subgroup_partitioned)
.Case("SPV_KHR_quad_control", Extension::KHR_quad_control)
.Default(Extension::Unknown);
}

Expand Down Expand Up @@ -297,6 +298,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
return "SPV_KHR_float_controls";
case Extension::NV_shader_subgroup_partitioned:
return "SPV_NV_shader_subgroup_partitioned";
case Extension::KHR_quad_control:
return "SPV_KHR_quad_control";
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
}

SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation loc,
llvm::Optional<spv::GroupOperation> groupOp) {
assert(insertPoint && "null insert point");
Expand Down
51 changes: 51 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9204,6 +9204,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
break;
case hlsl::IntrinsicOp::IOP_QuadAny:
case hlsl::IntrinsicOp::IOP_QuadAll:
retVal = processWaveQuadAnyAll(callExpr, hlslOpcode);
break;
case hlsl::IntrinsicOp::IOP_abort:
case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
Expand Down Expand Up @@ -10158,6 +10162,53 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc);
}

SpirvInstruction *SpirvEmitter::processWaveQuadAnyAll(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
// Signatures:
// bool QuadAny(bool localValue)
// bool QuadAll(bool localValue)
assert(callExpr->getNumArgs() == 1);
assert(op == hlsl::IntrinsicOp::IOP_QuadAny ||
op == hlsl::IntrinsicOp::IOP_QuadAll);
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
callExpr->getExprLoc());

auto *predicate = doExpr(callExpr->getArg(0));
const auto srcLoc = callExpr->getExprLoc();

if (!featureManager.isExtensionEnabled(Extension::KHR_quad_control)) {
// We can't use QuadAny/QuadAll, so implement them using QuadSwap. We
// will read the value at each quad invocation, then combine them.

spv::Op reducer = op == hlsl::IntrinsicOp::IOP_QuadAny
? spv::Op::OpLogicalOr
: spv::Op::OpLogicalAnd;

SpirvInstruction *result = predicate;

for (size_t i = 0; i < 3; i++) {
SpirvInstruction *invocationValue = spvBuilder.createGroupNonUniformOp(
spv::Op::OpGroupNonUniformQuadSwap, astContext.BoolTy,
spv::Scope::Subgroup,
{predicate, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, i))},
srcLoc);
result = spvBuilder.createBinaryOp(reducer, astContext.BoolTy, result,
invocationValue, srcLoc);
}

return result;
}

spv::Op opcode = op == hlsl::IntrinsicOp::IOP_QuadAny
? spv::Op::OpGroupNonUniformQuadAnyKHR
: spv::Op::OpGroupNonUniformQuadAllKHR;

return spvBuilder.createGroupNonUniformOp(opcode, astContext.BoolTy,
llvm::Optional<spv::Scope>(),
{predicate}, srcLoc);
}

SpirvInstruction *
SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
assert(callExpr->getNumArgs() == 1);
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ class SpirvEmitter : public ASTConsumer {
SpirvInstruction *processWaveQuadWideShuffle(const CallExpr *,
hlsl::IntrinsicOp op);

/// Processes SM6.7 quad any/all.
SpirvInstruction *processWaveQuadAnyAll(const CallExpr *,
hlsl::IntrinsicOp op);

/// Generates the Spir-V instructions needed to implement the given call to
/// WaveActiveAllEqual. Returns a pointer to the instruction that produces the
/// final result.
Expand Down
9 changes: 8 additions & 1 deletion tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ SpirvFunctionCall::SpirvFunctionCall(QualType resultType, SourceLocation loc,
function(fn), args(argsVec.begin(), argsVec.end()) {}

SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope scope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> scope,
llvm::ArrayRef<SpirvInstruction *> operandsVec, SourceLocation loc,
llvm::Optional<spv::GroupOperation> group)
: SpirvInstruction(IK_GroupNonUniformOp, op, resultType, loc),
Expand Down Expand Up @@ -709,6 +709,8 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
case spv::Op::OpGroupNonUniformLogicalAnd:
case spv::Op::OpGroupNonUniformLogicalOr:
case spv::Op::OpGroupNonUniformLogicalXor:
case spv::Op::OpGroupNonUniformQuadAnyKHR:
case spv::Op::OpGroupNonUniformQuadAllKHR:
assert(operandsVec.size() == 1);
break;

Expand Down Expand Up @@ -740,6 +742,11 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
assert(false && "Unexpected Group non-uniform opcode");
break;
}

if (op != spv::Op::OpGroupNonUniformQuadAnyKHR &&
op != spv::Op::OpGroupNonUniformQuadAllKHR) {
assert(scope.hasValue());
}
}

SpirvImageOp::SpirvImageOp(
Expand Down
41 changes: 41 additions & 0 deletions tools/clang/test/CodeGenSPIRV/sm6.quad-any-all.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,QUAD
// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fspv-extension=SPV_KHR_16bit_storage -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,NOQUAD
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do a test with vulkan1.0 and check for the error? I think we can simply add another run command.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// RUN: not %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.0 -fcgl %s -spirv 2>&1 | FileCheck %s --check-prefixes=ERROR

// CHECK: ; Version: 1.3

// QUAD: OpCapability QuadControlKHR
// QUAD: OpExtension "SPV_KHR_quad_control"

RWStructuredBuffer<float3> values;

[numthreads(32, 1, 1)]
void main(uint3 id: SV_DispatchThreadID) {
uint outIdx = (id.y * 8) + id.x;

// CHECK: [[val1:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}}
// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAnyKHR %bool [[val1]]

// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_0
// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalOr %bool [[val1]] [[inv0]]
// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_1
// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalOr %bool [[or0]] [[inv1]]
// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_2
// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalOr %bool [[or1]] [[inv2]]

// ERROR: 27:24: error: Vulkan 1.1 is required for Wave Operation but not permitted to use
values[outIdx].x = QuadAny(outIdx % 4 == 0) ? 1.0 : 2.0;

// CHECK: [[val2:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}}
// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAllKHR %bool [[val2]]

// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_0
// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalAnd %bool [[val2]] [[inv0]]
// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_1
// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalAnd %bool [[or0]] [[inv1]]
// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_2
// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalAnd %bool [[or1]] [[inv2]]

// ERROR: 40:24: error: Vulkan 1.1 is required for Wave Operation but not permitted to use
values[outIdx].y = QuadAll(outIdx % 2 == 0) ? 3.0 : 4.0;
}