Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/SPIR-V.rst
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ Supported extensions
* SPV_KHR_maximal_reconvergence
* SPV_KHR_float_controls
* SPV_NV_shader_subgroup_partitioned
* SPV_KHR_quad_control

Vulkan specific attributes
--------------------------
Expand Down Expand Up @@ -4008,6 +4009,8 @@ Quad ``QuadReadAcrossX()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast``
Quad ``QuadAny()`` ``OpGroupNonUniformQuadAnyKHR``
Quad ``QuadAll()`` ``OpGroupNonUniformQuadAllKHR``
N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV``
Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV``
Expand All @@ -4016,6 +4019,11 @@ Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` `
Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV``
============= ============================ =================================== ==============================

``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and
``OpGroupNonUniformQuadAllKHR`` if the ``SPV_KHR_quad_control`` extension is
enabled. If it is not, they will fall back to constructing the value using
multiple calls to ``OpGroupNonUniformQuadBroadcast``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and
``OpGroupNonUniformQuadAllKHR`` if the ``SPV_KHR_quad_control`` extension is
enabled. If it is not, they will fall back to constructing the value using
multiple calls to ``OpGroupNonUniformQuadBroadcast``.
``QuadAny`` and ``QuadAll`` will use the ``OpGroupNonUniformQuadAnyKHR`` and
``OpGroupNonUniformQuadAllKHR`` instructions if the ``SPV_KHR_quad_control`` extension is
enabled. If it is not, they will fall back to constructing the value using
multiple calls to ``OpGroupNonUniformQuadBroadcast``.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


The Implicit ``vk`` Namespace
=============================

Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/SPIRV/FeatureManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum class Extension {
KHR_maximal_reconvergence,
KHR_float_controls,
NV_shader_subgroup_partitioned,
KHR_quad_control,
Unknown,
};

Expand Down
2 changes: 1 addition & 1 deletion tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class SpirvBuilder {
/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
/// opcode.
SpirvGroupNonUniformOp *createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);

Expand Down
8 changes: 5 additions & 3 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,8 @@ class SpirvFunctionCall : public SpirvInstruction {
/// \brief OpGroupNonUniform* instructions
class SpirvGroupNonUniformOp : public SpirvInstruction {
public:
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType,
llvm::Optional<spv::Scope> scope,
llvm::ArrayRef<SpirvInstruction *> operands,
SourceLocation loc,
llvm::Optional<spv::GroupOperation> group);
Expand All @@ -1528,7 +1529,8 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {

bool invokeVisitor(Visitor *v) override;

spv::Scope getExecutionScope() const { return execScope; }
bool hasExecutionScope() const { return execScope.hasValue(); }
spv::Scope getExecutionScope() const { return execScope.getValue(); }

llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }

Expand All @@ -1546,7 +1548,7 @@ class SpirvGroupNonUniformOp : public SpirvInstruction {
}

private:
spv::Scope execScope;
llvm::Optional<spv::Scope> execScope;
llvm::SmallVector<SpirvInstruction *, 4> operands;
llvm::Optional<spv::GroupOperation> groupOp;
};
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,9 @@ bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {

addCapability(spv::Capability::InterpolationFunction);

addExtensionAndCapabilitiesIfEnabled(Extension::KHR_quad_control,
{spv::Capability::QuadControlKHR});

return true;
}

Expand Down
7 changes: 4 additions & 3 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,9 +1108,10 @@ bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
if (inst->hasExecutionScope())
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
if (inst->hasGroupOp())
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
for (auto *operand : inst->getOperands())
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/FeatureManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
.Case("SPV_KHR_float_controls", Extension::KHR_float_controls)
.Case("SPV_NV_shader_subgroup_partitioned",
Extension::NV_shader_subgroup_partitioned)
.Case("SPV_KHR_quad_control", Extension::KHR_quad_control)
.Default(Extension::Unknown);
}

Expand Down Expand Up @@ -297,6 +298,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
return "SPV_KHR_float_controls";
case Extension::NV_shader_subgroup_partitioned:
return "SPV_NV_shader_subgroup_partitioned";
case Extension::KHR_quad_control:
return "SPV_KHR_quad_control";
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
}

SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> execScope,
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation loc,
llvm::Optional<spv::GroupOperation> groupOp) {
assert(insertPoint && "null insert point");
Expand Down
51 changes: 51 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9204,6 +9204,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
break;
case hlsl::IntrinsicOp::IOP_QuadAny:
case hlsl::IntrinsicOp::IOP_QuadAll:
retVal = processWaveQuadAnyAll(callExpr, hlslOpcode);
break;
case hlsl::IntrinsicOp::IOP_abort:
case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
Expand Down Expand Up @@ -10158,6 +10162,53 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc);
}

SpirvInstruction *SpirvEmitter::processWaveQuadAnyAll(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
// Signatures:
// bool QuadAny(bool localValue)
// bool QuadAll(bool localValue)
assert(callExpr->getNumArgs() == 1);
assert(op == hlsl::IntrinsicOp::IOP_QuadAny ||
op == hlsl::IntrinsicOp::IOP_QuadAll);
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
callExpr->getExprLoc());

auto *predicate = doExpr(callExpr->getArg(0));
const auto srcLoc = callExpr->getExprLoc();

if (!featureManager.isExtensionEnabled(Extension::KHR_quad_control)) {
// We can't use QuadAny/QuadAll, so implement them using QuadSwap. We
// will read the value at each quad invocation, then combine them.

spv::Op reducer = op == hlsl::IntrinsicOp::IOP_QuadAny
? spv::Op::OpLogicalOr
: spv::Op::OpLogicalAnd;

SpirvInstruction *result = predicate;

for (size_t i = 0; i < 3; i++) {
SpirvInstruction *invocationValue = spvBuilder.createGroupNonUniformOp(
spv::Op::OpGroupNonUniformQuadSwap, astContext.BoolTy,
spv::Scope::Subgroup,
{predicate, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, i))},
srcLoc);
result = spvBuilder.createBinaryOp(reducer, astContext.BoolTy, result,
invocationValue, srcLoc);
}

return result;
}

spv::Op opcode = op == hlsl::IntrinsicOp::IOP_QuadAny
? spv::Op::OpGroupNonUniformQuadAnyKHR
: spv::Op::OpGroupNonUniformQuadAllKHR;

return spvBuilder.createGroupNonUniformOp(opcode, astContext.BoolTy,
llvm::Optional<spv::Scope>(),
{predicate}, srcLoc);
}

SpirvInstruction *
SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
assert(callExpr->getNumArgs() == 1);
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ class SpirvEmitter : public ASTConsumer {
SpirvInstruction *processWaveQuadWideShuffle(const CallExpr *,
hlsl::IntrinsicOp op);

/// Processes SM6.7 quad any/all.
SpirvInstruction *processWaveQuadAnyAll(const CallExpr *,
hlsl::IntrinsicOp op);

/// Generates the Spir-V instructions needed to implement the given call to
/// WaveActiveAllEqual. Returns a pointer to the instruction that produces the
/// final result.
Expand Down
9 changes: 8 additions & 1 deletion tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ SpirvFunctionCall::SpirvFunctionCall(QualType resultType, SourceLocation loc,
function(fn), args(argsVec.begin(), argsVec.end()) {}

SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope scope,
spv::Op op, QualType resultType, llvm::Optional<spv::Scope> scope,
llvm::ArrayRef<SpirvInstruction *> operandsVec, SourceLocation loc,
llvm::Optional<spv::GroupOperation> group)
: SpirvInstruction(IK_GroupNonUniformOp, op, resultType, loc),
Expand Down Expand Up @@ -709,6 +709,8 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
case spv::Op::OpGroupNonUniformLogicalAnd:
case spv::Op::OpGroupNonUniformLogicalOr:
case spv::Op::OpGroupNonUniformLogicalXor:
case spv::Op::OpGroupNonUniformQuadAnyKHR:
case spv::Op::OpGroupNonUniformQuadAllKHR:
assert(operandsVec.size() == 1);
break;

Expand Down Expand Up @@ -740,6 +742,11 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
assert(false && "Unexpected Group non-uniform opcode");
break;
}

if (op != spv::Op::OpGroupNonUniformQuadAnyKHR &&
op != spv::Op::OpGroupNonUniformQuadAllKHR) {
assert(scope.hasValue());
}
}

SpirvImageOp::SpirvImageOp(
Expand Down
36 changes: 36 additions & 0 deletions tools/clang/test/CodeGenSPIRV/sm6.quad-any-all.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,QUAD
// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.1 -fspv-extension=SPV_KHR_16bit_storage -fcgl %s -spirv | FileCheck %s --check-prefixes=CHECK,NOQUAD
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do a test with vulkan1.0 and check for the error? I think we can simply add another run command.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// CHECK: ; Version: 1.3

// QUAD: OpCapability QuadControlKHR
// QUAD: OpExtension "SPV_KHR_quad_control"

RWStructuredBuffer<float3> values;

[numthreads(32, 1, 1)]
void main(uint3 id: SV_DispatchThreadID) {
uint outIdx = (id.y * 8) + id.x;

// CHECK: [[val1:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}}
// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAnyKHR %bool [[val1]]

// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_0
// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalOr %bool [[val1]] [[inv0]]
// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_1
// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalOr %bool [[or0]] [[inv1]]
// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val1]] %uint_2
// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalOr %bool [[or1]] [[inv2]]
values[outIdx].x = QuadAny(outIdx % 4 == 0) ? 1.0 : 2.0;

// CHECK: [[val2:%[0-9]+]] = OpIEqual %bool {{%[0-9]+}}
// QUAD-NEXT: {{%[0-9]+}} = OpGroupNonUniformQuadAllKHR %bool [[val2]]

// NOQUAD-NEXT: [[inv0:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_0
// NOQUAD-NEXT: [[or0:%[0-9]+]] = OpLogicalAnd %bool [[val2]] [[inv0]]
// NOQUAD-NEXT: [[inv1:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_1
// NOQUAD-NEXT: [[or1:%[0-9]+]] = OpLogicalAnd %bool [[or0]] [[inv1]]
// NOQUAD-NEXT: [[inv2:%[0-9]+]] = OpGroupNonUniformQuadSwap %bool %uint_3 [[val2]] %uint_2
// NOQUAD-NEXT: [[or2:%[0-9]+]] = OpLogicalAnd %bool [[or1]] [[inv2]]
values[outIdx].y = QuadAll(outIdx % 2 == 0) ? 3.0 : 4.0;
}