Skip to content

Commit 88b942a

Browse files
committed
Rework DXIL op overload system
Add comments explaining the new system. Eliminate bool array in favor of array of masks for up to N dimensions. Add NumOverloadDims instead of two-mode system. Rework TypeSlots: - use enum, categorize basic, limit masks to used bits - void doesn't need a type slot (NumOverloadDims == 0 instead) - m_OverloadTypeName only contains basic type names Handle multi-overload in FixOverloadNames; new MayHaveNonCanonicalOverload is used to determine whether the overload name could need fixing. Extended overload is still a distinction because of the way the overloads must be wrapped in an unnamed StructType. However, it does not need a bit in the overload mask. Renamed GetVectorType to GetStructVectorType, since it's just used to get a struct for a particular vector type, not a vector type itself. In hctdb.py, no longer separate extended and vector overloads, just verify correctness of the incoming string, and add default vector overloads if necessary. In hctdb_instrhelp.py, update according to changes in hctdb.py, and eliminate needless, problematic, outdated comment printing.
1 parent fdaa369 commit 88b942a

4 files changed

Lines changed: 238 additions & 202 deletions

File tree

include/dxc/DXIL/DxilOperations.h

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,25 @@ class OP {
5757
// caches.
5858
void RefreshCache();
5959

60+
// The single llvm::Type * "OverloadType" has one of these forms:
61+
// No overloads (NumOverloadDims == 0):
62+
// - TS_Void: VoidTy
63+
// For single overload dimension (NumOverloadDims == 1):
64+
// - TS_F*, TS_I*: a scalar numeric type (half, float, i1, i64, etc.),
65+
// - TS_UDT: a pointer to a StructType representing a User Defined Type,
66+
// - TS_Object: a named StructType representing a built-in object, or
67+
// - TS_Vector: a vector type (<4 x float>, <16 x i16>, etc.)
68+
// For multiple overload dimensions (TS_Extended, NumOverloadDims > 1):
69+
// - an unnamed StructType containing each type for the corresponding
70+
// dimension, such as: type { i32, <2 x float> }
71+
// - contained type options are the same as for single dimension.
72+
6073
llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
74+
75+
// N-dimension convenience version of GetOpFunc:
6176
llvm::Function *GetOpFunc(OpCode OpCode,
62-
llvm::ArrayRef<llvm::Type *> ExtendedOverloads);
77+
llvm::ArrayRef<llvm::Type *> OverloadTypes);
78+
6379
const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &
6480
GetOpFuncList(OpCode OpCode) const;
6581
bool IsDxilOpUsed(OpCode opcode) const;
@@ -84,7 +100,8 @@ class OP {
84100

85101
llvm::Type *GetResRetType(llvm::Type *pOverloadType);
86102
llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
87-
llvm::Type *GetVectorType(unsigned numElements, llvm::Type *pOverloadType);
103+
llvm::Type *GetStructVectorType(unsigned numElements,
104+
llvm::Type *pOverloadType);
88105
bool IsResRetType(llvm::Type *Ty);
89106

90107
// Construct an unnamed struct type containing the set of member types.
@@ -145,6 +162,11 @@ class OP {
145162

146163
static bool IsDxilOpExtendedOverload(OpCode C);
147164

165+
// Return true if the overload name for this operation may be constructed
166+
// based on a type name that may not represent the same type in different
167+
// modules.
168+
static bool MayHaveNonCanonicalOverload(OpCode OC);
169+
148170
private:
149171
// Per-module properties.
150172
llvm::LLVMContext &m_Ctx;
@@ -168,15 +190,33 @@ class OP {
168190

169191
DXIL::LowPrecisionMode m_LowPrecisionMode;
170192

171-
static const unsigned kUserDefineTypeSlot = 9;
172-
static const unsigned kObjectTypeSlot = 10;
173-
static const unsigned kVectorTypeSlot = 11;
174-
static const unsigned kExtendedTypeSlot = 12;
175-
static const unsigned kNumTypeOverloads =
176-
13; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj, vec, extended
193+
// Overload types are split into "basic" overload types and special types
194+
// Basic: void, half, float, double, i1, i8, i16, i32, i64
195+
// - These have one canonical overload per TypeSlot
196+
// Special: udt, obj, vec, extended
197+
// - These may have many overloads per type slot
198+
enum TypeSlot : unsigned {
199+
TS_F16 = 0,
200+
TS_F32 = 1,
201+
TS_F64 = 2,
202+
TS_I1 = 3,
203+
TS_I8 = 4,
204+
TS_I16 = 5,
205+
TS_I32 = 6,
206+
TS_I64 = 7,
207+
TS_BasicCount,
208+
TS_UDT = 8, // Ex: %"struct.MyStruct" *
209+
TS_Object = 9, // Ex: %"class.StructuredBuffer<Foo>"
210+
TS_Vector = 10, // Ex: <8 x i16>
211+
TS_MaskBitCount, // Types used in Mask end here
212+
// TS_Extended is only used to identify the unnamed struct type used to wrap
213+
// multiple overloads when using GetTypeSlot.
214+
TS_Extended, // Ex: type { float, <16 x i32> }
215+
TS_Invalid = UINT_MAX,
216+
};
177217

178-
llvm::Type *m_pResRetType[kNumTypeOverloads];
179-
llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
218+
llvm::Type *m_pResRetType[TS_BasicCount];
219+
llvm::Type *m_pCBufferRetType[TS_BasicCount];
180220

181221
struct OpCodeCacheItem {
182222
llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> pOverloads;
@@ -190,10 +230,10 @@ class OP {
190230
struct OverloadMask {
191231
// mask of type slot bits as (1 << TypeSlot)
192232
uint16_t SlotMask;
193-
static_assert(kNumTypeOverloads <= (sizeof(SlotMask) * 8));
233+
static_assert(TS_MaskBitCount <= (sizeof(SlotMask) * 8));
194234
bool operator[](unsigned TypeSlot) const {
195-
return (TypeSlot < kNumTypeOverloads) ? (bool)(SlotMask & (1 << TypeSlot))
196-
: 0;
235+
return (TypeSlot < TS_MaskBitCount) ? (bool)(SlotMask & (1 << TypeSlot))
236+
: 0;
197237
}
198238
operator bool() const { return SlotMask != 0; }
199239
};
@@ -202,28 +242,21 @@ class OP {
202242
const char *pOpCodeName;
203243
OpCodeClass opCodeClass;
204244
const char *pOpCodeClassName;
205-
bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64,
206-
// udt, obj, vec, extended
207245
llvm::Attribute::AttrKind FuncAttr;
208246

209-
// Extended Type Overloads:
210-
// This is an encoding for a multi-dimensional overload.
211-
// 1. Only bAllowOverload[kExtendedTypeSlot] is set to true
212-
// 2. ExtendedOverloads defines allowed types for each overload index
213-
// 3. AllowedVectorElements defines allowed vector component types,
214-
// when kVectorTypeSlot bit is set for the corresponding overload index.
215-
// This includes when a single vector overload type is specified with
216-
// bAllowOverload[kVectorTypeSlot].
217-
218-
// A bit mask of allowed type slots per extended overload
219-
OverloadMask ExtendedOverloads[DXIL::kDxilMaxOloadDims];
220-
// A bit mask of allowed vector element types for the vector overload
221-
// or each corresponding extended vector overload.
247+
// Number of overload dimensions used by the operation.
248+
unsigned int NumOverloadDims;
249+
250+
// Mask of supported overload types for each overload dimension.
251+
OverloadMask AllowedOverloads[DXIL::kDxilMaxOloadDims];
252+
253+
// Mask of scalar components allowed for each demension where
254+
// AllowedOverloads[n][TS_Vector] is true.
222255
OverloadMask AllowedVectorElements[DXIL::kDxilMaxOloadDims];
223256
};
224257
static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
225258

226-
static const char *m_OverloadTypeName[kNumTypeOverloads];
259+
static const char *m_OverloadTypeName[TS_BasicCount];
227260
static const char *m_NamePrefix;
228261
static const char *m_TypePrefix;
229262
static const char *m_MatrixTypePrefix;

0 commit comments

Comments
 (0)