Skip to content

Commit bece3d4

Browse files
authored
Report error for unsupported types of SV semantics (#3043)
1 parent 97b0b1a commit bece3d4

31 files changed

Lines changed: 915 additions & 35 deletions

include/dxc/DXIL/DxilSemantic.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#pragma once
1313

1414
#include "llvm/ADT/StringRef.h"
15+
#include "llvm/IR/Type.h"
1516

1617
#include "DxilConstants.h"
1718
#include "DxilShaderModel.h"
@@ -23,6 +24,33 @@ class Semantic {
2324
public:
2425
using Kind = DXIL::SemanticKind;
2526

27+
enum class CompTy {
28+
BoolTy = 1 << 0,
29+
HalfTy = 1 << 1,
30+
Int16Ty = 1 << 2,
31+
FloatTy = 1 << 3,
32+
Int32Ty = 1 << 4,
33+
DoubleTy = 1 << 5,
34+
Int64Ty = 1 << 6,
35+
HalfOrFloatTy = HalfTy | FloatTy,
36+
BoolOrInt32Ty = BoolTy | Int32Ty,
37+
BoolOrInt16Or32Ty = BoolTy | Int16Ty | Int32Ty,
38+
FloatOrInt32Ty = FloatTy | Int32Ty,
39+
Int16Or32Ty = Int16Ty | Int32Ty,
40+
AnyIntTy = BoolTy | Int16Ty | Int32Ty | Int64Ty,
41+
AnyFloatTy = HalfTy | FloatTy | DoubleTy,
42+
AnyTy = AnyIntTy | AnyFloatTy,
43+
};
44+
45+
enum class SizeClass {
46+
Unknown,
47+
Scalar,
48+
Vec2,
49+
Vec3,
50+
Vec4,
51+
Other
52+
};
53+
2654
static const int kUndefinedRow = -1;
2755
static const int kUndefinedCol = -1;
2856

@@ -41,13 +69,19 @@ class Semantic {
4169
const char *GetName() const;
4270
bool IsArbitrary() const;
4371
bool IsInvalid() const;
72+
bool IsSupportedType(llvm::Type *semTy) const;
73+
CompTy GetCompType(llvm::Type* ty) const;
74+
SizeClass GetCompCount(llvm::Type* ty) const;
4475

4576
private:
4677
Kind m_Kind; // Semantic kind.
47-
const char *m_pszName; // Canonical name (for system semantics).
78+
const char *m_pszName; // Canonical name (for system semantics).
79+
CompTy m_allowedTys; // Types allowed for the semantic
80+
SizeClass m_minCompCount; // Minimum component count that is allowed for a semantic
81+
SizeClass m_maxCompCount; // Maximum component count that is allowed for a semantic
4882

4983
Semantic() = delete;
50-
Semantic(Kind Kind, const char *pszName);
84+
Semantic(Kind Kind, const char *pszName, CompTy allowedTys, SizeClass minCompCount, SizeClass maxCompCount);
5185

5286
// Table of all semantic properties.
5387
static const unsigned kNumSemanticRecords = (unsigned)Kind::Invalid + 1;

lib/DXIL/DxilSemantic.cpp

Lines changed: 158 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "dxc/DXIL/DxilSemantic.h"
1212
#include "dxc/DXIL/DxilSignature.h"
1313
#include "dxc/DXIL/DxilShaderModel.h"
14+
#include "dxc/DXIL/DxilUtil.h"
1415
#include "dxc/Support/Global.h"
1516

1617
#include <string>
@@ -24,9 +25,15 @@ namespace hlsl {
2425
// Semantic class methods.
2526
//
2627
Semantic::Semantic(Kind Kind,
27-
const char *pszName)
28+
const char *pszName,
29+
CompTy allowedTys,
30+
SizeClass minCompCount,
31+
SizeClass maxCompCount)
2832
: m_Kind(Kind)
2933
, m_pszName(pszName)
34+
, m_allowedTys(allowedTys)
35+
, m_minCompCount(minCompCount)
36+
, m_maxCompCount(maxCompCount)
3037
{
3138
}
3239

@@ -113,41 +120,159 @@ bool Semantic::IsInvalid() const {
113120
return m_Kind == Kind::Invalid;
114121
}
115122

123+
Semantic::SizeClass Semantic::GetCompCount(llvm::Type* ty) const {
124+
125+
if (!ty->isVectorTy() && !dxilutil::IsIntegerOrFloatingPointType(ty))
126+
return SizeClass::Unknown;
127+
128+
if (ty->isVectorTy()) {
129+
if (ty->getVectorNumElements() == 1) {
130+
return SizeClass::Scalar;
131+
}
132+
else if (ty->getVectorNumElements() == 2) {
133+
return SizeClass::Vec2;
134+
}
135+
else if (ty->getVectorNumElements() == 3) {
136+
return SizeClass::Vec3;
137+
}
138+
else if (ty->getVectorNumElements() == 4) {
139+
return SizeClass::Vec4;
140+
}
141+
else {
142+
DXASSERT(false, "Unexpected number of vector elements.");
143+
return SizeClass::Unknown;
144+
}
145+
}
146+
147+
return SizeClass::Scalar;
148+
}
149+
150+
Semantic::CompTy Semantic::GetCompType(llvm::Type* ty) const {
151+
152+
if (!ty->isVectorTy() && !dxilutil::IsIntegerOrFloatingPointType(ty))
153+
return CompTy::AnyTy;
154+
155+
if (ty->isVectorTy())
156+
ty = ty->getScalarType();
157+
158+
// must be an integer or a floating point type here
159+
DXASSERT_NOMSG(dxilutil::IsIntegerOrFloatingPointType(ty));
160+
if (ty->getScalarType()->isIntegerTy()) {
161+
if (ty->getScalarSizeInBits() == 1) {
162+
return CompTy::BoolTy;
163+
} else if (ty->getScalarSizeInBits() == 16) {
164+
return CompTy::Int16Ty;
165+
} else if (ty->getScalarSizeInBits() == 32) {
166+
return CompTy::Int32Ty;
167+
} else {
168+
return CompTy::Int64Ty;
169+
}
170+
} else {
171+
if (ty->isHalfTy()) {
172+
return CompTy::HalfTy;
173+
} else if (ty->isFloatTy()) {
174+
return CompTy::FloatTy;
175+
} else {
176+
DXASSERT_NOMSG(ty->isDoubleTy());
177+
return CompTy::DoubleTy;
178+
}
179+
}
180+
}
181+
182+
static bool IsScalarOrVectorTy(llvm::Type* ty) {
183+
if (dxilutil::IsIntegerOrFloatingPointType(ty))
184+
return true;
185+
if (ty->isVectorTy() &&
186+
dxilutil::IsIntegerOrFloatingPointType(ty->getVectorElementType()))
187+
return true;
188+
return false;
189+
}
190+
191+
bool Semantic::IsSupportedType(llvm::Type* semTy) const {
192+
193+
if (m_Kind == Kind::Invalid)
194+
return false;
195+
196+
// Skip type checking for Arbitrary kind
197+
if (m_Kind == Kind::Arbitrary)
198+
return true;
199+
200+
if (!IsScalarOrVectorTy(semTy)) {
201+
// We only allow scalar or vector types as a valid semantic type except in some cases
202+
// such as Clip/Cull or Tessfactor.
203+
if (m_minCompCount == SizeClass::Other) {
204+
if (semTy->isArrayTy()) {
205+
semTy = semTy->getArrayElementType();
206+
// TessFactor or InsideTessFactor must either be float[2] or float
207+
if ((m_Kind == Kind::TessFactor ||
208+
m_Kind == Kind::InsideTessFactor) &&
209+
!dxilutil::IsIntegerOrFloatingPointType(semTy)) {
210+
return false;
211+
}
212+
// Clip/Cull can be array of scalar or vector
213+
if ((m_Kind == Kind::ClipDistance ||
214+
m_Kind == Kind::CullDistance) &&
215+
!IsScalarOrVectorTy(semTy)) {
216+
return false;
217+
}
218+
}
219+
else {
220+
// Do not support other types such as matrix.
221+
return false;
222+
}
223+
}
224+
else {
225+
return false;
226+
}
227+
}
228+
229+
if (((unsigned)m_allowedTys & (unsigned)GetCompType(semTy)) == 0)
230+
return false;
231+
232+
// Skip type-shape validation for semantics marked as Other
233+
if (m_minCompCount == SizeClass::Other)
234+
return true;
235+
236+
SizeClass compSzClass = GetCompCount(semTy);
237+
return compSzClass >= m_minCompCount &&
238+
compSzClass <= m_maxCompCount;
239+
}
240+
116241
typedef Semantic SP;
117242
const Semantic Semantic::ms_SemanticTable[kNumSemanticRecords] = {
118243
// Kind Name
119-
SP(Kind::Arbitrary, nullptr),
120-
SP(Kind::VertexID, "SV_VertexID"),
121-
SP(Kind::InstanceID, "SV_InstanceID"),
122-
SP(Kind::Position, "SV_Position"),
123-
SP(Kind::RenderTargetArrayIndex,"SV_RenderTargetArrayIndex"),
124-
SP(Kind::ViewPortArrayIndex, "SV_ViewportArrayIndex"),
125-
SP(Kind::ClipDistance, "SV_ClipDistance"),
126-
SP(Kind::CullDistance, "SV_CullDistance"),
127-
SP(Kind::OutputControlPointID, "SV_OutputControlPointID"),
128-
SP(Kind::DomainLocation, "SV_DomainLocation"),
129-
SP(Kind::PrimitiveID, "SV_PrimitiveID"),
130-
SP(Kind::GSInstanceID, "SV_GSInstanceID"),
131-
SP(Kind::SampleIndex, "SV_SampleIndex"),
132-
SP(Kind::IsFrontFace, "SV_IsFrontFace"),
133-
SP(Kind::Coverage, "SV_Coverage"),
134-
SP(Kind::InnerCoverage, "SV_InnerCoverage"),
135-
SP(Kind::Target, "SV_Target"),
136-
SP(Kind::Depth, "SV_Depth"),
137-
SP(Kind::DepthLessEqual, "SV_DepthLessEqual"),
138-
SP(Kind::DepthGreaterEqual, "SV_DepthGreaterEqual"),
139-
SP(Kind::StencilRef, "SV_StencilRef"),
140-
SP(Kind::DispatchThreadID, "SV_DispatchThreadID"),
141-
SP(Kind::GroupID, "SV_GroupID"),
142-
SP(Kind::GroupIndex, "SV_GroupIndex"),
143-
SP(Kind::GroupThreadID, "SV_GroupThreadID"),
144-
SP(Kind::TessFactor, "SV_TessFactor"),
145-
SP(Kind::InsideTessFactor, "SV_InsideTessFactor"),
146-
SP(Kind::ViewID, "SV_ViewID"),
147-
SP(Kind::Barycentrics, "SV_Barycentrics"),
148-
SP(Kind::ShadingRate, "SV_ShadingRate"),
149-
SP(Kind::CullPrimitive, "SV_CullPrimitive"),
150-
SP(Kind::Invalid, nullptr),
244+
SP(Kind::Arbitrary, nullptr, CompTy::AnyTy, SizeClass::Other, SizeClass::Other),
245+
SP(Kind::VertexID, "SV_VertexID", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
246+
SP(Kind::InstanceID, "SV_InstanceID", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
247+
SP(Kind::Position, "SV_Position", CompTy::HalfOrFloatTy, SizeClass::Vec4, SizeClass::Vec4),
248+
SP(Kind::RenderTargetArrayIndex,"SV_RenderTargetArrayIndex", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
249+
SP(Kind::ViewPortArrayIndex, "SV_ViewportArrayIndex", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
250+
SP(Kind::ClipDistance, "SV_ClipDistance", CompTy::HalfOrFloatTy, SizeClass::Other, SizeClass::Other),
251+
SP(Kind::CullDistance, "SV_CullDistance", CompTy::HalfOrFloatTy, SizeClass::Other, SizeClass::Other),
252+
SP(Kind::OutputControlPointID, "SV_OutputControlPointID", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
253+
SP(Kind::DomainLocation, "SV_DomainLocation", CompTy::FloatTy, SizeClass::Scalar, SizeClass::Vec3),
254+
SP(Kind::PrimitiveID, "SV_PrimitiveID", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
255+
SP(Kind::GSInstanceID, "SV_GSInstanceID", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
256+
SP(Kind::SampleIndex, "SV_SampleIndex", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
257+
SP(Kind::IsFrontFace, "SV_IsFrontFace", CompTy::BoolOrInt32Ty, SizeClass::Scalar, SizeClass::Scalar),
258+
SP(Kind::Coverage, "SV_Coverage", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
259+
SP(Kind::InnerCoverage, "SV_InnerCoverage", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
260+
SP(Kind::Target, "SV_Target", CompTy::AnyTy, SizeClass::Scalar, SizeClass::Vec4),
261+
SP(Kind::Depth, "SV_Depth", CompTy::HalfOrFloatTy, SizeClass::Scalar, SizeClass::Scalar),
262+
SP(Kind::DepthLessEqual, "SV_DepthLessEqual", CompTy::HalfOrFloatTy, SizeClass::Scalar, SizeClass::Scalar),
263+
SP(Kind::DepthGreaterEqual, "SV_DepthGreaterEqual", CompTy::HalfOrFloatTy, SizeClass::Scalar, SizeClass::Scalar),
264+
SP(Kind::StencilRef, "SV_StencilRef", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
265+
SP(Kind::DispatchThreadID, "SV_DispatchThreadID", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Vec3),
266+
SP(Kind::GroupID, "SV_GroupID", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Vec3),
267+
SP(Kind::GroupIndex, "SV_GroupIndex", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
268+
SP(Kind::GroupThreadID, "SV_GroupThreadID", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Vec3),
269+
SP(Kind::TessFactor, "SV_TessFactor", CompTy::HalfOrFloatTy, SizeClass::Other, SizeClass::Other),
270+
SP(Kind::InsideTessFactor, "SV_InsideTessFactor", CompTy::HalfOrFloatTy, SizeClass::Other, SizeClass::Other),
271+
SP(Kind::ViewID, "SV_ViewID", CompTy::Int32Ty, SizeClass::Scalar, SizeClass::Scalar),
272+
SP(Kind::Barycentrics, "SV_Barycentrics", CompTy::HalfOrFloatTy, SizeClass::Vec3, SizeClass::Vec3),
273+
SP(Kind::ShadingRate, "SV_ShadingRate", CompTy::Int16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
274+
SP(Kind::CullPrimitive, "SV_CullPrimitive", CompTy::BoolOrInt16Or32Ty, SizeClass::Scalar, SizeClass::Scalar),
275+
SP(Kind::Invalid, nullptr, CompTy::AnyTy, SizeClass::Other, SizeClass::Other),
151276
};
152277

153278
} // namespace hlsl

lib/HLSL/HLSignatureLower.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,50 @@ void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
12101210
}
12111211
}
12121212

1213+
bool HLSignatureLower::ValidateSemanticType(llvm::Function* F) {
1214+
bool result = true;
1215+
DxilFunctionAnnotation* funcAnnotation = HLM.GetFunctionAnnotation(F);
1216+
if (!funcAnnotation) {
1217+
return result;
1218+
}
1219+
for (Argument& arg : F->args()) {
1220+
DxilParameterAnnotation &paramAnnotation =
1221+
funcAnnotation->GetParameterAnnotation(arg.getArgNo());
1222+
llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
1223+
if (semanticStr.empty()) {
1224+
continue;
1225+
}
1226+
unsigned index;
1227+
StringRef baseSemanticStr;
1228+
Semantic::DecomposeNameAndIndex(semanticStr, &baseSemanticStr, &index);
1229+
const Semantic* semantic = Semantic::GetByName(baseSemanticStr);
1230+
Type* argTy = arg.getType();
1231+
1232+
if (argTy->isPointerTy())
1233+
argTy = cast<PointerType>(argTy)->getPointerElementType();
1234+
1235+
if (argTy->isArrayTy()) {
1236+
// Array type for arguments with specific qualifiers are expected.
1237+
// In this case, we do validation on array's element type.
1238+
DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
1239+
if (inputQual == DxilParamInputQual::InputPatch ||
1240+
inputQual == DxilParamInputQual::InputPrimitive ||
1241+
inputQual == DxilParamInputQual::OutIndices ||
1242+
inputQual == DxilParamInputQual::OutPrimitives ||
1243+
inputQual == DxilParamInputQual::OutputPatch ||
1244+
inputQual == DxilParamInputQual::OutVertices) {
1245+
argTy = cast<ArrayType>(argTy)->getArrayElementType();
1246+
}
1247+
}
1248+
1249+
if (!semantic->IsSupportedType(argTy)) {
1250+
dxilutil::EmitErrorOnFunction(F->getContext(), F, "invalid type used for \'"+ semanticStr.str() + "\' semantic.");
1251+
result = false;
1252+
}
1253+
}
1254+
return result;
1255+
}
1256+
12131257
void HLSignatureLower::GenerateDxilCSInputs() {
12141258
OP *hlslOP = HLM.GetOP();
12151259

@@ -1721,6 +1765,12 @@ void HLSignatureLower::GenerateGetMeshPayloadOperation() {
17211765
}
17221766
// Lower signatures.
17231767
void HLSignatureLower::Run() {
1768+
1769+
// Generate error and exit if semantic type
1770+
// is not one of the allowed types
1771+
if (!ValidateSemanticType(Entry))
1772+
return;
1773+
17241774
DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
17251775
if (props.IsGraphics()) {
17261776
if (props.IsMS()) {

lib/HLSL/HLSignatureLower.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <unordered_set>
1414
#include <unordered_map>
1515
#include "dxc/DXIL/DxilConstants.h"
16+
#include "llvm/IR/Function.h"
17+
#include "llvm/IR/Argument.h"
1618

1719
namespace llvm {
1820
class Value;
@@ -53,6 +55,7 @@ class HLSignatureLower {
5355
void GenerateDxilPrimOutputs();
5456
void GenerateDxilInputsOutputs(DXIL::SignatureKind SK);
5557
void GenerateDxilCSInputs();
58+
bool ValidateSemanticType(llvm::Function* F);
5659
void GenerateDxilPatchConstantLdSt();
5760
void GenerateDxilPatchConstantFunctionInputs();
5861
void GenerateClipPlanesForVS(llvm::Value *outPosition);

0 commit comments

Comments
 (0)