Skip to content

Commit 65cef0a

Browse files
authored
[spirv] Added handling for multi-dimensional CT buffers (#4507)
1 parent d8d5b26 commit 65cef0a

4 files changed

Lines changed: 130 additions & 26 deletions

File tree

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,52 @@ bool containOnlyVecWithFourFloats(QualType type, bool use16Bit) {
600600
return false;
601601
}
602602

603+
// Evaluates whether a given QualType is in fact an array and unroll
604+
// accordingly. Returns the appropriate RecordType for the provided QualType.
605+
// Will be the actual type of the array element, if it is indeed an array. If
606+
// the returned QualType is not null and startType was in fact an array, the out
607+
// parameter arraySizes contains the dimensions of each cascaded array, from
608+
// right to left order as defined in source. e.g.: float a[2][3] -> arraySizes
609+
// [3, 2].
610+
QualType unrollMultiDimensionalArray(const ASTContext &astContext,
611+
const QualType &startType,
612+
llvm::SmallVectorImpl<int> *arraySizes) {
613+
614+
QualType innerQualType = startType;
615+
616+
// Unroll a multidimensional array.
617+
const auto *arrayType = startType->getAsArrayTypeUnsafe();
618+
619+
while (arrayType) {
620+
// If we are here the top level is an array let's grab it's size.
621+
if (const auto *caType = astContext.getAsConstantArrayType(innerQualType)) {
622+
auto arrayExtend = static_cast<int>(caType->getSize().getZExtValue());
623+
arraySizes->push_back(arrayExtend);
624+
} else {
625+
// It's certainly an array, but we can't make it out it's dimension. So
626+
// mark it as runtime array.
627+
arraySizes->push_back(-1);
628+
}
629+
630+
// Grab the sub element and see if it's an element or another array.
631+
innerQualType = arrayType->getElementType();
632+
if (innerQualType->isArrayType()) {
633+
arrayType = innerQualType->getAsArrayTypeUnsafe();
634+
} else if (innerQualType->isRecordType()) {
635+
// If we reached the inner type, bail.
636+
break;
637+
} else {
638+
// In case we encountered anything else than the expected types, bail
639+
// and report the error.
640+
return QualType{};
641+
}
642+
}
643+
644+
std::reverse(arraySizes->begin(), arraySizes->end());
645+
646+
return innerQualType;
647+
}
648+
603649
} // anonymous namespace
604650

605651
std::string StageVar::getSemanticStr() const {
@@ -1132,8 +1178,9 @@ DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
11321178
}
11331179

11341180
SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
1135-
const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
1136-
llvm::StringRef typeName, llvm::StringRef varName) {
1181+
const DeclContext *decl, llvm::ArrayRef<int> arraySize,
1182+
const ContextUsageKind usageKind, llvm::StringRef typeName,
1183+
llvm::StringRef varName) {
11371184
// cbuffers are translated into OpTypeStruct with Block decoration.
11381185
// tbuffers are translated into OpTypeStruct with BufferBlock decoration.
11391186
// Push constants are translated into OpTypeStruct with Block decoration.
@@ -1188,13 +1235,14 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
11881235
forTBuffer ? StructInterfaceType::StorageBuffer
11891236
: StructInterfaceType::UniformBuffer);
11901237

1191-
// Make an array if requested.
1192-
if (arraySize > 0) {
1193-
resultType = spvContext.getArrayType(resultType, arraySize,
1194-
/*ArrayStride*/ llvm::None);
1195-
} else if (arraySize == -1) {
1196-
resultType =
1197-
spvContext.getRuntimeArrayType(resultType, /*ArrayStride*/ llvm::None);
1238+
for (int size : arraySize) {
1239+
if (size != -1) {
1240+
resultType = spvContext.getArrayType(resultType, size,
1241+
/*ArrayStride*/ llvm::None);
1242+
} else {
1243+
resultType = spvContext.getRuntimeArrayType(resultType,
1244+
/*ArrayStride*/ llvm::None);
1245+
}
11981246
}
11991247

12001248
// Register the <type-id> for this decl
@@ -1223,6 +1271,17 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
12231271
return var;
12241272
}
12251273

1274+
SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
1275+
const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
1276+
llvm::StringRef typeName, llvm::StringRef varName) {
1277+
llvm::SmallVector<int, 1> arraySizes;
1278+
if (arraySize > 0)
1279+
arraySizes.push_back(arraySize);
1280+
1281+
return createStructOrStructArrayVarOfExplicitLayout(
1282+
decl, arraySizes, usageKind, typeName, varName);
1283+
}
1284+
12261285
void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
12271286
const auto *valueDecl = dyn_cast<ValueDecl>(decl);
12281287
const auto enumConstant =
@@ -1291,23 +1350,22 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
12911350
assert(isConstantTextureBuffer(type));
12921351
const RecordType *recordType = nullptr;
12931352
const RecordType *templatedType = nullptr;
1294-
int arraySize = 0;
1295-
1296-
// In case we have an array of ConstantBuffer/TextureBuffer:
1297-
if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
1298-
const QualType elemType = arrayType->getElementType();
1299-
recordType = elemType->getAs<RecordType>();
1300-
templatedType =
1301-
hlsl::GetHLSLResourceResultType(elemType)->getAs<RecordType>();
1302-
if (const auto *caType = astContext.getAsConstantArrayType(type)) {
1303-
arraySize = static_cast<uint32_t>(caType->getSize().getZExtValue());
1304-
} else {
1305-
arraySize = -1;
1306-
}
1307-
} else {
1308-
recordType = type->getAs<RecordType>();
1309-
templatedType = hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
1353+
1354+
llvm::SmallVector<int, 2> arraySizes;
1355+
QualType actualType =
1356+
unrollMultiDimensionalArray(astContext, type, &arraySizes);
1357+
if (actualType.isNull()) {
1358+
emitError("encountered unsupported type while decomposing "
1359+
"multi-dimensional array",
1360+
decl->getLocStart())
1361+
<< type;
1362+
return nullptr;
13101363
}
1364+
1365+
recordType = actualType->getAs<RecordType>();
1366+
templatedType =
1367+
hlsl::GetHLSLResourceResultType(actualType)->getAs<RecordType>();
1368+
13111369
if (!recordType) {
13121370
emitError("constant/texture buffer type %0 unimplemented",
13131371
decl->getLocStart())
@@ -1331,7 +1389,7 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
13311389
templatedType->getDecl()->getName().str();
13321390

13331391
SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
1334-
templatedType->getDecl(), arraySize, usageKind, structName,
1392+
templatedType->getDecl(), arraySizes, usageKind, structName,
13351393
decl->getName());
13361394

13371395
// We register the VarDecl here.

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "clang/AST/Attr.h"
1919
#include "clang/SPIRV/FeatureManager.h"
2020
#include "clang/SPIRV/SpirvBuilder.h"
21+
#include "llvm/ADT/ArrayRef.h"
2122
#include "llvm/ADT/DenseMap.h"
2223
#include "llvm/ADT/Optional.h"
2324
#include "llvm/ADT/SmallVector.h"
@@ -593,6 +594,25 @@ class DeclResultIdMapper {
593594
/// construction.
594595
bool finalizeStageIOLocations(bool forInput);
595596

597+
/// Creates a variable of struct type with explicit layout decorations.
598+
/// The sub-Decls in the given DeclContext will be treated as the struct
599+
/// fields. The struct type will be named as typeName, and the variable
600+
/// will be named as varName.
601+
///
602+
/// This method should only be used for cbuffers/ContantBuffers, tbuffers/
603+
/// TextureBuffers, and PushConstants. usageKind must be set properly
604+
/// depending on the usage kind.
605+
///
606+
/// If arraySize is 0, the variable will be created as a struct ; if arraySize
607+
/// is > 0, the variable will be created as an array; if arraySize is -1, the
608+
/// variable will be created as a runtime array.
609+
///
610+
/// Panics if the DeclContext is neither HLSLBufferDecl or RecordDecl.
611+
SpirvVariable *createStructOrStructArrayVarOfExplicitLayout(
612+
const DeclContext *decl, llvm::ArrayRef<int> arraySize,
613+
ContextUsageKind usageKind, llvm::StringRef typeName,
614+
llvm::StringRef varName);
615+
596616
/// Creates a variable of struct type with explicit layout decorations.
597617
/// The sub-Decls in the given DeclContext will be treated as the struct
598618
/// fields. The struct type will be named as typeName, and the variable
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %dxc -T ps_6_0 -E main -fspv-target-env=universal1.5
2+
3+
struct Foo {
4+
float4 a;
5+
int2 b;
6+
};
7+
8+
// CHECK: [[s01:%\w+]] = OpTypeStruct %v4float %v2int
9+
// CHECK: [[s02:%\w+]] = OpTypeArray [[s01]] %uint_3
10+
// CHECK: [[s03:%\w+]] = OpTypeArray [[s02]] %uint_2
11+
// CHECK: [[s04:%\w+]] = OpTypePointer Uniform [[s03]]
12+
ConstantBuffer<Foo> myCB2[2][3] : register(b0, space1);
13+
14+
struct VSOutput {
15+
float2 TexCoord : TEXCOORD;
16+
};
17+
18+
float4 main(VSOutput input) : SV_TARGET {
19+
// CHECK: [[s05:%\w+]] = OpTypePointer Uniform %v4float
20+
// CHECK: [[s06:%\w+]] = OpVariable [[s04:%\w+]] Uniform
21+
// CHECK: OpAccessChain [[s05]] [[s06]] %int_1 %int_0 %int_0
22+
return float4(1.0, 1.0, 1.0, 1.0) * myCB2[1][0].a;
23+
}

tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ TEST_F(FileTest, TypeCBufferIncludingResource) {
9898
TEST_F(FileTest, ConstantBufferType) {
9999
runFileTest("type.constant-buffer.hlsl");
100100
}
101+
TEST_F(FileTest, ConstantBufferTypeMultiDimensionalArray) {
102+
runFileTest("type.constant-buffer.multiple-dimensions.hlsl");
103+
}
101104
TEST_F(FileTest, BindlessConstantBufferArrayType) {
102105
runFileTest("type.constant-buffer.bindless.array.hlsl", Expect::Success,
103106
/*legalization*/ false);

0 commit comments

Comments
 (0)