Skip to content

Commit 4d66ec8

Browse files
authored
Remove assumption that templates are never UDTs (#4752)
* Remove assumption that templates are never UDTs There was an assumtion in the HLSL sema code that a template specialization could never be a UDT. This assumption is incorrect now. I've reworked the code so that we instead assume built-in types are marked as `implicit` (which they all should and seem to be). Correcting this in `IsHLSLNumericUserDefinedType` resulted in some breakge in raytracing code generation because we used that method to deterimine if structures could be payloads or attributes. That was an incorrect API usage because we do have some builtin types that are allowed. The change here does the following: * Introduce `IsHLSLBuiltinRayAttributeStruct` which returns true for the builtin raytracing data types that behave like UDTs. * Introduce `IsHLSLCopyableAnnotatableRecord` returns true for user-defined trivially copyable structures and the builtin ray tracing types. * Adjust `IsHLSLNumericUserDefinedType` to do what the name says. * Consolidates implementations of `IsUserDefinedRecordType` across the project. * Adds new test cases for the ray tracing built in structs to cover diagnostic cases missed by the existing tests. The new `IsHLSLBuiltinRayAttributeStruct` is hacky and uses the type names (as the old code did). We should in the future insert an internal attribute on the types that can be used to denote them so that we don't need to match string names. Resolves #4735
1 parent 46db044 commit 4d66ec8

11 files changed

Lines changed: 141 additions & 44 deletions

File tree

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ bool IsHLSLBufferViewType(clang::QualType type);
407407
bool IsHLSLStructuredBufferType(clang::QualType type);
408408
bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type);
409409
bool IsHLSLNumericUserDefinedType(clang::QualType type);
410+
bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT);
411+
bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT);
410412
bool IsHLSLAggregateType(clang::QualType type);
411413
clang::QualType GetHLSLResourceResultType(clang::QualType type);
412414
unsigned GetHLSLResourceTemplateUInt(clang::QualType type);

tools/clang/lib/AST/HlslTypes.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type) {
9797
if (isa<RecordType>(Ty)) {
9898
if (IsHLSLVecMatType(type))
9999
return true;
100-
return IsHLSLNumericUserDefinedType(type);
100+
return IsHLSLCopyableAnnotatableRecord(type);
101101
} else if (type->isArrayType()) {
102102
return IsHLSLNumericOrAggregateOfNumericType(QualType(type->getArrayElementTypeNoTypeQual(), 0));
103103
}
@@ -111,14 +111,7 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
111111
const clang::Type *Ty = type.getCanonicalType().getTypePtr();
112112
if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
113113
const RecordDecl *RD = RT->getDecl();
114-
if (isa<ClassTemplateSpecializationDecl>(RD)) {
115-
return false; // UDT are not templates
116-
}
117-
// TODO: avoid check by name
118-
StringRef name = RD->getName();
119-
if (name == "ByteAddressBuffer" ||
120-
name == "RWByteAddressBuffer" ||
121-
name == "RaytracingAccelerationStructure")
114+
if (!IsUserDefinedRecordType(type))
122115
return false;
123116
for (auto member : RD->fields()) {
124117
if (!IsHLSLNumericOrAggregateOfNumericType(member->getType()))
@@ -129,15 +122,34 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
129122
return false;
130123
}
131124

125+
// In some cases we need record types that are annotatable and trivially
126+
// copyable from outside the shader. This excludes resource types which may be
127+
// trivially copyable inside the shader, and builtin matrix and vector types
128+
// which can't be annotated. But includes UDTs of trivially copyable data and
129+
// the builtin trivially copyable raytracing structs.
130+
bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT) {
131+
return IsHLSLNumericUserDefinedType(QT) ||
132+
IsHLSLBuiltinRayAttributeStruct(QT);
133+
}
134+
135+
bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT) {
136+
QT = QT.getCanonicalType();
137+
const clang::Type *Ty = QT.getTypePtr();
138+
if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
139+
const RecordDecl *RD = RT->getDecl();
140+
if (RD->getName() == "BuiltInTriangleIntersectionAttributes" ||
141+
RD->getName() == "RayDesc")
142+
return true;
143+
}
144+
return false;
145+
}
146+
132147
// Aggregate types are arrays and user-defined structs
133148
bool IsHLSLAggregateType(clang::QualType type) {
134149
type = type.getCanonicalType();
135150
if (isa<clang::ArrayType>(type)) return true;
136151

137-
const RecordType *Record = dyn_cast<RecordType>(type);
138-
return Record != nullptr
139-
&& !IsHLSLVecMatType(type) && !IsHLSLResourceType(type)
140-
&& !dyn_cast<ClassTemplateSpecializationDecl>(Record->getAsCXXRecordDecl());
152+
return IsUserDefinedRecordType(type);
141153
}
142154

143155
clang::QualType GetElementTypeOrType(clang::QualType type) {
@@ -586,23 +598,17 @@ bool IsHLSLSubobjectType(clang::QualType type) {
586598
return GetHLSLSubobjectKind(type, kind, hgType);
587599
}
588600

589-
bool IsUserDefinedRecordType(clang::QualType type) {
590-
if (const auto *rt = type->getAs<RecordType>()) {
591-
// HLSL specific types
592-
if (hlsl::IsHLSLResourceType(type) || hlsl::IsHLSLVecMatType(type) ||
593-
isa<ExtVectorType>(type.getTypePtr()) || type->isBuiltinType() ||
594-
type->isArrayType()) {
595-
return false;
596-
}
597-
598-
// SubpassInput or SubpassInputMS type
599-
if (rt->getDecl()->getName() == "SubpassInput" ||
600-
rt->getDecl()->getName() == "SubpassInputMS") {
601+
bool IsUserDefinedRecordType(clang::QualType QT) {
602+
const clang::Type *Ty = QT.getCanonicalType().getTypePtr();
603+
if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
604+
const RecordDecl *RD = RT->getDecl();
605+
if (RD->isImplicit())
601606
return false;
602-
}
607+
if (auto TD = dyn_cast<ClassTemplateSpecializationDecl>(RD))
608+
if (TD->getSpecializedTemplate()->isImplicit())
609+
return false;
603610
return true;
604611
}
605-
606612
return false;
607613
}
608614

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2202,7 +2202,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
22022202
rayShaderHaveErrors = true;
22032203
}
22042204
if (ArgNo < 2) {
2205-
if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
2205+
if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
22062206
Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
22072207
DiagnosticsEngine::Error,
22082208
"payload and attribute structures must be user defined types with only numeric contents."));
@@ -2230,7 +2230,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
22302230
rayShaderHaveErrors = true;
22312231
}
22322232
if (ArgNo < 1) {
2233-
if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
2233+
if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
22342234
Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
22352235
DiagnosticsEngine::Error,
22362236
"ray payload parameter must be a user defined type with only numeric contents."));
@@ -2255,7 +2255,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
22552255
rayShaderHaveErrors = true;
22562256
}
22572257
if (ArgNo < 1) {
2258-
if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
2258+
if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
22592259
Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
22602260
DiagnosticsEngine::Error,
22612261
"callable parameter must be a user defined type with only numeric contents."));

tools/clang/lib/SPIRV/AstTypeProbe.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -355,19 +355,6 @@ bool isResourceType(QualType type) {
355355
return hlsl::IsHLSLResourceType(type);
356356
}
357357

358-
bool isUserDefinedRecordType(const ASTContext &astContext, QualType type) {
359-
if (const auto *rt = type->getAs<RecordType>()) {
360-
if (rt->getDecl()->getName() == "mips_slice_type" ||
361-
rt->getDecl()->getName() == "sample_slice_type") {
362-
return false;
363-
}
364-
}
365-
return type->getAs<RecordType>() != nullptr && !isResourceType(type) &&
366-
!isMatrixOrArrayOfMatrix(astContext, type) &&
367-
!isScalarOrVectorType(type, nullptr, nullptr) &&
368-
!isArrayType(type, nullptr, nullptr);
369-
}
370-
371358
bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
372359
// Primitive types
373360
{

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "dxc/DXIL/DxilConstants.h"
2121
#include "dxc/HlslIntrinsicOp.h"
2222
#include "spirv-tools/optimizer.hpp"
23+
#include "clang/AST/HlslTypes.h"
2324
#include "clang/AST/RecordLayout.h"
2425
#include "clang/SPIRV/AstTypeProbe.h"
2526
#include "clang/SPIRV/String.h"
@@ -2667,7 +2668,7 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr,
26672668
dyn_cast<CXXMethodDecl>(operatorCall->getCalleeDecl())) {
26682669
QualType parentType =
26692670
QualType(cxxMethodDecl->getParent()->getTypeForDecl(), 0);
2670-
if (isUserDefinedRecordType(astContext, parentType)) {
2671+
if (hlsl::IsUserDefinedRecordType(parentType)) {
26712672
// If the parent is a user-defined record type
26722673
return processCall(callExpr);
26732674
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %clang_cc1 -fsyntax-only -ffreestanding -HV 2021 -verify %s
2+
3+
ByteAddressBuffer In;
4+
RWBuffer<float> Out;
5+
6+
7+
[numthreads(1,1,1)]
8+
void CSMain()
9+
{
10+
RWBuffer<float> FB = In.Load<RWBuffer<float> >(0); // expected-error {{Explicit template arguments on intrinsic Load must be a single numeric type}}
11+
Out[0] = FB[0];
12+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: %dxc -E CSMain -T cs_6_6 -HV 2021 -fcgl %s | FileCheck %s
2+
template<typename T>
3+
struct MyStructA
4+
{
5+
T m_0;
6+
};
7+
8+
struct MyStructB
9+
{
10+
MyStructA<float> m_a;
11+
float m_1;
12+
float m_2;
13+
float m_3;
14+
};
15+
16+
ByteAddressBuffer g_bab;
17+
RWBuffer<float> result;
18+
19+
// This test verifies that templates can be used both as the argument to
20+
// ByteAddressBuffer::Load and as a member of a structure passed as an argument
21+
// to ByteAddressLoad as long as the specialized template conforms to the rules
22+
// for HLSL (must only contain integral and floating point members).
23+
// CHECK-NOT: error
24+
25+
[numthreads(1,1,1)]
26+
void CSMain()
27+
{
28+
// CHECK: call %"struct.MyStructA<float>"* @"dx.hl.op..%\22struct.MyStructA<float>\22* (i32, %dx.types.Handle, i32)"(i32 229, %dx.types.Handle %{{[0-9]+}}, i32 0)
29+
MyStructA<float> a = g_bab.Load<MyStructA<float> >(0);
30+
result[0] = a.m_0;
31+
32+
// CHECK: call %struct.MyStructB* @"dx.hl.op..%struct.MyStructB* (i32, %dx.types.Handle, i32)"(i32 229, %dx.types.Handle %{{[0-9]+}}, i32 1)
33+
MyStructB b = g_bab.Load<MyStructB>(1);
34+
result[1] = b.m_a.m_0;
35+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
2+
3+
// CHECK-NOT: error
4+
[shader("anyhit")]
5+
void anyhit_param0( inout RayDesc D1, RayDesc D2 ) { }
6+
7+
[shader("anyhit")]
8+
void anyhit_param1( inout BuiltInTriangleIntersectionAttributes A1, BuiltInTriangleIntersectionAttributes A2 ) { }
9+
10+
// CHECK: builtin-ray-types-anyhit.hlsl:15:37: error: payload and attribute structures must be user defined types with only numeric contents.
11+
// CHECK: builtin-ray-types-anyhit.hlsl:15:48: error: payload and attribute structures must be user defined types with only numeric contents.
12+
// CHECK: builtin-ray-types-anyhit.hlsl:15:6: error: shader must include inout payload structure parameter.
13+
// CHECK: builtin-ray-types-anyhit.hlsl:15:6: error: shader must include attributes structure parameter.
14+
[shader("anyhit")]
15+
void anyhit_param2( inout Texture2D A1, float4 A2 ) { }
16+
// CHECK-NOT: error
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
2+
3+
// CHECK-NOT: error
4+
5+
[shader("callable")]
6+
void callable0( inout RayDesc param ) {}
7+
8+
[shader("callable")]
9+
void callable1( inout BuiltInTriangleIntersectionAttributes param ) {}
10+
11+
// CHECK: builtin-ray-types-callable.hlsl:14:33: error: callable parameter must be a user defined type with only numeric contents.
12+
// CHECK: builtin-ray-types-callable.hlsl:14:6: error: shader must include inout parameter structure.
13+
[shader("callable")]
14+
void callable2( inout Texture2D param ) {}
15+
16+
// CHECK-NOT: error
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
2+
3+
// CHECK-NOT: error
4+
5+
[shader("miss")]
6+
void miss0(inout RayDesc PL) { }
7+
8+
[shader("miss")]
9+
void miss1(inout BuiltInTriangleIntersectionAttributes PL) { }
10+
11+
// CHECK: builtin-ray-types-miss.hlsl:15:28: error: ray payload parameter must be a user defined type with only numeric contents.
12+
// CHECK: builtin-ray-types-miss.hlsl:15:6: error: shader must include inout payload structure parameter.
13+
14+
[shader("miss")]
15+
void miss2(inout Texture2D PL) { }
16+
17+
// CHECK-NOT: error

0 commit comments

Comments
 (0)