Skip to content

Commit 9021cae

Browse files
authored
Expand WaveActiveAllEqual with vector and matrix arguments. (#5428)
For vector and matrix inputs to WaveActiveAllEqual, the OpGroupNonuniformAllEqual instruction must be applied on each element of the vector or matrix, and then have the results combined into a vector or matrix of bools. We make that change in this PR. Fixes #5426
1 parent 325ae8e commit 9021cae

10 files changed

Lines changed: 263 additions & 38 deletions

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8540,7 +8540,7 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
85408540
retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformBallot);
85418541
break;
85428542
case hlsl::IntrinsicOp::IOP_WaveActiveAllEqual:
8543-
retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
8543+
retVal = processWaveActiveAllEqual(callExpr);
85448544
break;
85458545
case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
85468546
retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
@@ -9384,6 +9384,74 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
93849384
opcode, retType, spv::Scope::Subgroup, value, target, srcLoc);
93859385
}
93869386

9387+
SpirvInstruction *
9388+
SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
9389+
assert(callExpr->getNumArgs() == 1);
9390+
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
9391+
callExpr->getExprLoc());
9392+
SpirvInstruction *arg = doExpr(callExpr->getArg(0));
9393+
const QualType retType = callExpr->getCallReturnType(astContext);
9394+
9395+
if (isScalarType(retType))
9396+
return processWaveActiveAllEqualScalar(arg, callExpr->getExprLoc());
9397+
9398+
if (isVectorType(retType))
9399+
return processWaveActiveAllEqualVector(arg, callExpr->getExprLoc());
9400+
9401+
assert(isMxNMatrix(retType));
9402+
return processWaveActiveAllEqualMatrix(arg, retType, callExpr->getExprLoc());
9403+
}
9404+
9405+
SpirvInstruction *
9406+
SpirvEmitter::processWaveActiveAllEqualScalar(SpirvInstruction *arg,
9407+
clang::SourceLocation srcLoc) {
9408+
return spvBuilder.createGroupNonUniformUnaryOp(
9409+
srcLoc, spv::Op::OpGroupNonUniformAllEqual, astContext.BoolTy,
9410+
spv::Scope::Subgroup, arg);
9411+
}
9412+
9413+
SpirvInstruction *
9414+
SpirvEmitter::processWaveActiveAllEqualVector(SpirvInstruction *arg,
9415+
clang::SourceLocation srcLoc) {
9416+
uint32_t vectorSize = 0;
9417+
QualType elementType;
9418+
isVectorType(arg->getAstResultType(), &elementType, &vectorSize);
9419+
assert(vectorSize >= 2 && "Vector size in spir-v must be at least 2");
9420+
9421+
llvm::SmallVector<SpirvInstruction *, 4> elements;
9422+
for (uint32_t i = 0; i < vectorSize; ++i) {
9423+
SpirvInstruction *element =
9424+
spvBuilder.createCompositeExtract(elementType, arg, {i}, srcLoc);
9425+
elements.push_back(processWaveActiveAllEqualScalar(element, srcLoc));
9426+
}
9427+
9428+
QualType booleanVectortype =
9429+
astContext.getExtVectorType(astContext.BoolTy, vectorSize);
9430+
return spvBuilder.createCompositeConstruct(booleanVectortype, elements,
9431+
srcLoc);
9432+
}
9433+
9434+
SpirvInstruction *
9435+
SpirvEmitter::processWaveActiveAllEqualMatrix(SpirvInstruction *arg,
9436+
QualType booleanMatrixType,
9437+
clang::SourceLocation srcLoc) {
9438+
uint32_t numberOfRows = 0;
9439+
uint32_t numberOfColumns = 0;
9440+
QualType elementType;
9441+
isMxNMatrix(arg->getAstResultType(), &elementType, &numberOfRows,
9442+
&numberOfColumns);
9443+
assert(numberOfRows >= 2 && "Vector size in spir-v must be at least 2");
9444+
9445+
QualType rowType = astContext.getExtVectorType(elementType, numberOfColumns);
9446+
llvm::SmallVector<SpirvInstruction *, 4> rows;
9447+
for (uint32_t i = 0; i < numberOfRows; ++i) {
9448+
SpirvInstruction *row =
9449+
spvBuilder.createCompositeExtract(rowType, arg, {i}, srcLoc);
9450+
rows.push_back(processWaveActiveAllEqualVector(row, srcLoc));
9451+
}
9452+
return spvBuilder.createCompositeConstruct(booleanMatrixType, rows, srcLoc);
9453+
}
9454+
93879455
SpirvInstruction *SpirvEmitter::processIntrinsicModf(const CallExpr *callExpr) {
93889456
// Signature is: ret modf(x, ip)
93899457
// [in] x: the input floating-point value.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,35 @@ class SpirvEmitter : public ASTConsumer {
608608
SpirvInstruction *processWaveQuadWideShuffle(const CallExpr *,
609609
hlsl::IntrinsicOp op);
610610

611+
/// Generates the Spir-V instructions needed to implement the given call to
612+
/// WaveActiveAllEqual. Returns a pointer to the instruction that produces the
613+
/// final result.
614+
SpirvInstruction *processWaveActiveAllEqual(const CallExpr *);
615+
616+
/// Generates the Spir-V instructions needed to implement WaveActiveAllEqual
617+
/// with the scalar input `arg`. Returns a pointer to the instruction that
618+
/// produces the final result. srcLoc should be the source location of the
619+
/// original call.
620+
SpirvInstruction *
621+
processWaveActiveAllEqualScalar(SpirvInstruction *arg,
622+
clang::SourceLocation srcLoc);
623+
624+
/// Generates the Spir-V instructions needed to implement WaveActiveAllEqual
625+
/// with the vector input `arg`. Returns a pointer to the instruction that
626+
/// produces the final result. srcLoc should be the source location of the
627+
/// original call.
628+
SpirvInstruction *
629+
processWaveActiveAllEqualVector(SpirvInstruction *arg,
630+
clang::SourceLocation srcLoc);
631+
632+
/// Generates the Spir-V instructions needed to implement WaveActiveAllEqual
633+
/// with the matrix input `arg`. Returns a pointer to the instruction that
634+
/// produces the final result. srcLoc should be the source location of the
635+
/// original call.
636+
SpirvInstruction *
637+
processWaveActiveAllEqualMatrix(SpirvInstruction *arg, QualType,
638+
clang::SourceLocation srcLoc);
639+
611640
/// Processes the NonUniformResourceIndex intrinsic function.
612641
SpirvInstruction *processIntrinsicNonUniformResourceIndex(const CallExpr *);
613642

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: %dxc -T cs_6_0 -HV 2018 -E main -fspv-target-env=vulkan1.1
2+
3+
struct S {
4+
float2x2 val;
5+
bool res;
6+
};
7+
8+
RWStructuredBuffer<S> values;
9+
10+
// CHECK: OpCapability GroupNonUniformVote
11+
12+
[numthreads(32, 1, 1)]
13+
void main(uint3 id: SV_DispatchThreadID) {
14+
15+
// Each element of the matrix must be extracted, and be passed to OpGroupNonUniformAllEqual.
16+
// CHECK: [[ld:%\w+]] = OpLoad %mat2v2float %34
17+
18+
// Process the first row.
19+
// CHECK: [[row_0:%\w+]] = OpCompositeExtract %v2float [[ld]] 0
20+
// CHECK: [[element_0_0:%\w+]] = OpCompositeExtract %float [[row_0]] 0
21+
// CHECK: [[res_0_0:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element_0_0]]
22+
// CHECK: [[element_0_1:%\w+]] = OpCompositeExtract %float [[row_0]] 1
23+
// CHECK: [[res_0_1:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element_0_1]]
24+
25+
// Combine the results in a row for the results matrix.
26+
// CHECK: [[res_0:%\w+]] = OpCompositeConstruct %v2bool [[res_0_0]] [[res_0_1]]
27+
28+
// Process the second row.
29+
// CHECK: [[row_1:%\w+]] = OpCompositeExtract %v2float [[ld]] 1
30+
// CHECK: [[element_1_0:%\w+]] = OpCompositeExtract %float [[row_1]] 0
31+
// CHECK: [[res_1_0:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element_1_0]]
32+
// CHECK: [[element_1_1:%\w+]] = OpCompositeExtract %float [[row_1]] 1
33+
// CHECK: [[res_1_1:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element_1_1]]
34+
35+
36+
// Combine the results in a row for the results matrix.
37+
// CHECK: [[res_1:%\w+]] = OpCompositeConstruct %v2bool [[res_1_0]] [[res_1_1]]
38+
39+
// Combind the results for each row in a "matrix" for the final result.
40+
// CHECK: [[res_matrix:%\w+]] = OpCompositeConstruct %_arr_v2bool_uint_2 [[res_0]] [[res_1]]
41+
42+
// Apply the `all` to the entire matrix.
43+
// CHECK: [[res_vec0:%\w+]] = OpCompositeExtract %v2bool [[res_matrix]] 0
44+
// CHECK: [[all0:%\w+]] = OpAll %bool [[res_vec0]]
45+
// CHECK: [[res_vec1:%\w+]] = OpCompositeExtract %v2bool %53 1
46+
// CHECK: [[all1:%\w+]] = OpAll %bool [[res_vec1]]
47+
// CHECK: [[all_vec:%\w+]] = OpCompositeConstruct %v2bool [[all0]] [[all1]]
48+
// CHECK: [[all:%\w+]] = OpAll %bool [[all_vec]]
49+
50+
// Do the select.
51+
// CHECK: OpSelect %uint [[all]] %uint_1 %uint_0
52+
values[id.x].res = all(WaveActiveAllEqual(values[id.x].val));
53+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %dxc -T cs_6_0 -HV 2018 -E main -fspv-target-env=vulkan1.1
2+
3+
struct S {
4+
float1x1 val;
5+
bool res;
6+
};
7+
8+
RWStructuredBuffer<S> values;
9+
10+
// CHECK: OpCapability GroupNonUniformVote
11+
12+
[numthreads(32, 1, 1)]
13+
void main(uint3 id: SV_DispatchThreadID) {
14+
15+
// For a 1x1 matrix, the spirv type should become a scalar because Spir-V cannot have a 1x1 matrix.
16+
17+
// CHECK: [[ld:%\w+]] = OpLoad %float %32
18+
// CHECK: [[res:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[ld]]
19+
// CHECK: OpSelect %uint [[res]] %uint_1 %uint_0
20+
values[id.x].res = all(WaveActiveAllEqual(values[id.x].val));
21+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %dxc -T cs_6_0 -HV 2018 -E main -fspv-target-env=vulkan1.1
2+
3+
struct S {
4+
uint val;
5+
bool res;
6+
};
7+
8+
RWStructuredBuffer<S> values;
9+
10+
// CHECK: OpCapability GroupNonUniformVote
11+
12+
[numthreads(32, 1, 1)]
13+
void main(uint3 id: SV_DispatchThreadID) {
14+
// CHECK: [[eq:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 {{%\w+}}
15+
// CHECK: OpSelect %uint [[eq]] %uint_1 %uint_0
16+
values[id.x].res = WaveActiveAllEqual(values[id.x].val);
17+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %dxc -T cs_6_0 -HV 2018 -E main -fspv-target-env=vulkan1.1
2+
3+
struct S {
4+
float4 val;
5+
bool res;
6+
};
7+
8+
RWStructuredBuffer<S> values;
9+
10+
// CHECK: OpCapability GroupNonUniformVote
11+
12+
[numthreads(32, 1, 1)]
13+
void main(uint3 id: SV_DispatchThreadID) {
14+
15+
// Each element of the vector must be extracted, and be passed to OpGroupNonUniformAllEqual.
16+
// CHECK: [[ld:%\w+]] = OpLoad %v4float
17+
// CHECK: [[element0:%\w+]] = OpCompositeExtract %float [[ld]] 0
18+
// CHECK: [[res0:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element0]]
19+
// CHECK: [[element1:%\w+]] = OpCompositeExtract %float [[ld]] 1
20+
// CHECK: [[res1:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element1]]
21+
// CHECK: [[element2:%\w+]] = OpCompositeExtract %float [[ld]] 2
22+
// CHECK: [[res2:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element2]]
23+
// CHECK: [[element3:%\w+]] = OpCompositeExtract %float [[ld]] 3
24+
// CHECK: [[res3:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element3]]
25+
26+
// Then the results must be combined into a boolean vector.
27+
// CHECK: [[vec_res:%\w+]] = OpCompositeConstruct %v4bool [[res0]] [[res1]] [[res2]] [[res3]]
28+
// CHECK: [[res:%\w+]] = OpAll %bool [[vec_res]]
29+
// CHECK: OpSelect %uint [[res]] %uint_1 %uint_0
30+
values[id.x].res = all(WaveActiveAllEqual(values[id.x].val));
31+
}

tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-equal.hlsl

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %dxc -T cs_6_0 -E main -HV 2018 -fspv-target-env=vulkan1.0
2+
3+
struct S {
4+
float4 val1;
5+
uint val2;
6+
bool res;
7+
};
8+
9+
RWStructuredBuffer<S> values;
10+
11+
[numthreads(32, 1, 1)]
12+
void main(uint3 id: SV_DispatchThreadID) {
13+
uint x = id.x;
14+
values[x].res = WaveActiveAllEqual(values[x].val1) && WaveActiveAllEqual(values[x].val2);
15+
}
16+
17+
// CHECK: sm6.wave-active-all-equal.vulkan1.0.hlsl:14:21: error: Vulkan 1.1 is required for Wave Operation but not permitted to use
18+
// CHECK-NEXT: values[x].res = WaveActiveAllEqual(values[x].val1) && WaveActiveAllEqual(values[x].val2);
19+
// CHECK: note: please specify your target environment via command line option -fspv-target-env=

tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-equal.vulkan1.2.hlsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ void main(uint3 id: SV_DispatchThreadID) {
1717
uint x = id.x;
1818
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_StorageBuffer_v4float %values %int_0 {{%\d+}} %int_0
1919
// CHECK-NEXT: [[f32val:%\d+]] = OpLoad %v4float [[ptr]]
20-
// TODO: The front end will return bool4 for the first call, which acutally should be bool.
21-
// XXXXX-NEXT: {{%\d+}} = OpGroupNonUniformAllEqual %bool %uint_3 [[f32val]]
20+
// CHECK: [[element0:%\w+]] = OpCompositeExtract %float [[f32val]] 0
21+
// CHECK: [[res0:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element0]]
22+
// CHECK: [[element1:%\w+]] = OpCompositeExtract %float [[f32val]] 1
23+
// CHECK: [[res1:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element1]]
24+
// CHECK: [[element2:%\w+]] = OpCompositeExtract %float [[f32val]] 2
25+
// CHECK: [[res2:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element2]]
26+
// CHECK: [[element3:%\w+]] = OpCompositeExtract %float [[f32val]] 3
27+
// CHECK: [[res3:%\w+]] = OpGroupNonUniformAllEqual %bool %uint_3 [[element3]]
28+
// CHECK: [[vec_res:%\w+]] = OpCompositeConstruct %v4bool [[res0]] [[res1]] [[res2]] [[res3]]
2229

2330
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_StorageBuffer_uint %values %int_0 {{%\d+}} %int_1
2431
// CHECK-NEXT: [[u32val:%\d+]] = OpLoad %uint [[ptr]]

tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,11 +1587,20 @@ TEST_F(FileTest, SM6WaveActiveBallot) {
15871587
runFileTest("sm6.wave-active-ballot.hlsl");
15881588
}
15891589

1590-
// Shader model 6.0 wave reduction
1591-
// TODO(5410): Still unclear what should happen with WaveActiveAllEqual with a vector parameter.
1592-
// For now, the generated SPIR-V is invalid.
1593-
TEST_F(FileTest, SM6WaveActiveAllEqual) {
1594-
runFileTest("sm6.wave-active-all-equal.hlsl", Expect::Success, /*runValidation=*/ false);
1590+
TEST_F(FileTest, SM6WaveActiveAllEqualScalar) {
1591+
runFileTest("sm6.wave-active-all-equal-scalar.hlsl");
1592+
}
1593+
TEST_F(FileTest, SM6WaveActiveAllEqualVector) {
1594+
runFileTest("sm6.wave-active-all-equal-vector.hlsl");
1595+
}
1596+
TEST_F(FileTest, SM6WaveActiveAllEqualMatrix) {
1597+
runFileTest("sm6.wave-active-all-equal-matrix.hlsl");
1598+
}
1599+
TEST_F(FileTest, SM6WaveActiveAllEqualMatrix1x1) {
1600+
runFileTest("sm6.wave-active-all-equal-matrix1x1.hlsl");
1601+
}
1602+
TEST_F(FileTest, SM6WaveActiveAllEqualVulkan1_0) {
1603+
runFileTest("sm6.wave-active-all-equal.vulkan1.0.hlsl", Expect::Failure);
15951604
}
15961605
TEST_F(FileTest, SM6WaveActiveSum) {
15971606
runFileTest("sm6.wave-active-sum.hlsl");
@@ -2950,9 +2959,7 @@ TEST_F(FileTest, CompatibilityWithVk1p1) {
29502959
runFileTest("sm6.quad-read-across-x.vulkan1.2.hlsl");
29512960
runFileTest("sm6.quad-read-across-y.vulkan1.2.hlsl");
29522961
runFileTest("sm6.quad-read-lane-at.vulkan1.2.hlsl");
2953-
// TODO(5410): Still unclear what should happen with WaveActiveAllEqual with a vector parameter.
2954-
// For now, the generated SPIR-V is invalid.
2955-
runFileTest("sm6.wave-active-all-equal.vulkan1.2.hlsl", Expect::Success, /*runValidation=*/false);
2962+
runFileTest("sm6.wave-active-all-equal.vulkan1.2.hlsl");
29562963
runFileTest("sm6.wave-active-all-true.vulkan1.2.hlsl");
29572964
runFileTest("sm6.wave-active-any-true.vulkan1.2.hlsl");
29582965
runFileTest("sm6.wave-active-ballot.vulkan1.2.hlsl");

0 commit comments

Comments
 (0)