@@ -5368,8 +5368,8 @@ class HLSLExternalSource : public ExternalSemaSource {
53685368 /// use for the signature, with the first being the return type.</remarks>
53695369 bool MatchArguments(const IntrinsicDefIter &cursor, QualType objectType,
53705370 QualType objectElement, QualType functionTemplateTypeArg,
5371- ArrayRef<Expr *> Args, std::vector<QualType> * ,
5372- size_t &badArgIdx);
5371+ unsigned functionTemplateIntArg, ArrayRef<Expr *> Args,
5372+ std::vector<QualType> *, size_t &badArgIdx);
53735373
53745374 /// <summary>Validate object element on intrinsic to catch case like integer
53755375 /// on Sample.</summary> <param name="tableName">Intrinsic function to
@@ -5418,6 +5418,19 @@ class HLSLExternalSource : public ExternalSemaSource {
54185418 nameIdentifier, argumentCount));
54195419 }
54205420
5421+ static unsigned GetIntegralTemplateArg(ASTContext &context,
5422+ const TemplateArgument &arg) {
5423+ if (arg.getKind() == TemplateArgument::Integral)
5424+ return arg.getAsIntegral().getZExtValue();
5425+ if (arg.getKind() == TemplateArgument::Expression) {
5426+ llvm::APSInt result;
5427+ Expr *expr = arg.getAsExpr();
5428+ if (expr != nullptr && expr->isIntegerConstantExpr(result, context))
5429+ return result.getZExtValue();
5430+ }
5431+ return 0;
5432+ }
5433+
54215434 bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
54225435 ArrayRef<Expr *> Args,
54235436 OverloadCandidateSet &CandidateSet, Scope *S,
@@ -5512,11 +5525,22 @@ class HLSLExternalSource : public ExternalSemaSource {
55125525 "otherwise g_MaxIntrinsicParamCount needs to be updated for "
55135526 "wider signatures");
55145527
5528+ QualType templateTypeArg;
5529+ unsigned templateIntArg = 0;
55155530 std::vector<QualType> functionArgTypes;
55165531 size_t badArgIdx;
5532+ if (ULE->hasExplicitTemplateArgs() && ULE->getNumTemplateArgs() >= 1) {
5533+ const TemplateArgumentLoc &TypeArgLoc = ULE->getTemplateArgs()[0];
5534+ if (TypeArgLoc.getArgument().getKind() == TemplateArgument::Type)
5535+ templateTypeArg = TypeArgLoc.getArgument().getAsType();
5536+ if (ULE->getNumTemplateArgs() >= 2)
5537+ templateIntArg = GetIntegralTemplateArg(
5538+ *m_context, ULE->getTemplateArgs()[1].getArgument());
5539+ }
5540+
55175541 bool argsMatch =
5518- MatchArguments(cursor, QualType(), QualType(), QualType(), Args ,
5519- &functionArgTypes, badArgIdx);
5542+ MatchArguments(cursor, QualType(), QualType(), templateTypeArg ,
5543+ templateIntArg, Args, &functionArgTypes, badArgIdx);
55205544 if (!functionArgTypes.size())
55215545 return false;
55225546
@@ -6925,8 +6949,9 @@ bool HLSLExternalSource::IsValidObjectElement(LPCSTR tableName,
69256949
69266950bool HLSLExternalSource::MatchArguments(
69276951 const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement,
6928- QualType functionTemplateTypeArg, ArrayRef<Expr *> Args,
6929- std::vector<QualType> *argTypesVector, size_t &badArgIdx) {
6952+ QualType functionTemplateTypeArg, unsigned functionTemplateIntArg,
6953+ ArrayRef<Expr *> Args, std::vector<QualType> *argTypesVector,
6954+ size_t &badArgIdx) {
69306955 const HLSL_INTRINSIC *pIntrinsic = *cursor;
69316956 LPCSTR tableName = cursor.GetTableName();
69326957 IntrinsicOp builtinOp = IntrinsicOp::Num_Intrinsics;
@@ -7418,7 +7443,45 @@ bool HLSLExternalSource::MatchArguments(
74187443 if (i == 0 &&
74197444 (builtinOp == hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast ||
74207445 builtinOp == hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast)) {
7421- pNewType = Args[0]->getType();
7446+ if (functionTemplateTypeArg.isNull()) {
7447+ badArgIdx = std::min(badArgIdx, i);
7448+ continue;
7449+ }
7450+
7451+ // Build BufferPointer<T, A> where T is the template type argument and
7452+ // A is the template alignment argument (or the alignment of the
7453+ // source pointer if none is given).
7454+ unsigned srcAlignment =
7455+ functionTemplateIntArg
7456+ ? functionTemplateIntArg
7457+ : hlsl::GetVKBufferPointerAlignment(Args[0]->getType());
7458+ TemplateArgument TemplateArgs[] = {
7459+ TemplateArgument(functionTemplateTypeArg),
7460+ TemplateArgument(*m_context,
7461+ llvm::APSInt(llvm::APInt(32, srcAlignment)),
7462+ m_context->UnsignedIntTy)};
7463+ void *InsertPos = nullptr;
7464+ ClassTemplateSpecializationDecl *Spec =
7465+ m_vkBufferPointerTemplateDecl->findSpecialization(
7466+ llvm::ArrayRef<TemplateArgument>(TemplateArgs, 2), InsertPos);
7467+ if (!Spec) {
7468+ Spec = ClassTemplateSpecializationDecl::Create(
7469+ *m_context, TagDecl::TagKind::TTK_Struct,
7470+ m_vkBufferPointerTemplateDecl->getDeclContext(), SourceLocation(),
7471+ SourceLocation(), m_vkBufferPointerTemplateDecl, TemplateArgs, 2,
7472+ nullptr);
7473+ m_vkBufferPointerTemplateDecl->AddSpecialization(Spec, InsertPos);
7474+ Spec->setImplicit(true);
7475+ DXVERIFY_NOMSG(
7476+ false ==
7477+ getSema()->InstantiateClassTemplateSpecialization(
7478+ SourceLocation(), Spec,
7479+ TemplateSpecializationKind::TSK_ImplicitInstantiation, true));
7480+ }
7481+
7482+ pNewType = m_context->getTemplateSpecializationType(
7483+ TemplateName(m_vkBufferPointerTemplateDecl), TemplateArgs, 2,
7484+ m_context->getTypeDeclType(Spec));
74227485 } else {
74237486 badArgIdx = std::min(badArgIdx, i);
74247487 }
@@ -11155,11 +11218,18 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1115511218 QualType objectType = m_context->getTagDeclType(functionParentRecord);
1115611219
1115711220 QualType functionTemplateTypeArg{};
11158- if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() == 1) {
11221+ unsigned functionTemplateIntArg = 0;
11222+ if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() >= 1) {
1115911223 const TemplateArgument &firstTemplateArg =
1116011224 (*ExplicitTemplateArgs)[0].getArgument();
1116111225 if (firstTemplateArg.getKind() == TemplateArgument::ArgKind::Type)
1116211226 functionTemplateTypeArg = firstTemplateArg.getAsType();
11227+ if (ExplicitTemplateArgs->size() > 1) {
11228+ const TemplateArgument &secondTemplateArg =
11229+ (*ExplicitTemplateArgs)[1].getArgument();
11230+ functionTemplateIntArg =
11231+ GetIntegralTemplateArg(*m_context, secondTemplateArg);
11232+ }
1116311233 }
1116411234
1116511235 // Handle subscript overloads.
@@ -11233,7 +11303,8 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1123311303 while (cursor != end) {
1123411304 size_t badArgIdx;
1123511305 if (!MatchArguments(cursor, objectType, objectElement,
11236- functionTemplateTypeArg, Args, &argTypes, badArgIdx)) {
11306+ functionTemplateTypeArg, functionTemplateIntArg, Args,
11307+ &argTypes, badArgIdx)) {
1123711308 ++cursor;
1123811309 continue;
1123911310 }
@@ -11276,8 +11347,9 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
1127611347 if (!IsNull &&
1127711348 getSema()->RequireCompleteType(Loc, functionTemplateTypeArg, 0))
1127811349 return Sema::TemplateDeductionResult::TDK_Invalid;
11279- if (IsNull || !hlsl::IsHLSLNumericOrAggregateOfNumericType(
11280- functionTemplateTypeArg)) {
11350+ if (IsNull || ExplicitTemplateArgs->size() > 1 ||
11351+ !hlsl::IsHLSLNumericOrAggregateOfNumericType(
11352+ functionTemplateTypeArg)) {
1128111353 getSema()->Diag(Loc, diag::err_hlsl_intrinsic_template_arg_numeric)
1128211354 << intrinsicName;
1128311355 DiagnoseTypeElements(
@@ -12007,8 +12079,18 @@ static bool CheckBarrierCall(Sema &S, FunctionDecl *FD, CallExpr *CE,
1200712079}
1200812080
1200912081#ifdef ENABLE_SPIRV_CODEGEN
12010- static bool CheckVKBufferPointerCast(Sema &S, FunctionDecl *FD, CallExpr *CE,
12011- bool isStatic) {
12082+ static bool CheckVKBufferPointerCast(Sema &S, CallExpr *CE, bool isStatic) {
12083+ const auto *callee = dyn_cast<DeclRefExpr>(CE->getCallee()->IgnoreImpCasts());
12084+ if (callee && callee->hasExplicitTemplateArgs() &&
12085+ callee->getNumTemplateArgs() > 2) {
12086+ StringRef castName =
12087+ isStatic ? "static_pointer_cast" : "reinterpret_pointer_cast";
12088+ S.Diags.Report(CE->getExprLoc(),
12089+ diag::err_template_arg_list_different_arity)
12090+ << /*too many*/ 1 << /*function template*/ 1 << castName;
12091+ return true;
12092+ }
12093+
1201212094 const Expr *argExpr = CE->getArg(0);
1201312095 QualType srcType = argExpr->getType();
1201412096 QualType destType = CE->getType();
@@ -12146,10 +12228,10 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) {
1214612228 break;
1214712229#ifdef ENABLE_SPIRV_CODEGEN
1214812230 case hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast:
12149- CheckVKBufferPointerCast(*this, FDecl, TheCall, false);
12231+ CheckVKBufferPointerCast(*this, TheCall, false);
1215012232 break;
1215112233 case hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast:
12152- CheckVKBufferPointerCast(*this, FDecl, TheCall, true);
12234+ CheckVKBufferPointerCast(*this, TheCall, true);
1215312235 break;
1215412236#endif
1215512237 default:
0 commit comments