Skip to content

Commit 2eae8d3

Browse files
jiaolujaebaek
andauthored
[SPIRV] Add support of [[vk::ext_type_def]] (#4068)
Support [[vk::ext_type_def]] and vk::ext_type. This is related #3919 Co-authored-by: Jaebaek Seo <[email protected]>
1 parent 676fe64 commit 2eae8d3

20 files changed

Lines changed: 345 additions & 37 deletions

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(
340340

341341
clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
342342
clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
343+
clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext(
344+
clang::ASTContext &context, clang::DeclContext *declContext,
345+
llvm::StringRef typeName, llvm::StringRef templateParamName);
343346
clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf);
344347
clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context);
345348
clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context,

tools/clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr {
11451145
let Documentation = [Undocumented];
11461146
}
11471147

1148+
def VKTypeDefExt : InheritableAttr {
1149+
let Spellings = [CXX11<"vk", "ext_type_def">];
1150+
let Subjects = SubjectList<[Function], ErrorDiag>;
1151+
let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">];
1152+
let LangOpts = [SPIRV];
1153+
let Documentation = [Undocumented];
1154+
}
1155+
11481156
// Global variables that are of scalar type
11491157
def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;
11501158

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,12 @@ class SpirvContext {
288288
return rayQueryTypeKHR;
289289
}
290290

291+
const SpirvIntrinsicType *
292+
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
293+
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
294+
295+
SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);
296+
291297
/// --- Hybrid type getter functions ---
292298
///
293299
/// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid
@@ -467,6 +473,7 @@ class SpirvContext {
467473
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
468474
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
469475
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
476+
llvm::DenseMap<unsigned, SpirvIntrinsicType*> spirvIntrinsicTypes;
470477
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
471478
const RayQueryTypeKHR *rayQueryTypeKHR;
472479

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction {
11101110
}
11111111

11121112
bool isSpecConstant() const;
1113+
void setLiteral(bool literal = true) { literalConstant = literal; }
1114+
bool isLiteral() { return literalConstant; }
11131115

11141116
protected:
1115-
SpirvConstant(Kind, spv::Op, const SpirvType *);
1116-
SpirvConstant(Kind, spv::Op, QualType);
1117+
SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false);
1118+
SpirvConstant(Kind, spv::Op, QualType, bool literal = false);
1119+
bool literalConstant;
11171120
};
11181121

11191122
class SpirvConstantBoolean : public SpirvConstant {
@@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant {
11411144
class SpirvConstantInteger : public SpirvConstant {
11421145
public:
11431146
SpirvConstantInteger(QualType type, llvm::APInt value,
1144-
bool isSpecConst = false, bool literal = false);
1147+
bool isSpecConst = false);
11451148

11461149
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)
11471150

@@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant {
11551158
bool invokeVisitor(Visitor *v) override;
11561159

11571160
llvm::APInt getValue() const { return value; }
1158-
void setLiteral(bool l = true) { isLiteral = l; }
1159-
bool getLiteral() { return isLiteral; }
11601161

11611162
private:
11621163
llvm::APInt value;
1163-
bool isLiteral;
11641164
};
11651165

11661166
class SpirvConstantFloat : public SpirvConstant {

tools/clang/include/clang/SPIRV/SpirvType.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class SpirvType {
4949
TK_Function,
5050
TK_AccelerationStructureNV,
5151
TK_RayQueryKHR,
52+
TK_SpirvIntrinsicType,
5253
// Order matters: all the following are hybrid types
5354
TK_HybridStruct,
5455
TK_HybridPointer,
@@ -412,6 +413,37 @@ class RayQueryTypeKHR : public SpirvType {
412413
}
413414
};
414415

416+
class SpirvInstruction;
417+
struct SpvIntrinsicTypeOperand {
418+
SpvIntrinsicTypeOperand(SpirvType *type_operand)
419+
: operand_as_type(type_operand), isTypeOperand(true) {}
420+
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
421+
: operand_as_inst(inst_operand), isTypeOperand(false) {}
422+
union {
423+
SpirvType *operand_as_type;
424+
SpirvInstruction *operand_as_inst;
425+
};
426+
bool isTypeOperand;
427+
};
428+
429+
class SpirvIntrinsicType : public SpirvType {
430+
public:
431+
SpirvIntrinsicType(unsigned typeOp,
432+
llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps);
433+
434+
static bool classof(const SpirvType *t) {
435+
return t->getKind() == TK_SpirvIntrinsicType;
436+
}
437+
unsigned getOpCode() const { return typeOpCode; }
438+
llvm::ArrayRef<SpvIntrinsicTypeOperand> getOperands() const {
439+
return operands;
440+
}
441+
442+
private:
443+
unsigned typeOpCode;
444+
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;
445+
};
446+
415447
class HybridType : public SpirvType {
416448
public:
417449
static bool classof(const SpirvType *t) {

tools/clang/lib/AST/ASTContextHLSL.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,15 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
840840

841841
CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
842842
ASTContext& context, StringRef typeName, StringRef templateParamName) {
843+
return DeclareUIntTemplatedTypeWithHandleInDeclContext(
844+
context, context.getTranslationUnitDecl(), typeName, templateParamName);
845+
}
846+
847+
CXXRecordDecl *hlsl::DeclareUIntTemplatedTypeWithHandleInDeclContext(
848+
ASTContext &context, DeclContext *declContext, StringRef typeName,
849+
StringRef templateParamName) {
843850
// template<uint kind> FeedbackTexture2D[Array] { ... }
844-
BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
851+
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName);
845852
typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
846853
typeDeclBuilder.startDefinition();
847854
typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
529529
}
530530
case spv::Op::OpRayQueryInitializeKHR: {
531531
auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
532-
if (rayQueryInst->hasCullFlags()) {
533-
addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR);
532+
if (rayQueryInst && rayQueryInst->hasCullFlags()) {
533+
addCapability(
534+
spv::Capability::RayTraversalPrimitiveCullingKHR);
534535
}
535536

536537
break;

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
18841884
}
18851885

18861886
for (const auto operand : inst->getOperands()) {
1887-
// TODO: Handle Literals with other types.
1888-
auto literalOperand = dyn_cast<SpirvConstantInteger>(operand);
1889-
if (literalOperand && literalOperand->getLiteral()) {
1890-
curInst.push_back(literalOperand->getValue().getZExtValue());
1887+
auto literalOperand = dyn_cast<SpirvConstant>(operand);
1888+
if (literalOperand && literalOperand->isLiteral()) {
1889+
typeHandler.emitLiteral(literalOperand, curInst);
18911890
} else {
18921891
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
18931892
}
@@ -2451,6 +2450,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
24512450
initTypeInstruction(spv::Op::OpTypeRayQueryKHR);
24522451
curTypeInst.push_back(id);
24532452
finalizeTypeInstruction();
2453+
} else if (const auto *spvIntrinsicType =
2454+
dyn_cast<SpirvIntrinsicType>(type)) {
2455+
initTypeInstruction(static_cast<spv::Op>(spvIntrinsicType->getOpCode()));
2456+
curTypeInst.push_back(id);
2457+
for (const SpvIntrinsicTypeOperand &operand :
2458+
spvIntrinsicType->getOperands()) {
2459+
if (operand.isTypeOperand) {
2460+
curTypeInst.push_back(emitType(operand.operand_as_type));
2461+
} else {
2462+
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
2463+
if (literal && literal->isLiteral()) {
2464+
emitLiteral(literal, curTypeInst);
2465+
} else {
2466+
curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst));
2467+
}
2468+
}
2469+
}
2470+
finalizeTypeInstruction();
24542471
}
24552472
// Hybrid Types
24562473
// Note: The type lowering pass should lower all types to SpirvTypes.
@@ -2467,6 +2484,50 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
24672484
return id;
24682485
}
24692486

2487+
template <typename vecType>
2488+
void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral,
2489+
vecType &outInst) {
2490+
const auto &literalVal = intLiteral->getValue();
2491+
bool positive = !literalVal.isNegative();
2492+
if (literalVal.getBitWidth() <= 32) {
2493+
outInst.push_back(positive ? literalVal.getZExtValue()
2494+
: literalVal.getSExtValue());
2495+
} else {
2496+
assert(literalVal.getBitWidth() == 64);
2497+
uint64_t val =
2498+
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
2499+
outInst.push_back(static_cast<unsigned>(val));
2500+
outInst.push_back(static_cast<unsigned>(val >> 32));
2501+
}
2502+
}
2503+
2504+
template <typename vecType>
2505+
void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral,
2506+
vecType &outInst) {
2507+
const auto &literalVal = fLiteral->getValue();
2508+
const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics());
2509+
if (bitwidth <= 32) {
2510+
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
2511+
} else {
2512+
assert(bitwidth == 64);
2513+
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
2514+
outInst.push_back(static_cast<unsigned>(val));
2515+
outInst.push_back(static_cast<unsigned>(val >> 32));
2516+
}
2517+
}
2518+
2519+
template <typename VecType>
2520+
void EmitTypeHandler::emitLiteral(const SpirvConstant *literal,
2521+
VecType &outInst) {
2522+
if (auto boolLiteral = dyn_cast<SpirvConstantBoolean>(literal)) {
2523+
outInst.push_back(static_cast<unsigned>(boolLiteral->getValue()));
2524+
} else if (auto intLiteral = dyn_cast<SpirvConstantInteger>(literal)) {
2525+
emitIntLiteral(intLiteral, outInst);
2526+
} else if (auto fLiteral = dyn_cast<SpirvConstantFloat>(literal)) {
2527+
emitFloatLiteral(fLiteral, outInst);
2528+
}
2529+
}
2530+
24702531
void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
24712532
spv::Decoration decoration,
24722533
llvm::ArrayRef<uint32_t> decorationParams,

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ class EmitTypeHandler {
109109
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
110110
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
111111
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
112+
template <typename vecType>
113+
void emitLiteral(const SpirvConstant *, vecType &outInst);
114+
template <typename vecType>
115+
void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst);
116+
template <typename vecType>
117+
void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst);
112118

113119
private:
114120
void initTypeInstruction(spv::Op op);

tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
593593
if (name == "RayQuery")
594594
return spvContext.getRayQueryTypeKHR();
595595

596+
if (name == "ext_type") {
597+
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
598+
return spvContext.getCreatedSpirvIntrinsicType(typeId);
599+
}
600+
596601
if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
597602
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
598603
// StructureBuffer<S> will be translated into an OpTypeStruct with one

0 commit comments

Comments
 (0)