Skip to content

Commit ec9d97a

Browse files
authored
[SPIRV] Fixes vk::BufferPointer cast methods. (#8365)
Addresses #7891.
1 parent e718062 commit ec9d97a

3 files changed

Lines changed: 143 additions & 17 deletions

File tree

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5368,8 +5368,8 @@ class HLSLExternalSource : public ExternalSemaSource {
53685368
/// use for the signature, with the first being the return type.</remarks>
53695369
bool MatchArguments(const IntrinsicDefIter &cursor, QualType objectType,
53705370
QualType objectElement, QualType functionTemplateTypeArg,
5371-
ArrayRef<Expr *> Args, std::vector<QualType> *,
5372-
size_t &badArgIdx);
5371+
unsigned functionTemplateIntArg, ArrayRef<Expr *> Args,
5372+
std::vector<QualType> *, size_t &badArgIdx);
53735373

53745374
/// <summary>Validate object element on intrinsic to catch case like integer
53755375
/// on Sample.</summary> <param name="tableName">Intrinsic function to
@@ -5418,6 +5418,19 @@ class HLSLExternalSource : public ExternalSemaSource {
54185418
nameIdentifier, argumentCount));
54195419
}
54205420

5421+
static unsigned GetIntegralTemplateArg(ASTContext &context,
5422+
const TemplateArgument &arg) {
5423+
if (arg.getKind() == TemplateArgument::Integral)
5424+
return arg.getAsIntegral().getZExtValue();
5425+
if (arg.getKind() == TemplateArgument::Expression) {
5426+
llvm::APSInt result;
5427+
Expr *expr = arg.getAsExpr();
5428+
if (expr != nullptr && expr->isIntegerConstantExpr(result, context))
5429+
return result.getZExtValue();
5430+
}
5431+
return 0;
5432+
}
5433+
54215434
bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
54225435
ArrayRef<Expr *> Args,
54235436
OverloadCandidateSet &CandidateSet, Scope *S,
@@ -5512,11 +5525,22 @@ class HLSLExternalSource : public ExternalSemaSource {
55125525
"otherwise g_MaxIntrinsicParamCount needs to be updated for "
55135526
"wider signatures");
55145527

5528+
QualType templateTypeArg;
5529+
unsigned templateIntArg = 0;
55155530
std::vector<QualType> functionArgTypes;
55165531
size_t badArgIdx;
5532+
if (ULE->hasExplicitTemplateArgs() && ULE->getNumTemplateArgs() >= 1) {
5533+
const TemplateArgumentLoc &TypeArgLoc = ULE->getTemplateArgs()[0];
5534+
if (TypeArgLoc.getArgument().getKind() == TemplateArgument::Type)
5535+
templateTypeArg = TypeArgLoc.getArgument().getAsType();
5536+
if (ULE->getNumTemplateArgs() >= 2)
5537+
templateIntArg = GetIntegralTemplateArg(
5538+
*m_context, ULE->getTemplateArgs()[1].getArgument());
5539+
}
5540+
55175541
bool argsMatch =
5518-
MatchArguments(cursor, QualType(), QualType(), QualType(), Args,
5519-
&functionArgTypes, badArgIdx);
5542+
MatchArguments(cursor, QualType(), QualType(), templateTypeArg,
5543+
templateIntArg, Args, &functionArgTypes, badArgIdx);
55205544
if (!functionArgTypes.size())
55215545
return false;
55225546

@@ -6925,8 +6949,9 @@ bool HLSLExternalSource::IsValidObjectElement(LPCSTR tableName,
69256949

69266950
bool HLSLExternalSource::MatchArguments(
69276951
const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement,
6928-
QualType functionTemplateTypeArg, ArrayRef<Expr *> Args,
6929-
std::vector<QualType> *argTypesVector, size_t &badArgIdx) {
6952+
QualType functionTemplateTypeArg, unsigned functionTemplateIntArg,
6953+
ArrayRef<Expr *> Args, std::vector<QualType> *argTypesVector,
6954+
size_t &badArgIdx) {
69306955
const HLSL_INTRINSIC *pIntrinsic = *cursor;
69316956
LPCSTR tableName = cursor.GetTableName();
69326957
IntrinsicOp builtinOp = IntrinsicOp::Num_Intrinsics;
@@ -7418,7 +7443,45 @@ bool HLSLExternalSource::MatchArguments(
74187443
if (i == 0 &&
74197444
(builtinOp == hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast ||
74207445
builtinOp == hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast)) {
7421-
pNewType = Args[0]->getType();
7446+
if (functionTemplateTypeArg.isNull()) {
7447+
badArgIdx = std::min(badArgIdx, i);
7448+
continue;
7449+
}
7450+
7451+
// Build BufferPointer<T, A> where T is the template type argument and
7452+
// A is the template alignment argument (or the alignment of the
7453+
// source pointer if none is given).
7454+
unsigned srcAlignment =
7455+
functionTemplateIntArg
7456+
? functionTemplateIntArg
7457+
: hlsl::GetVKBufferPointerAlignment(Args[0]->getType());
7458+
TemplateArgument TemplateArgs[] = {
7459+
TemplateArgument(functionTemplateTypeArg),
7460+
TemplateArgument(*m_context,
7461+
llvm::APSInt(llvm::APInt(32, srcAlignment)),
7462+
m_context->UnsignedIntTy)};
7463+
void *InsertPos = nullptr;
7464+
ClassTemplateSpecializationDecl *Spec =
7465+
m_vkBufferPointerTemplateDecl->findSpecialization(
7466+
llvm::ArrayRef<TemplateArgument>(TemplateArgs, 2), InsertPos);
7467+
if (!Spec) {
7468+
Spec = ClassTemplateSpecializationDecl::Create(
7469+
*m_context, TagDecl::TagKind::TTK_Struct,
7470+
m_vkBufferPointerTemplateDecl->getDeclContext(), SourceLocation(),
7471+
SourceLocation(), m_vkBufferPointerTemplateDecl, TemplateArgs, 2,
7472+
nullptr);
7473+
m_vkBufferPointerTemplateDecl->AddSpecialization(Spec, InsertPos);
7474+
Spec->setImplicit(true);
7475+
DXVERIFY_NOMSG(
7476+
false ==
7477+
getSema()->InstantiateClassTemplateSpecialization(
7478+
SourceLocation(), Spec,
7479+
TemplateSpecializationKind::TSK_ImplicitInstantiation, true));
7480+
}
7481+
7482+
pNewType = m_context->getTemplateSpecializationType(
7483+
TemplateName(m_vkBufferPointerTemplateDecl), TemplateArgs, 2,
7484+
m_context->getTypeDeclType(Spec));
74227485
} else {
74237486
badArgIdx = std::min(badArgIdx, i);
74247487
}
@@ -11155,11 +11218,18 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1115511218
QualType objectType = m_context->getTagDeclType(functionParentRecord);
1115611219

1115711220
QualType functionTemplateTypeArg{};
11158-
if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() == 1) {
11221+
unsigned functionTemplateIntArg = 0;
11222+
if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() >= 1) {
1115911223
const TemplateArgument &firstTemplateArg =
1116011224
(*ExplicitTemplateArgs)[0].getArgument();
1116111225
if (firstTemplateArg.getKind() == TemplateArgument::ArgKind::Type)
1116211226
functionTemplateTypeArg = firstTemplateArg.getAsType();
11227+
if (ExplicitTemplateArgs->size() > 1) {
11228+
const TemplateArgument &secondTemplateArg =
11229+
(*ExplicitTemplateArgs)[1].getArgument();
11230+
functionTemplateIntArg =
11231+
GetIntegralTemplateArg(*m_context, secondTemplateArg);
11232+
}
1116311233
}
1116411234

1116511235
// Handle subscript overloads.
@@ -11233,7 +11303,8 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1123311303
while (cursor != end) {
1123411304
size_t badArgIdx;
1123511305
if (!MatchArguments(cursor, objectType, objectElement,
11236-
functionTemplateTypeArg, Args, &argTypes, badArgIdx)) {
11306+
functionTemplateTypeArg, functionTemplateIntArg, Args,
11307+
&argTypes, badArgIdx)) {
1123711308
++cursor;
1123811309
continue;
1123911310
}
@@ -11276,8 +11347,9 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1127611347
if (!IsNull &&
1127711348
getSema()->RequireCompleteType(Loc, functionTemplateTypeArg, 0))
1127811349
return Sema::TemplateDeductionResult::TDK_Invalid;
11279-
if (IsNull || !hlsl::IsHLSLNumericOrAggregateOfNumericType(
11280-
functionTemplateTypeArg)) {
11350+
if (IsNull || ExplicitTemplateArgs->size() > 1 ||
11351+
!hlsl::IsHLSLNumericOrAggregateOfNumericType(
11352+
functionTemplateTypeArg)) {
1128111353
getSema()->Diag(Loc, diag::err_hlsl_intrinsic_template_arg_numeric)
1128211354
<< intrinsicName;
1128311355
DiagnoseTypeElements(
@@ -12007,8 +12079,18 @@ static bool CheckBarrierCall(Sema &S, FunctionDecl *FD, CallExpr *CE,
1200712079
}
1200812080

1200912081
#ifdef ENABLE_SPIRV_CODEGEN
12010-
static bool CheckVKBufferPointerCast(Sema &S, FunctionDecl *FD, CallExpr *CE,
12011-
bool isStatic) {
12082+
static bool CheckVKBufferPointerCast(Sema &S, CallExpr *CE, bool isStatic) {
12083+
const auto *callee = dyn_cast<DeclRefExpr>(CE->getCallee()->IgnoreImpCasts());
12084+
if (callee && callee->hasExplicitTemplateArgs() &&
12085+
callee->getNumTemplateArgs() > 2) {
12086+
StringRef castName =
12087+
isStatic ? "static_pointer_cast" : "reinterpret_pointer_cast";
12088+
S.Diags.Report(CE->getExprLoc(),
12089+
diag::err_template_arg_list_different_arity)
12090+
<< /*too many*/ 1 << /*function template*/ 1 << castName;
12091+
return true;
12092+
}
12093+
1201212094
const Expr *argExpr = CE->getArg(0);
1201312095
QualType srcType = argExpr->getType();
1201412096
QualType destType = CE->getType();
@@ -12146,10 +12228,10 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) {
1214612228
break;
1214712229
#ifdef ENABLE_SPIRV_CODEGEN
1214812230
case hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast:
12149-
CheckVKBufferPointerCast(*this, FDecl, TheCall, false);
12231+
CheckVKBufferPointerCast(*this, TheCall, false);
1215012232
break;
1215112233
case hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast:
12152-
CheckVKBufferPointerCast(*this, FDecl, TheCall, true);
12234+
CheckVKBufferPointerCast(*this, TheCall, true);
1215312235
break;
1215412236
#endif
1215512237
default:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 %s | FileCheck %s --check-prefix=GOOD
2+
// RUN: not %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 -DBAD %s 2>&1 | FileCheck %s --check-prefix=BAD
3+
4+
struct Base {
5+
};
6+
7+
struct Derived : Base {
8+
int val;
9+
};
10+
11+
cbuffer Test {
12+
vk::BufferPointer<Derived, 32> derivedBuf;
13+
vk::BufferPointer<int> intBuf;
14+
};
15+
16+
[shader("compute")]
17+
[numthreads(256, 1, 1)]
18+
void main(in uint3 threadId : SV_DispatchThreadID) {
19+
#ifdef BAD
20+
vk::BufferPointer<Base, 64> derivedBufAsBase = vk::static_pointer_cast<Base, 64>(derivedBuf);
21+
vk::BufferPointer<float> intBufAsFloat = vk::static_pointer_cast<float>(intBuf);
22+
#else
23+
vk::BufferPointer<Base, 16> derivedBufAsBase = vk::static_pointer_cast<Base, 16>(derivedBuf);
24+
vk::BufferPointer<float> intBufAsFloat = vk::reinterpret_pointer_cast<float>(intBuf);
25+
#endif
26+
27+
intBuf.Get() = (int)intBufAsFloat.Get();
28+
}
29+
30+
// GOOD: [[INT:%[^ ]*]] = OpTypeInt 32 1
31+
// GOOD: [[I1:%[^ ]*]] = OpConstant [[INT]] 1
32+
// GOOD: [[PINT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[INT]]
33+
// GOOD: [[FLOAT:%[^ ]*]] = OpTypeFloat 32
34+
// GOOD: [[PFLOAT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[FLOAT]]
35+
// GOOD: %Test = OpVariable %{{[^ ]*}} Uniform
36+
// GOOD: [[V0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %Test [[I1]]
37+
// GOOD: [[V1:%[^ ]*]] = OpLoad [[PINT]] [[V0]]
38+
// GOOD: [[V2:%[^ ]*]] = OpBitcast [[PFLOAT]] [[V1]]
39+
// GOOD: [[V3:%[^ ]*]] = OpLoad [[FLOAT]] [[V2]] Aligned 4
40+
// GOOD: [[V4:%[^ ]*]] = OpConvertFToS [[INT]] [[V3]]
41+
// GOOD: OpStore [[V1]] [[V4]] Aligned 4
42+
43+
// BAD: error: Vulkan buffer pointer cannot be cast to greater alignment
44+
// BAD: error: vk::static_pointer_cast() content type must be base class of argument's content type
45+

tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.linked-list.hlsl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ struct TestPushConstant_t
5353
float4 MainPs(void) : SV_Target0
5454
{
5555
if (__has_feature(hlsl_vk_buffer_pointer)) {
56-
[[vk::aliased_pointer]] block_p g_p =
57-
vk::static_pointer_cast<block_t, 16>(g_PushConstants.root);
56+
[[vk::aliased_pointer]] block_p g_p = g_PushConstants.root;
5857
g_p = g_p.Get().next;
5958
uint64_t addr = (uint64_t)g_p;
6059
block_p copy1 = block_p(addr);

0 commit comments

Comments
 (0)