diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index c7a031a219..9e704f7e78 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -1386,19 +1386,27 @@ CXXRecordDecl *hlsl::DeclareVkBufferPointerType(ASTContext &context, DeclarationName(&context.Idents.get("Get")), true); CanQualType canQualType = recordDecl->getTypeForDecl()->getCanonicalTypeUnqualified(); - CreateConstructorDeclarationWithParams( + auto *copyConstructorDecl = CreateConstructorDeclarationWithParams( context, recordDecl, context.VoidTy, {context.getRValueReferenceType(canQualType)}, {"bufferPointer"}, - context.DeclarationNames.getCXXConstructorName(canQualType), false); - CreateConstructorDeclarationWithParams( + context.DeclarationNames.getCXXConstructorName(canQualType), false, true); + auto *addressConstructorDecl = CreateConstructorDeclarationWithParams( context, recordDecl, context.VoidTy, {context.UnsignedIntTy}, {"address"}, - context.DeclarationNames.getCXXConstructorName(canQualType), false); + context.DeclarationNames.getCXXConstructorName(canQualType), false, true); + hlsl::CreateFunctionTemplateDecl( + context, recordDecl, copyConstructorDecl, + Builder.getTemplateDecl()->getTemplateParameters()->begin(), 2); + hlsl::CreateFunctionTemplateDecl( + context, recordDecl, addressConstructorDecl, + Builder.getTemplateDecl()->getTemplateParameters()->begin(), 2); StringRef OpcodeGroup = GetHLOpcodeGroupName(HLOpcodeGroup::HLIntrinsic); unsigned Opcode = static_cast(IntrinsicOp::MOP_GetBufferContents); methodDecl->addAttr( HLSLIntrinsicAttr::CreateImplicit(context, OpcodeGroup, "", Opcode)); methodDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + copyConstructorDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + addressConstructorDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); return Builder.completeDefinition(); } diff --git a/tools/clang/lib/Sema/SemaExprCXX.cpp b/tools/clang/lib/Sema/SemaExprCXX.cpp index 4723bc93e9..5113c56205 100644 --- a/tools/clang/lib/Sema/SemaExprCXX.cpp +++ b/tools/clang/lib/Sema/SemaExprCXX.cpp @@ -1057,26 +1057,51 @@ Sema::BuildCXXTypeConstructExpr(TypeSourceInfo *TInfo, Expr *Arg = Exprs[0]; #ifdef ENABLE_SPIRV_CODEGEN if (hlsl::IsVKBufferPointerType(Ty) && Arg->getType()->isIntegerType()) { - for (auto *ctor : Ty->getAsCXXRecordDecl()->ctors()) { - if (auto *functionType = ctor->getType()->getAs()) { - if (functionType->getNumParams() != 1 || - !functionType->getParamType(0)->isIntegerType()) - continue; - - CanQualType argType = Arg->getType()->getCanonicalTypeUnqualified(); - if (!Arg->isRValue()) { - Arg = ImpCastExprToType(Arg, argType, CK_LValueToRValue).get(); - } - if (argType != Context.UnsignedLongLongTy) { - Arg = ImpCastExprToType(Arg, Context.UnsignedLongLongTy, - CK_IntegralCast) - .get(); - } - return CXXConstructExpr::Create( - Context, Ty, TyBeginLoc, ctor, false, {Arg}, false, false, false, - false, CXXConstructExpr::ConstructionKind::CK_Complete, - SourceRange(LParenLoc, RParenLoc)); + typedef DeclContext::specific_decl_iterator ft_iter; + auto *recordDecl = Ty->getAsCXXRecordDecl(); + auto *specDecl = cast(recordDecl); + auto *templatedDecl = + specDecl->getSpecializedTemplate()->getTemplatedDecl(); + auto functionTemplateDecls = + llvm::iterator_range(ft_iter(templatedDecl->decls_begin()), + ft_iter(templatedDecl->decls_end())); + for (auto *ftd : functionTemplateDecls) { + auto *fd = ftd->getTemplatedDecl(); + if (fd->getNumParams() != 1 || + !fd->getParamDecl(0)->getType()->isIntegerType()) + continue; + + void *insertPos; + auto templateArgs = ftd->getInjectedTemplateArgs(); + auto *functionDecl = ftd->findSpecialization(templateArgs, insertPos); + if (!functionDecl) { + DeclarationNameInfo DInfo(ftd->getDeclName(), + recordDecl->getLocation()); + auto *templateArgList = TemplateArgumentList::CreateCopy( + Context, templateArgs.data(), templateArgs.size()); + functionDecl = CXXConstructorDecl::Create( + Context, recordDecl, Arg->getLocStart(), DInfo, Ty, TInfo, false, + false, false, false); + functionDecl->setFunctionTemplateSpecialization(ftd, templateArgList, + insertPos); + } else if (functionDecl->getDeclKind() != Decl::Kind::CXXConstructor) { + continue; + } + + CanQualType argType = Arg->getType()->getCanonicalTypeUnqualified(); + if (!Arg->isRValue()) { + Arg = ImpCastExprToType(Arg, argType, CK_LValueToRValue).get(); + } + if (argType != Context.UnsignedLongLongTy) { + Arg = ImpCastExprToType(Arg, Context.UnsignedLongLongTy, + CK_IntegralCast) + .get(); } + return CXXConstructExpr::Create( + Context, Ty, TyBeginLoc, cast(functionDecl), + false, {Arg}, false, false, false, false, + CXXConstructExpr::ConstructionKind::CK_Complete, + SourceRange(LParenLoc, RParenLoc)); } } #endif diff --git a/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.from-uint.hlsl b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.from-uint.hlsl new file mode 100644 index 0000000000..b44e1eca09 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.from-uint.hlsl @@ -0,0 +1,46 @@ +// RUN: %dxc -spirv -Od -T cs_6_7 %s | FileCheck %s +// RUN: %dxc -spirv -Od -T cs_6_7 -DALIGN_16 %s | FileCheck %s +// RUN: %dxc -spirv -Od -T cs_6_7 -DNO_PC %s | FileCheck %s + +// Was getting bogus type errors with the defined changes + +#ifdef ALIGN_16 +typedef vk::BufferPointer BufferType; +#else +typedef vk::BufferPointer BufferType; +#endif +#ifndef NO_PC +struct PushConstantStruct { + BufferType push_buffer; +}; +[[vk::push_constant]] PushConstantStruct push_constant; +#endif + +RWStructuredBuffer output; + +// CHECK: [[INT:%[_0-9A-Za-z]*]] = OpTypeInt 32 1 +// CHECK: [[I0:%[_0-9A-Za-z]*]] = OpConstant [[INT]] 0 +// CHECK: [[UINT:%[_0-9A-Za-z]*]] = OpTypeInt 32 0 +// CHECK: [[U0:%[_0-9A-Za-z]*]] = OpConstant [[UINT]] 0 +// CHECK: [[PPUINT:%[_0-9A-Za-z]*]] = OpTypePointer PhysicalStorageBuffer [[UINT]] +// CHECK: [[PFPPUINT:%[_0-9A-Za-z]*]] = OpTypePointer Function [[PPUINT]] +// CHECK: [[PUUINT:%[_0-9A-Za-z]*]] = OpTypePointer Uniform [[UINT]] +// CHECK: [[OUTPUT:%[_0-9A-Za-z]*]] = OpVariable %{{[_0-9A-Za-z]*}} Uniform + +[numthreads(1, 1, 1)] +void main() { + uint64_t addr = 123; + vk::BufferPointer test = vk::BufferPointer(addr); + output[0] = test.Get(); +} + +// CHECK: [[TEST:%[_0-9A-Za-z]*]] = OpVariable [[PFPPUINT]] Function +// CHECK: [[X1:%[_0-9A-Za-z]*]] = OpConvertUToPtr [[PPUINT]] +// CHECK: OpStore [[TEST]] [[X1]] +// CHECK: [[X2:%[_0-9A-Za-z]*]] = OpLoad [[PPUINT]] [[TEST]] Aligned 32 +// CHECK: [[X3:%[_0-9A-Za-z]*]] = OpLoad [[UINT]] [[X2]] Aligned 4 +// CHECK: [[X4:%[_0-9A-Za-z]*]] = OpAccessChain [[PUUINT]] [[OUTPUT]] [[I0]] [[U0]] +// CHECK: OpStore [[X4]] [[X3]] +// CHECK: OpReturn +// CHECK: OpFunctionEnd +