@@ -600,6 +600,52 @@ bool containOnlyVecWithFourFloats(QualType type, bool use16Bit) {
600600 return false ;
601601}
602602
603+ // Evaluates whether a given QualType is in fact an array and unroll
604+ // accordingly. Returns the appropriate RecordType for the provided QualType.
605+ // Will be the actual type of the array element, if it is indeed an array. If
606+ // the returned QualType is not null and startType was in fact an array, the out
607+ // parameter arraySizes contains the dimensions of each cascaded array, from
608+ // right to left order as defined in source. e.g.: float a[2][3] -> arraySizes
609+ // [3, 2].
610+ QualType unrollMultiDimensionalArray (const ASTContext &astContext,
611+ const QualType &startType,
612+ llvm::SmallVectorImpl<int > *arraySizes) {
613+
614+ QualType innerQualType = startType;
615+
616+ // Unroll a multidimensional array.
617+ const auto *arrayType = startType->getAsArrayTypeUnsafe ();
618+
619+ while (arrayType) {
620+ // If we are here the top level is an array let's grab it's size.
621+ if (const auto *caType = astContext.getAsConstantArrayType (innerQualType)) {
622+ auto arrayExtend = static_cast <int >(caType->getSize ().getZExtValue ());
623+ arraySizes->push_back (arrayExtend);
624+ } else {
625+ // It's certainly an array, but we can't make it out it's dimension. So
626+ // mark it as runtime array.
627+ arraySizes->push_back (-1 );
628+ }
629+
630+ // Grab the sub element and see if it's an element or another array.
631+ innerQualType = arrayType->getElementType ();
632+ if (innerQualType->isArrayType ()) {
633+ arrayType = innerQualType->getAsArrayTypeUnsafe ();
634+ } else if (innerQualType->isRecordType ()) {
635+ // If we reached the inner type, bail.
636+ break ;
637+ } else {
638+ // In case we encountered anything else than the expected types, bail
639+ // and report the error.
640+ return QualType{};
641+ }
642+ }
643+
644+ std::reverse (arraySizes->begin (), arraySizes->end ());
645+
646+ return innerQualType;
647+ }
648+
603649} // anonymous namespace
604650
605651std::string StageVar::getSemanticStr () const {
@@ -1132,8 +1178,9 @@ DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
11321178}
11331179
11341180SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout (
1135- const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
1136- llvm::StringRef typeName, llvm::StringRef varName) {
1181+ const DeclContext *decl, llvm::ArrayRef<int > arraySize,
1182+ const ContextUsageKind usageKind, llvm::StringRef typeName,
1183+ llvm::StringRef varName) {
11371184 // cbuffers are translated into OpTypeStruct with Block decoration.
11381185 // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
11391186 // Push constants are translated into OpTypeStruct with Block decoration.
@@ -1188,13 +1235,14 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
11881235 forTBuffer ? StructInterfaceType::StorageBuffer
11891236 : StructInterfaceType::UniformBuffer);
11901237
1191- // Make an array if requested.
1192- if (arraySize > 0 ) {
1193- resultType = spvContext.getArrayType (resultType, arraySize,
1194- /* ArrayStride*/ llvm::None);
1195- } else if (arraySize == -1 ) {
1196- resultType =
1197- spvContext.getRuntimeArrayType (resultType, /* ArrayStride*/ llvm::None);
1238+ for (int size : arraySize) {
1239+ if (size != -1 ) {
1240+ resultType = spvContext.getArrayType (resultType, size,
1241+ /* ArrayStride*/ llvm::None);
1242+ } else {
1243+ resultType = spvContext.getRuntimeArrayType (resultType,
1244+ /* ArrayStride*/ llvm::None);
1245+ }
11981246 }
11991247
12001248 // Register the <type-id> for this decl
@@ -1223,6 +1271,17 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
12231271 return var;
12241272}
12251273
1274+ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout (
1275+ const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
1276+ llvm::StringRef typeName, llvm::StringRef varName) {
1277+ llvm::SmallVector<int , 1 > arraySizes;
1278+ if (arraySize > 0 )
1279+ arraySizes.push_back (arraySize);
1280+
1281+ return createStructOrStructArrayVarOfExplicitLayout (
1282+ decl, arraySizes, usageKind, typeName, varName);
1283+ }
1284+
12261285void DeclResultIdMapper::createEnumConstant (const EnumConstantDecl *decl) {
12271286 const auto *valueDecl = dyn_cast<ValueDecl>(decl);
12281287 const auto enumConstant =
@@ -1291,23 +1350,22 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
12911350 assert (isConstantTextureBuffer (type));
12921351 const RecordType *recordType = nullptr ;
12931352 const RecordType *templatedType = nullptr ;
1294- int arraySize = 0 ;
1295-
1296- // In case we have an array of ConstantBuffer/TextureBuffer:
1297- if (const auto *arrayType = type->getAsArrayTypeUnsafe ()) {
1298- const QualType elemType = arrayType->getElementType ();
1299- recordType = elemType->getAs <RecordType>();
1300- templatedType =
1301- hlsl::GetHLSLResourceResultType (elemType)->getAs <RecordType>();
1302- if (const auto *caType = astContext.getAsConstantArrayType (type)) {
1303- arraySize = static_cast <uint32_t >(caType->getSize ().getZExtValue ());
1304- } else {
1305- arraySize = -1 ;
1306- }
1307- } else {
1308- recordType = type->getAs <RecordType>();
1309- templatedType = hlsl::GetHLSLResourceResultType (type)->getAs <RecordType>();
1353+
1354+ llvm::SmallVector<int , 2 > arraySizes;
1355+ QualType actualType =
1356+ unrollMultiDimensionalArray (astContext, type, &arraySizes);
1357+ if (actualType.isNull ()) {
1358+ emitError (" encountered unsupported type while decomposing "
1359+ " multi-dimensional array" ,
1360+ decl->getLocStart ())
1361+ << type;
1362+ return nullptr ;
13101363 }
1364+
1365+ recordType = actualType->getAs <RecordType>();
1366+ templatedType =
1367+ hlsl::GetHLSLResourceResultType (actualType)->getAs <RecordType>();
1368+
13111369 if (!recordType) {
13121370 emitError (" constant/texture buffer type %0 unimplemented" ,
13131371 decl->getLocStart ())
@@ -1331,7 +1389,7 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
13311389 templatedType->getDecl ()->getName ().str ();
13321390
13331391 SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout (
1334- templatedType->getDecl (), arraySize , usageKind, structName,
1392+ templatedType->getDecl (), arraySizes , usageKind, structName,
13351393 decl->getName ());
13361394
13371395 // We register the VarDecl here.
0 commit comments