Skip to content

Commit 03cc4f8

Browse files
authored
[spirv] Add vk::image_format attribute for Buffers, RWBuffers and RWTextures (#3395)
According to Vulkan specification when using `OpImageRead/OpImageWrite`, the `OpTypeImage` (`Buffers`, `RWBuffers`, `RWTextures`) must have a format that matches the format on the API side, unless the StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat is added and `Unknown` is used as the format. This pull request addressess #2498 for the format part by adding an attribute `[[vk::image_format("<image format as spelled in SPIR-V spec>")]].` Example of the syntax: ``` [[vk::image_format("rgba8")]] RWBuffer<float4> Buf; [[vk::image_format("rg16f")]] RWTexture2D<float2> Tex; RWTexture2D<float2> Tex2; // Works like before ``` The `image_format` only applies to **global variables** of type `Buffer`, `RWBuffer`, `RWTexture`. For variables and function parameters it is propagated by the inlining pass in legalization. This required a small change to one of the passes in SPIRV-Tools, that should be also checked by someone more familiar with the codebase: KhronosGroup/SPIRV-Tools#4126 Note that this does not fix the handling of unspecified format (that case still works like before, using `R32f`, etc. based on the type in shader), although it should be still fixed to add the StorageImageReadWithoutFormat and/or StorageImageWriteWithoutFormat and use Undefined. But I think the ability to specify the format is more urgent. Design note from Jaebaek: Since the `image_format` attribute only applies to **global variables**, under the DXC architecture only `DeclResultIdMapper` can check the attribute when it handles `VarDecl`s. It means we have to pass the `image_format` information to `LowerTypeVisitor` because it cannot access to `VarDecl`. In order to pass the `image_format`, we use `SpirvContext` that can be accessed by `SpirvEmitter` and all visitors. We use `SpirvVariable` to `spv::ImageFormat` mapping because the attribute only applies to **global variables** (not to image types). See how we use `llvm::DenseMap<const SpirvVariable *, spv::ImageFormat> spvVarToImageFormat`.
1 parent 1ca7797 commit 03cc4f8

14 files changed

Lines changed: 478 additions & 11 deletions

File tree

docs/SPIR-V.rst

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,73 @@ are translated into SPIR-V ``OpTypeImage``, with parameters:
786786
The meanings of the headers in the above table is explained in ``OpTypeImage``
787787
of the SPIR-V spec.
788788

789+
Vulkan specific Image Formats
790+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
791+
792+
Since HLSL lacks the syntax for fully specifying image formats for textures in
793+
SPIR-V, we introduce ``[[vk::image_format("FORMAT")]]`` attribute for texture types.
794+
For example,
795+
796+
.. code:: hlsl
797+
[[vk::image_format("rgba8")]]
798+
RWBuffer<float4> Buf;
799+
800+
[[vk::image_format("rg16f")]]
801+
RWTexture2D<float2> Tex;
802+
803+
RWTexture2D<float2> Tex2; // Works like before
804+
805+
``rgba8`` means ``Rgba8`` `SPIR-V Image Format <https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_image_format_a_image_format>`_.
806+
The following table lists the mapping between ``FORMAT`` of
807+
``[[vk::image_format("FORMAT")]]`` and its corresponding SPIR-V Image Format.
808+
809+
======================= ============================================
810+
FORMAT SPIR-V Image Format
811+
======================= ============================================
812+
``unknown`` ``Unknown``
813+
``rgba32f`` ``Rgba32f``
814+
``rgba16f`` ``Rgba16f``
815+
``r32f`` ``R32f``
816+
``rgba8`` ``Rgba8``
817+
``rgba8snorm`` ``Rgba8Snorm``
818+
``rg32f`` ``Rg32f``
819+
``rg16f`` ``Rg16f``
820+
``r11g11b10f`` ``R11fG11fB10f``
821+
``r16f`` ``R16f``
822+
``rgba16`` ``Rgba16``
823+
``rgb10a2`` ``Rgb10A2``
824+
``rg16`` ``Rg16``
825+
``rg8`` ``Rg8``
826+
``r16`` ``R16``
827+
``r8`` ``R8``
828+
``rgba16snorm`` ``Rgba16Snorm``
829+
``rg16snorm`` ``Rg16Snorm``
830+
``rg8snorm`` ``Rg8Snorm``
831+
``r16snorm`` ``R16Snorm``
832+
``r8snorm`` ``R8Snorm``
833+
``rgba32i`` ``Rgba32i``
834+
``rgba16i`` ``Rgba16i``
835+
``rgba8i`` ``Rgba8i``
836+
``r32i`` ``R32i``
837+
``rg32i`` ``Rg32i``
838+
``rg16i`` ``Rg16i``
839+
``rg8i`` ``Rg8i``
840+
``r16i`` ``R16i``
841+
``r8i`` ``R8i``
842+
``rgba32ui`` ``Rgba32ui``
843+
``rgba16ui`` ``Rgba16ui``
844+
``rgba8ui`` ``Rgba8ui``
845+
``r32ui`` ``R32ui``
846+
``rgb10a2ui`` ``Rgb10a2ui``
847+
``rg32ui`` ``Rg32ui``
848+
``rg16ui`` ``Rg16ui``
849+
``rg8ui`` ``Rg8ui``
850+
``r16ui`` ``R16ui``
851+
``r8ui`` ``R8ui``
852+
``r64ui`` ``R64ui``
853+
``r64i`` ``R64i``
854+
======================= ============================================
855+
789856
Constant/Texture/Structured/Byte Buffers
790857
----------------------------------------
791858

tools/clang/include/clang/Basic/Attr.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,32 @@ def ConstantTextureBuffer
937937
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
938938
"TextureBuffer")}]>;
939939

940+
// Global variable with "RWTexture" type
941+
def RWTexture
942+
: SubsetSubject<
943+
Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
944+
S->getType()->getAs<RecordType>()->getDecl() &&
945+
(S->getType()->getAs<RecordType>()->getDecl()->getName() ==
946+
"RWTexture1D" ||
947+
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
948+
"RWTexture1DArray" ||
949+
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
950+
"RWTexture2D" ||
951+
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
952+
"RWTexture2DArray" ||
953+
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
954+
"RWTexture3D")}]>;
955+
956+
// Global variable with "[RW]Buffer" type
957+
def Buffer
958+
: SubsetSubject<
959+
Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
960+
S->getType()->getAs<RecordType>()->getDecl() &&
961+
(S->getType()->getAs<RecordType>()->getDecl()->getName() ==
962+
"Buffer" ||
963+
S->getType()->getAs<RecordType>()->getDecl()->getName() ==
964+
"RWBuffer")}]>;
965+
940966
def VKBuiltIn : InheritableAttr {
941967
let Spellings = [CXX11<"vk", "builtin">];
942968
let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;
@@ -997,6 +1023,29 @@ def VKOffset : InheritableAttr {
9971023
let Documentation = [Undocumented];
9981024
}
9991025

1026+
def VKImageFormat : InheritableAttr {
1027+
let Spellings = [CXX11<"vk", "image_format">];
1028+
let Subjects = SubjectList<[RWTexture, Buffer],
1029+
ErrorDiag, "ExpectedRWTextureOrBuffer">;
1030+
let Args = [EnumArgument<"ImageFormat", "ImageFormatType",
1031+
["unknown", "rgba32f", "rgba16f", "r32f", "rgba8", "rgba8snorm",
1032+
"rg32f", "rg16f", "r11g11b10f", "r16f", "rgba16", "rgb10a2",
1033+
"rg16", "rg8", "r16", "r8", "rgba16snorm", "rg16snorm", "rg8snorm",
1034+
"r16snorm", "r8snorm", "rgba32i", "rgba16i", "rgba8i", "r32i",
1035+
"rg32i", "rg16i", "rg8i", "r16i", "r8i", "rgba32ui", "rgba16ui", "rgba8ui",
1036+
"r32ui", "rgb10a2ui", "rg32ui", "rg16ui", "rg8ui", "r16ui",
1037+
"r8ui", "r64ui", "r64i"],
1038+
["unknown", "rgba32f", "rgba16f", "r32f", "rgba8", "rgba8snorm",
1039+
"rg32f", "rg16f", "r11g11b10f", "r16f", "rgba16", "rgb10a2",
1040+
"rg16", "rg8", "r16", "r8", "rgba16snorm", "rg16snorm", "rg8snorm",
1041+
"r16snorm", "r8snorm", "rgba32i", "rgba16i", "rgba8i", "r32i",
1042+
"rg32i", "rg16i", "rg8i", "r16i", "r8i", "rgba32ui", "rgba16ui", "rgba8ui",
1043+
"r32ui", "rgb10a2ui", "rg32ui", "rg16ui", "rg8ui", "r16ui",
1044+
"r8ui", "r64ui", "r64i"]>];
1045+
let LangOpts = [SPIRV];
1046+
let Documentation = [Undocumented];
1047+
}
1048+
10001049
def SubpassInput : SubsetSubject<
10011050
Var,
10021051
[{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,6 +2332,7 @@ def warn_attribute_wrong_decl_type : Warning<
23322332
"global variables of scalar type|"
23332333
"global variables of struct type|"
23342334
"global variables, cbuffers, and tbuffers|"
2335+
"RWTextures, Buffers and RWBuffers|"
23352336
"RWStructuredBuffers, AppendStructuredBuffers, and ConsumeStructuredBuffers|"
23362337
"SubpassInput, SubpassInputMS|"
23372338
"cbuffer or ConstantBuffer|"

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ class SpirvContext {
235235
ImageType::WithDepth, bool arrayed, bool ms,
236236
ImageType::WithSampler sampled,
237237
spv::ImageFormat);
238+
// Get ImageType whose attributes are the same with imageTypeWithUnknownFormat
239+
// but it has spv::ImageFormat format.
240+
const ImageType *getImageType(const ImageType *imageTypeWithUnknownFormat,
241+
spv::ImageFormat format);
238242
const SamplerType *getSamplerType() const { return samplerType; }
239243
const SampledImageType *getSampledImageType(const ImageType *image);
240244
const HybridSampledImageType *getSampledImageType(QualType image);
@@ -335,6 +339,20 @@ class SpirvContext {
335339
return currentLexicalScope;
336340
}
337341

342+
/// Function to add/get the mapping from a SPIR-V OpVariable to its image
343+
/// format.
344+
void registerImageFormatForSpirvVariable(const SpirvVariable *spvVar,
345+
spv::ImageFormat format) {
346+
assert(spvVar != nullptr);
347+
spvVarToImageFormat[spvVar] = format;
348+
}
349+
spv::ImageFormat getImageFormatForSpirvVariable(const SpirvVariable *spvVar) {
350+
auto itr = spvVarToImageFormat.find(spvVar);
351+
if (itr == spvVarToImageFormat.end())
352+
return spv::ImageFormat::Unknown;
353+
return itr->second;
354+
}
355+
338356
/// Function to add/get the mapping from a SPIR-V type to its Decl for
339357
/// a struct type.
340358
void registerStructDeclForSpirvType(const SpirvType *spvTy,
@@ -442,6 +460,9 @@ class SpirvContext {
442460
// Mapping from FunctionDecl to SPIR-V debug function.
443461
llvm::DenseMap<const FunctionDecl *, SpirvDebugFunction *>
444462
declToDebugFunction;
463+
464+
// Mapping from SPIR-V OpVariable to SPIR-V image format.
465+
llvm::DenseMap<const SpirvVariable *, spv::ImageFormat> spvVarToImageFormat;
445466
};
446467

447468
} // end namespace spirv

tools/clang/include/clang/Sema/AttributeList.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -854,18 +854,19 @@ enum AttributeDeclKind {
854854
ExpectedStructOrUnionOrTypedef,
855855
ExpectedStructOrTypedef,
856856
ExpectedObjectiveCInterfaceOrProtocol,
857-
ExpectedKernelFunction
857+
ExpectedKernelFunction,
858858
// SPIRV Change Begins
859-
,ExpectedField
860-
,ExpectedScalarGlobalVar
861-
,ExpectedStructGlobalVar
862-
,ExpectedGlobalVarOrCTBuffer
863-
,ExpectedCounterStructuredBuffer
864-
,ExpectedSubpassInput
865-
,ExpectedCTBuffer
859+
ExpectedField,
860+
ExpectedScalarGlobalVar,
861+
ExpectedStructGlobalVar,
862+
ExpectedGlobalVarOrCTBuffer,
863+
ExpectedRWTextureOrBuffer,
864+
ExpectedCounterStructuredBuffer,
865+
ExpectedSubpassInput,
866+
ExpectedCTBuffer,
866867
// SPIRV Change Ends
867868
// HLSL Change Begins - add attribute decl combinations
868-
,ExpectedVariableOrParam,
869+
ExpectedVariableOrParam,
869870
ExpectedFunctionOrParamOrField,
870871
ExpectedFunctionOrVariableOrParamOrFieldOrType
871872
// HLSL Change Ends

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,99 @@ SpirvLayoutRule getLayoutRuleForExternVar(QualType type,
398398
return SpirvLayoutRule::Void;
399399
}
400400

401+
spv::ImageFormat getSpvImageFormat(const VKImageFormatAttr *imageFormatAttr) {
402+
if (imageFormatAttr == nullptr)
403+
return spv::ImageFormat::Unknown;
404+
405+
switch (imageFormatAttr->getImageFormat()) {
406+
case VKImageFormatAttr::unknown:
407+
return spv::ImageFormat::Unknown;
408+
case VKImageFormatAttr::rgba32f:
409+
return spv::ImageFormat::Rgba32f;
410+
case VKImageFormatAttr::rgba16f:
411+
return spv::ImageFormat::Rgba16f;
412+
case VKImageFormatAttr::r32f:
413+
return spv::ImageFormat::R32f;
414+
case VKImageFormatAttr::rgba8:
415+
return spv::ImageFormat::Rgba8;
416+
case VKImageFormatAttr::rgba8snorm:
417+
return spv::ImageFormat::Rgba8Snorm;
418+
case VKImageFormatAttr::rg32f:
419+
return spv::ImageFormat::Rg32f;
420+
case VKImageFormatAttr::rg16f:
421+
return spv::ImageFormat::Rg16f;
422+
case VKImageFormatAttr::r11g11b10f:
423+
return spv::ImageFormat::R11fG11fB10f;
424+
case VKImageFormatAttr::r16f:
425+
return spv::ImageFormat::R16f;
426+
case VKImageFormatAttr::rgba16:
427+
return spv::ImageFormat::Rgba16;
428+
case VKImageFormatAttr::rgb10a2:
429+
return spv::ImageFormat::Rgb10A2;
430+
case VKImageFormatAttr::rg16:
431+
return spv::ImageFormat::Rg16;
432+
case VKImageFormatAttr::rg8:
433+
return spv::ImageFormat::Rg8;
434+
case VKImageFormatAttr::r16:
435+
return spv::ImageFormat::R16;
436+
case VKImageFormatAttr::r8:
437+
return spv::ImageFormat::R8;
438+
case VKImageFormatAttr::rgba16snorm:
439+
return spv::ImageFormat::Rgba16Snorm;
440+
case VKImageFormatAttr::rg16snorm:
441+
return spv::ImageFormat::Rg16Snorm;
442+
case VKImageFormatAttr::rg8snorm:
443+
return spv::ImageFormat::Rg8Snorm;
444+
case VKImageFormatAttr::r16snorm:
445+
return spv::ImageFormat::R16Snorm;
446+
case VKImageFormatAttr::r8snorm:
447+
return spv::ImageFormat::R8Snorm;
448+
case VKImageFormatAttr::rgba32i:
449+
return spv::ImageFormat::Rgba32i;
450+
case VKImageFormatAttr::rgba16i:
451+
return spv::ImageFormat::Rgba16i;
452+
case VKImageFormatAttr::rgba8i:
453+
return spv::ImageFormat::Rgba8i;
454+
case VKImageFormatAttr::r32i:
455+
return spv::ImageFormat::R32i;
456+
case VKImageFormatAttr::rg32i:
457+
return spv::ImageFormat::Rg32i;
458+
case VKImageFormatAttr::rg16i:
459+
return spv::ImageFormat::Rg16i;
460+
case VKImageFormatAttr::rg8i:
461+
return spv::ImageFormat::Rg8i;
462+
case VKImageFormatAttr::r16i:
463+
return spv::ImageFormat::R16i;
464+
case VKImageFormatAttr::r8i:
465+
return spv::ImageFormat::R8i;
466+
case VKImageFormatAttr::rgba32ui:
467+
return spv::ImageFormat::Rgba32ui;
468+
case VKImageFormatAttr::rgba16ui:
469+
return spv::ImageFormat::Rgba16ui;
470+
case VKImageFormatAttr::rgba8ui:
471+
return spv::ImageFormat::Rgba8ui;
472+
case VKImageFormatAttr::r32ui:
473+
return spv::ImageFormat::R32ui;
474+
case VKImageFormatAttr::rgb10a2ui:
475+
return spv::ImageFormat::Rgb10a2ui;
476+
case VKImageFormatAttr::rg32ui:
477+
return spv::ImageFormat::Rg32ui;
478+
case VKImageFormatAttr::rg16ui:
479+
return spv::ImageFormat::Rg16ui;
480+
case VKImageFormatAttr::rg8ui:
481+
return spv::ImageFormat::Rg8ui;
482+
case VKImageFormatAttr::r16ui:
483+
return spv::ImageFormat::R16ui;
484+
case VKImageFormatAttr::r8ui:
485+
return spv::ImageFormat::R8ui;
486+
case VKImageFormatAttr::r64ui:
487+
return spv::ImageFormat::R64ui;
488+
case VKImageFormatAttr::r64i:
489+
return spv::ImageFormat::R64i;
490+
}
491+
return spv::ImageFormat::Unknown;
492+
}
493+
401494
} // anonymous namespace
402495

403496
std::string StageVar::getSemanticStr() const {
@@ -847,6 +940,13 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
847940
type, storageClass, var->hasAttr<HLSLPreciseAttr>(), name, llvm::None,
848941
loc);
849942
varInstr->setLayoutRule(rule);
943+
944+
// If this variable has [[vk::image_format("..")]] attribute, we have to keep
945+
// it in the SpirvContext and use it when we lower the QualType to SpirvType.
946+
auto spvImageFormat = getSpvImageFormat(var->getAttr<VKImageFormatAttr>());
947+
if (spvImageFormat != spv::ImageFormat::Unknown)
948+
spvContext.registerImageFormatForSpirvVariable(varInstr, spvImageFormat);
949+
850950
DeclSpirvInfo info(varInstr);
851951
astDecls[var] = info;
852952

tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
121121
if (var->hasBinding() && var->getHlslUserType().empty()) {
122122
var->setHlslUserType(getHlslResourceTypeName(var->getAstResultType()));
123123
}
124+
125+
auto spvImageFormat = spvContext.getImageFormatForSpirvVariable(var);
126+
if (spvImageFormat != spv::ImageFormat::Unknown) {
127+
if (const auto *imageType = dyn_cast<ImageType>(resultType)) {
128+
resultType = spvContext.getImageType(imageType, spvImageFormat);
129+
instr->setResultType(resultType);
130+
}
131+
}
124132
}
125133
const SpirvType *pointerType =
126134
spvContext.getPointerType(resultType, instr->getStorageClass());

tools/clang/lib/SPIRV/SpirvContext.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,17 @@ const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
184184
return ptr;
185185
}
186186

187+
const ImageType *
188+
SpirvContext::getImageType(const ImageType *imageTypeWithUnknownFormat,
189+
spv::ImageFormat format) {
190+
return getImageType(imageTypeWithUnknownFormat->getSampledType(),
191+
imageTypeWithUnknownFormat->getDimension(),
192+
imageTypeWithUnknownFormat->getDepth(),
193+
imageTypeWithUnknownFormat->isArrayedImage(),
194+
imageTypeWithUnknownFormat->isMSImage(),
195+
imageTypeWithUnknownFormat->withSampler(), format);
196+
}
197+
187198
const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
188199
spv::Dim dim,
189200
ImageType::WithDepth depth,

0 commit comments

Comments
 (0)