Skip to content

Commit 96b9074

Browse files
committed
Rework DXIL op overload system
Add comments explaining the new system. Eliminate bool array in favor of array of masks for up to N dimensions. Add NumOverloadDims instead of two-mode system. Rework TypeSlots: - use enum, categorize basic, limit masks to used bits - void doesn't need a type slot (NumOverloadDims == 0 instead) - m_OverloadTypeName only contains basic type names Handle multi-overload in FixOverloadNames; new MayHaveNonCanonicalOverload is used to determine whether the overload name could need fixing. Extended overload is still a distinction because of the way the overloads must be wrapped in an unnamed StructType. However, it does not need a bit in the overload mask. Renamed GetVectorType to GetStructVectorType, since it's just used to get a struct for a particular vector type, not a vector type itself. In hctdb.py, no longer separate extended and vector overloads, just verify correctness of the incoming string, and add default vector overloads if necessary. In hctdb_instrhelp.py, update according to changes in hctdb.py, and eliminate needless, problematic, outdated comment printing.
1 parent fdaa369 commit 96b9074

4 files changed

Lines changed: 237 additions & 202 deletions

File tree

include/dxc/DXIL/DxilOperations.h

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,25 @@ class OP {
5757
// caches.
5858
void RefreshCache();
5959

60+
// The single llvm::Type * "OverloadType" has one of these forms:
61+
// No overloads (NumOverloadDims == 0):
62+
// - TS_Void: VoidTy
63+
// For single overload dimension (NumOverloadDims == 1):
64+
// - TS_F*, TS_I*: a scalar numeric type (half, float, i1, i64, etc.),
65+
// - TS_UDT: a pointer to a StructType representing a User Defined Type,
66+
// - TS_Object: a named StructType representing a built-in object, or
67+
// - TS_Vector: a vector type (<4 x float>, <16 x i16>, etc.)
68+
// For multiple overload dimensions (TS_Extended, NumOverloadDims > 1):
69+
// - an unnamed StructType containing each type for the corresponding
70+
// dimension, such as: type { i32, <2 x float> }
71+
// - contained type options are the same as for single dimension.
72+
6073
llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
74+
75+
// N-dimension convenience version of GetOpFunc:
6176
llvm::Function *GetOpFunc(OpCode OpCode,
62-
llvm::ArrayRef<llvm::Type *> ExtendedOverloads);
77+
llvm::ArrayRef<llvm::Type *> OverloadTypes);
78+
6379
const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &
6480
GetOpFuncList(OpCode OpCode) const;
6581
bool IsDxilOpUsed(OpCode opcode) const;
@@ -84,7 +100,7 @@ class OP {
84100

85101
llvm::Type *GetResRetType(llvm::Type *pOverloadType);
86102
llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
87-
llvm::Type *GetVectorType(unsigned numElements, llvm::Type *pOverloadType);
103+
llvm::Type *GetStructVectorType(unsigned numElements, llvm::Type *pOverloadType);
88104
bool IsResRetType(llvm::Type *Ty);
89105

90106
// Construct an unnamed struct type containing the set of member types.
@@ -145,6 +161,11 @@ class OP {
145161

146162
static bool IsDxilOpExtendedOverload(OpCode C);
147163

164+
// Return true if the overload name for this operation may be constructed
165+
// based on a type name that may not represent the same type in different
166+
// modules.
167+
static bool MayHaveNonCanonicalOverload(OpCode OC);
168+
148169
private:
149170
// Per-module properties.
150171
llvm::LLVMContext &m_Ctx;
@@ -168,15 +189,33 @@ class OP {
168189

169190
DXIL::LowPrecisionMode m_LowPrecisionMode;
170191

171-
static const unsigned kUserDefineTypeSlot = 9;
172-
static const unsigned kObjectTypeSlot = 10;
173-
static const unsigned kVectorTypeSlot = 11;
174-
static const unsigned kExtendedTypeSlot = 12;
175-
static const unsigned kNumTypeOverloads =
176-
13; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj, vec, extended
192+
// Overload types are split into "basic" overload types and special types
193+
// Basic: void, half, float, double, i1, i8, i16, i32, i64
194+
// - These have one canonical overload per TypeSlot
195+
// Special: udt, obj, vec, extended
196+
// - These may have many overloads per type slot
197+
enum TypeSlot : unsigned {
198+
TS_F16 = 0,
199+
TS_F32 = 1,
200+
TS_F64 = 2,
201+
TS_I1 = 3,
202+
TS_I8 = 4,
203+
TS_I16 = 5,
204+
TS_I32 = 6,
205+
TS_I64 = 7,
206+
TS_BasicCount,
207+
TS_UDT = 8, // Ex: %"struct.MyStruct" *
208+
TS_Object = 9, // Ex: %"class.StructuredBuffer<Foo>"
209+
TS_Vector = 10, // Ex: <8 x i16>
210+
TS_MaskBitCount, // Types used in Mask end here
211+
// TS_Extended is only used to identify the unnamed struct type used to wrap
212+
// multiple overloads when using GetTypeSlot.
213+
TS_Extended, // Ex: type { float, <16 x i32> }
214+
TS_Invalid = UINT_MAX,
215+
};
177216

178-
llvm::Type *m_pResRetType[kNumTypeOverloads];
179-
llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
217+
llvm::Type *m_pResRetType[TS_BasicCount];
218+
llvm::Type *m_pCBufferRetType[TS_BasicCount];
180219

181220
struct OpCodeCacheItem {
182221
llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> pOverloads;
@@ -190,10 +229,10 @@ class OP {
190229
struct OverloadMask {
191230
// mask of type slot bits as (1 << TypeSlot)
192231
uint16_t SlotMask;
193-
static_assert(kNumTypeOverloads <= (sizeof(SlotMask) * 8));
232+
static_assert(TS_MaskBitCount <= (sizeof(SlotMask) * 8));
194233
bool operator[](unsigned TypeSlot) const {
195-
return (TypeSlot < kNumTypeOverloads) ? (bool)(SlotMask & (1 << TypeSlot))
196-
: 0;
234+
return (TypeSlot < TS_MaskBitCount) ? (bool)(SlotMask & (1 << TypeSlot))
235+
: 0;
197236
}
198237
operator bool() const { return SlotMask != 0; }
199238
};
@@ -202,28 +241,21 @@ class OP {
202241
const char *pOpCodeName;
203242
OpCodeClass opCodeClass;
204243
const char *pOpCodeClassName;
205-
bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64,
206-
// udt, obj, vec, extended
207244
llvm::Attribute::AttrKind FuncAttr;
208245

209-
// Extended Type Overloads:
210-
// This is an encoding for a multi-dimensional overload.
211-
// 1. Only bAllowOverload[kExtendedTypeSlot] is set to true
212-
// 2. ExtendedOverloads defines allowed types for each overload index
213-
// 3. AllowedVectorElements defines allowed vector component types,
214-
// when kVectorTypeSlot bit is set for the corresponding overload index.
215-
// This includes when a single vector overload type is specified with
216-
// bAllowOverload[kVectorTypeSlot].
217-
218-
// A bit mask of allowed type slots per extended overload
219-
OverloadMask ExtendedOverloads[DXIL::kDxilMaxOloadDims];
220-
// A bit mask of allowed vector element types for the vector overload
221-
// or each corresponding extended vector overload.
246+
// Number of overload dimensions used by the operation.
247+
unsigned int NumOverloadDims;
248+
249+
// Mask of supported overload types for each overload dimension.
250+
OverloadMask AllowedOverloads[DXIL::kDxilMaxOloadDims];
251+
252+
// Mask of scalar components allowed for each demension where
253+
// AllowedOverloads[n][TS_Vector] is true.
222254
OverloadMask AllowedVectorElements[DXIL::kDxilMaxOloadDims];
223255
};
224256
static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
225257

226-
static const char *m_OverloadTypeName[kNumTypeOverloads];
258+
static const char *m_OverloadTypeName[TS_BasicCount];
227259
static const char *m_NamePrefix;
228260
static const char *m_TypePrefix;
229261
static const char *m_MatrixTypePrefix;

0 commit comments

Comments
 (0)