Skip to content

Commit b97f9b9

Browse files
authored
Handle passing an element of a RWStructureBuffer array to a function (#5447)
getFinalACSBufferCounter currently returns a CounterIdAliasPair. This is because all cases in the original design could return the entire counter variable. When using arrays for RWStructuredBuffers, we will need to be able to return something the represents a portion of the array that is needed. The best method of doing that is to write reusable functions that return a SpirvInstruction whose result is a pointer to the approparite element of the array. This is then used to allow copying individual counter variables. It does not handle copying the entire array of counter variables.
1 parent 3c06afb commit b97f9b9

10 files changed

Lines changed: 129 additions & 43 deletions

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,14 @@ std::string StageVar::getSemanticStr() const {
661661
return ss.str();
662662
}
663663

664-
SpirvInstruction *CounterIdAliasPair::get(SpirvBuilder &builder,
665-
SpirvContext &spvContext) const {
664+
SpirvInstruction *CounterIdAliasPair::getAliasAddress() const {
665+
assert(isAlias);
666+
return counterVar;
667+
}
668+
669+
SpirvInstruction *
670+
CounterIdAliasPair::getCounterVariable(SpirvBuilder &builder,
671+
SpirvContext &spvContext) const {
666672
if (isAlias) {
667673
const auto *counterType = spvContext.getACSBufferCounterType();
668674
const auto *counterVarType =
@@ -689,7 +695,8 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
689695
if (!srcField)
690696
return false;
691697

692-
field.counterVar.assign(*srcField, builder, context);
698+
field.counterVar.assign(srcField->getCounterVariable(builder, context),
699+
builder);
693700
}
694701

695702
return true;
@@ -729,7 +736,8 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
729736
if (!srcField)
730737
return false;
731738

732-
field.counterVar.assign(*srcField, builder, context);
739+
field.counterVar.assign(srcField->getCounterVariable(builder, context),
740+
builder);
733741
for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
734742
srcIndices.pop_back();
735743
}

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,19 @@ class CounterIdAliasPair {
7171
CounterIdAliasPair(SpirvVariable *var, bool alias)
7272
: counterVar(var), isAlias(alias) {}
7373

74+
/// Returns the pointer to the counter variable alias. This returns a pointer
75+
/// that can be used as the address to a store instruction when storing to an
76+
/// alias counter.
77+
SpirvInstruction *getAliasAddress() const;
78+
7479
/// Returns the pointer to the counter variable. Dereferences first if this is
7580
/// an alias to a counter variable.
76-
SpirvInstruction *get(SpirvBuilder &builder, SpirvContext &spvContext) const;
81+
SpirvInstruction *getCounterVariable(SpirvBuilder &builder,
82+
SpirvContext &spvContext) const;
7783

78-
/// Stores the counter variable's pointer in srcPair to the curent counter
84+
/// Stores the counter variable pointed to by src to the curent counter
7985
/// variable. The current counter variable must be an alias.
80-
inline void assign(const CounterIdAliasPair &srcPair, SpirvBuilder &,
81-
SpirvContext &) const;
86+
inline void assign(SpirvInstruction *src, SpirvBuilder &) const;
8287

8388
private:
8489
SpirvVariable *counterVar;
@@ -906,12 +911,10 @@ bool SemanticInfo::isTarget() const {
906911
return semantic && semantic->GetKind() == hlsl::Semantic::Kind::Target;
907912
}
908913

909-
void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
910-
SpirvBuilder &builder,
911-
SpirvContext &context) const {
914+
void CounterIdAliasPair::assign(SpirvInstruction *src,
915+
SpirvBuilder &builder) const {
912916
assert(isAlias);
913-
builder.createStore(counterVar, srcPair.get(builder, context),
914-
/* SourceLocation */ {});
917+
builder.createStore(counterVar, src, /* SourceLocation */ {});
915918
}
916919

917920
DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4539,26 +4539,16 @@ SpirvEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
45394539
(void)doExpr(object);
45404540
}
45414541

4542-
const auto *counterPair = getFinalACSBufferCounter(object);
4543-
if (!counterPair) {
4542+
auto *counter = getFinalACSBufferCounterInstruction(object);
4543+
if (!counter) {
45444544
emitFatalError("cannot find the associated counter variable",
45454545
object->getExprLoc());
45464546
return nullptr;
45474547
}
45484548

4549-
llvm::SmallVector<SpirvInstruction *, 2> indexes;
4550-
if(const auto *arraySubscriptExpr = dyn_cast<ArraySubscriptExpr>(object)) {
4551-
// TODO(5440): This codes does not handle multi-dimensional arrays. We need
4552-
// to look at specific example to determine the best way to do it.
4553-
indexes.push_back(doExpr(arraySubscriptExpr->getIdx()));
4554-
}
4555-
45564549
// Add an extra 0 because the counter is wrapped in a struct.
4557-
indexes.push_back(zero);
4558-
4559-
auto *counterPtr = spvBuilder.createAccessChain(
4560-
astContext.IntTy, counterPair->get(spvBuilder, spvContext), indexes,
4561-
srcLoc, srcRange);
4550+
auto *counterPtr = spvBuilder.createAccessChain(astContext.IntTy, counter,
4551+
{zero}, srcLoc, srcRange);
45624552

45634553
SpirvInstruction *index = nullptr;
45644554
if (isInc) {
@@ -4596,13 +4586,13 @@ bool SpirvEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
45964586
// Handle AssocCounter#1 (see CounterVarFields comment)
45974587
if (const auto *dstPair =
45984588
declIdMapper.createOrGetCounterIdAliasPair(dstDecl)) {
4599-
const auto *srcPair = getFinalACSBufferCounter(srcExpr);
4600-
if (!srcPair) {
4589+
auto *srcCounter = getFinalACSBufferCounterInstruction(srcExpr);
4590+
if (!srcCounter) {
46014591
emitFatalError("cannot find the associated counter variable",
46024592
srcExpr->getExprLoc());
46034593
return false;
46044594
}
4605-
dstPair->assign(*srcPair, spvBuilder, spvContext);
4595+
dstPair->assign(srcCounter, spvBuilder);
46064596
return true;
46074597
}
46084598

@@ -4633,18 +4623,18 @@ bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr,
46334623
dstExpr = dstExpr->IgnoreParenCasts();
46344624
srcExpr = srcExpr->IgnoreParenCasts();
46354625

4636-
const auto *dstPair = getFinalACSBufferCounter(dstExpr);
4637-
const auto *srcPair = getFinalACSBufferCounter(srcExpr);
4626+
auto *dstCounter = getFinalACSBufferCounterAliasAddressInstruction(dstExpr);
4627+
auto *srcCounter = getFinalACSBufferCounterInstruction(srcExpr);
46384628

4639-
if ((dstPair == nullptr) != (srcPair == nullptr)) {
4629+
if ((dstCounter == nullptr) != (srcCounter == nullptr)) {
46404630
emitFatalError("cannot handle associated counter variable assignment",
46414631
srcExpr->getExprLoc());
46424632
return false;
46434633
}
46444634

46454635
// Handle AssocCounter#1 & AssocCounter#2
4646-
if (dstPair && srcPair) {
4647-
dstPair->assign(*srcPair, spvBuilder, spvContext);
4636+
if (dstCounter && srcCounter) {
4637+
spvBuilder.createStore(dstCounter, srcCounter, /* SourceLocation */ {});
46484638
return true;
46494639
}
46504640

@@ -4662,6 +4652,37 @@ bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr,
46624652
return false;
46634653
}
46644654

4655+
SpirvInstruction *SpirvEmitter::getFinalACSBufferCounterAliasAddressInstruction(
4656+
const Expr *expr) {
4657+
const CounterIdAliasPair *counter = getFinalACSBufferCounter(expr);
4658+
return (counter ? counter->getAliasAddress() : nullptr);
4659+
}
4660+
4661+
SpirvInstruction *
4662+
SpirvEmitter::getFinalACSBufferCounterInstruction(const Expr *expr) {
4663+
const CounterIdAliasPair *counterPair = getFinalACSBufferCounter(expr);
4664+
if (!counterPair)
4665+
return nullptr;
4666+
4667+
SpirvInstruction *counter =
4668+
counterPair->getCounterVariable(spvBuilder, spvContext);
4669+
const auto srcLoc = expr->getExprLoc();
4670+
4671+
// TODO(5440): This codes does not handle multi-dimensional arrays. We need
4672+
// to look at specific example to determine the best way to do it. Could a
4673+
// call to collectArrayStructIndices handle that for us?
4674+
llvm::SmallVector<SpirvInstruction *, 2> indexes;
4675+
if (const auto *arraySubscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
4676+
indexes.push_back(doExpr(arraySubscriptExpr->getIdx()));
4677+
}
4678+
4679+
if (!indexes.empty()) {
4680+
counter = spvBuilder.createAccessChain(spvContext.getACSBufferCounterType(),
4681+
counter, indexes, srcLoc);
4682+
}
4683+
return counter;
4684+
}
4685+
46654686
const CounterIdAliasPair *
46664687
SpirvEmitter::getFinalACSBufferCounter(const Expr *expr) {
46674688
// AssocCounter#1: referencing some stand-alone variable

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,21 @@ class SpirvEmitter : public ASTConsumer {
10511051
const Expr *srcExpr);
10521052
bool tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr);
10531053

1054+
/// Returns an instruction that points to the alias counter variable with the
1055+
/// entity represented by expr.
1056+
///
1057+
/// This method only handles final alias structured buffers, which means
1058+
/// AssocCounter#1 and AssocCounter#2.
1059+
SpirvInstruction *
1060+
getFinalACSBufferCounterAliasAddressInstruction(const Expr *expr);
1061+
1062+
/// Returns an instruction that points to the counter variable with the entity
1063+
/// represented by expr.
1064+
///
1065+
/// This method only handles final alias structured buffers, which means
1066+
/// AssocCounter#1 and AssocCounter#2.
1067+
SpirvInstruction *getFinalACSBufferCounterInstruction(const Expr *expr);
1068+
10541069
/// Returns the counter variable's information associated with the entity
10551070
/// represented by the given decl.
10561071
///

tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.const.index.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
1717
float4 main(PSInput input) : SV_TARGET
1818
{
1919
// Correctly increment the counter.
20-
// CHECK: [[ac:%\d+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer %int_3 %uint_0
21-
// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
20+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer %int_3
21+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
22+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
2223
g_rwbuffer[3].IncrementCounter();
2324

2425
// Correctly access the buffer.

tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.flatten.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
2121
float4 main(PSInput input) : SV_TARGET
2222
{
2323
// Correctly increment the counter.
24-
// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
25-
// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
24+
// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
25+
// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
26+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
2627
g_rwbuffer[input.idx].IncrementCounter();
2728

2829
// Correctly access the buffer.

tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
1717
float4 main(PSInput input) : SV_TARGET
1818
{
1919
// Correctly increment the counter.
20-
// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
21-
// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
20+
// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
21+
// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
22+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
2223
g_rwbuffer[input.idx].IncrementCounter();
2324

2425
// Correctly access the buffer.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %dxc -T ps_6_6 -E main -O0 -fvk-allow-rwstructuredbuffer-arrays
2+
3+
struct PSInput
4+
{
5+
uint idx : COLOR;
6+
};
7+
8+
// CHECK: OpDecorate %g_rwbuffer DescriptorSet 2
9+
// CHECK: OpDecorate %g_rwbuffer Binding 0
10+
// CHECK: OpDecorate %counter_var_g_rwbuffer DescriptorSet 2
11+
// CHECK: OpDecorate %counter_var_g_rwbuffer Binding 1
12+
13+
// CHECK: %g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_RWStructuredBuffer_uint_uint_5 Uniform
14+
// CHECK: %counter_var_g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_5 Uniform
15+
RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
16+
17+
void func(RWStructuredBuffer<uint> local) {
18+
local.IncrementCounter();
19+
}
20+
21+
float4 main(PSInput input) : SV_TARGET
22+
{
23+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
24+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
25+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
26+
func(g_rwbuffer[input.idx]);
27+
28+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_RWStructuredBuffer_uint %g_rwbuffer {{%\d+}}
29+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ac1]] %int_0 %uint_0
30+
// CHECK: OpLoad %uint [[ac2]]
31+
return g_rwbuffer[input.idx][0];
32+
}

tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.unbounded.array.counter.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[] : register(u0, space2);
1717
float4 main(PSInput input) : SV_TARGET
1818
{
1919
// Correctly increment the counter.
20-
// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
21-
// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
20+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
21+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
22+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
2223
g_rwbuffer[input.idx].IncrementCounter();
2324

2425
// Correctly access the buffer.

tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ TEST_F(FileTest, RWStructuredBufferArrayCounterConstIndex) {
161161
TEST_F(FileTest, RWStructuredBufferArrayCounterFlattened) {
162162
runFileTest("type.rwstructured-buffer.array.counter.flatten.hlsl");
163163
}
164+
TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect) {
165+
runFileTest("type.rwstructured-buffer.array.counter.indirect.hlsl");
166+
}
164167
TEST_F(FileTest, AppendStructuredBufferArrayError) {
165168
runFileTest("type.append-structured-buffer.array.hlsl");
166169
}

0 commit comments

Comments
 (0)