Skip to content

Commit cce6fe0

Browse files
[SPIR-V] HLSL2021: initial bitfield implementation (#4831)
This commit adds the logic in the SPIR-V backend to generate proper bitfields. Bitfield are packed using a first-fit method, linearly packing them, but not mixing types. Goal is to follow C/C++ rules. Bitfield merging was initially guessed from the offset stored in the FieldInfo. This offset is not always available and has a very specific meaning. When the struct is a function local variable, the layout rule is Void, meaning we shouldn't assume any kind of byte offset, but rely on construct index. This commit adds a fieldIndex member to the FieldInfo struct, and this field is used to determine if 2 fields are merged. When doing a buffer texture load, the struct must be extracted from a vector type, and rebuilt. This commit adds support for bitfield extraction for such types. Fixing this helped me see scalar assignment were also failling in some cases. Addressing bitfield extraction/insertion issues on with commit. Signed-off-by: Nathan Gauër <[email protected]> Co-authored-by: Cassandra Beckley <[email protected]>
1 parent 12515e8 commit cce6fe0

21 files changed

Lines changed: 931 additions & 207 deletions

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,21 @@ class SpirvBuilder {
174174
SourceLocation loc,
175175
SourceRange range = {});
176176

177-
/// \brief Creates a load instruction loading the value of the given
178-
/// <result-type> from the given pointer. Returns the instruction pointer for
179-
/// the loaded value.
180-
SpirvLoad *createLoad(QualType resultType, SpirvInstruction *pointer,
181-
SourceLocation loc, SourceRange range = {});
177+
/// \brief Creates a load sequence loading the value of the given
178+
/// <result-type> from the given pointer (load + optional extraction,
179+
/// ex:bitfield). Returns the instruction pointer for the loaded value.
180+
SpirvInstruction *createLoad(QualType resultType, SpirvInstruction *pointer,
181+
SourceLocation loc, SourceRange range = {});
182182
SpirvLoad *createLoad(const SpirvType *resultType, SpirvInstruction *pointer,
183183
SourceLocation loc, SourceRange range = {});
184184

185185
/// \brief Creates an OpCopyObject instruction from the given pointer.
186186
SpirvCopyObject *createCopyObject(QualType resultType,
187187
SpirvInstruction *pointer, SourceLocation);
188188

189-
/// \brief Creates a store instruction storing the given value into the given
189+
/// \brief Creates a store sequence storing the given value into the given
190190
/// address. Returns the instruction pointer for the store instruction.
191+
/// This function handles storing to bitfields.
191192
SpirvStore *createStore(SpirvInstruction *address, SpirvInstruction *value,
192193
SourceLocation loc, SourceRange range = {});
193194

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "clang/AST/APValue.h"
1515
#include "clang/AST/Type.h"
1616
#include "clang/Basic/SourceLocation.h"
17+
#include "clang/SPIRV/SpirvType.h"
1718
#include "llvm/ADT/APFloat.h"
1819
#include "llvm/ADT/APInt.h"
1920
#include "llvm/ADT/Optional.h"
@@ -203,6 +204,7 @@ class SpirvInstruction {
203204

204205
void setRValue(bool rvalue = true) { isRValue_ = rvalue; }
205206
bool isRValue() const { return isRValue_; }
207+
bool isLValue() const { return !isRValue_; }
206208

207209
void setRelaxedPrecision() { isRelaxedPrecision_ = true; }
208210
bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
@@ -213,6 +215,9 @@ class SpirvInstruction {
213215
void setPrecise(bool p = true) { isPrecise_ = p; }
214216
bool isPrecise() const { return isPrecise_; }
215217

218+
void setBitfieldInfo(const BitfieldInfo &info) { bitfieldInfo = info; }
219+
llvm::Optional<BitfieldInfo> getBitfieldInfo() const { return bitfieldInfo; }
220+
216221
/// Legalization-specific code
217222
///
218223
/// Note: the following two functions are currently needed in order to support
@@ -255,6 +260,7 @@ class SpirvInstruction {
255260
bool isRelaxedPrecision_;
256261
bool isNonUniform_;
257262
bool isPrecise_;
263+
llvm::Optional<BitfieldInfo> bitfieldInfo;
258264
};
259265

260266
/// \brief OpCapability instruction

tools/clang/include/clang/SPIRV/SpirvType.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ enum class StructInterfaceType : uint32_t {
3030
UniformBuffer = 2,
3131
};
3232

33+
struct BitfieldInfo {
34+
// Offset of the bitfield, in bits, from the basetype start.
35+
uint32_t offsetInBits;
36+
// Size of the bitfield, in bits.
37+
uint32_t sizeInBits;
38+
};
39+
3340
class SpirvType {
3441
public:
3542
enum Kind {
@@ -290,14 +297,16 @@ class StructType : public SpirvType {
290297
public:
291298
struct FieldInfo {
292299
public:
293-
FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
300+
FieldInfo(const SpirvType *type_, uint32_t fieldIndex_,
301+
llvm::StringRef name_ = "",
294302
llvm::Optional<uint32_t> offset_ = llvm::None,
295303
llvm::Optional<uint32_t> matrixStride_ = llvm::None,
296304
llvm::Optional<bool> isRowMajor_ = llvm::None,
297305
bool relaxedPrecision = false, bool precise = false)
298-
: type(type_), name(name_), offset(offset_), sizeInBytes(llvm::None),
299-
matrixStride(matrixStride_), isRowMajor(isRowMajor_),
300-
isRelaxedPrecision(relaxedPrecision), isPrecise(precise) {
306+
: type(type_), fieldIndex(fieldIndex_), name(name_), offset(offset_),
307+
sizeInBytes(llvm::None), matrixStride(matrixStride_),
308+
isRowMajor(isRowMajor_), isRelaxedPrecision(relaxedPrecision),
309+
isPrecise(precise) {
301310
// A StructType may not contain any hybrid types.
302311
assert(!isa<HybridType>(type_));
303312
}
@@ -306,6 +315,10 @@ class StructType : public SpirvType {
306315

307316
// The field's type.
308317
const SpirvType *type;
318+
// The index of this field in the composite construct.
319+
// When the struct contains bitfields, StructType index and construct index
320+
// can diverge as we merge bitfields together.
321+
uint32_t fieldIndex;
309322
// The field's name.
310323
std::string name;
311324
// The integer offset in bytes for this field.
@@ -320,6 +333,8 @@ class StructType : public SpirvType {
320333
bool isRelaxedPrecision;
321334
// Whether this field is marked as 'precise'.
322335
bool isPrecise;
336+
// Information about the bitfield (if applicable).
337+
llvm::Optional<BitfieldInfo> bitfield;
323338
};
324339

325340
StructType(
@@ -467,9 +482,11 @@ class HybridStructType : public HybridType {
467482
clang::VKOffsetAttr *offset = nullptr,
468483
hlsl::ConstantPacking *packOffset = nullptr,
469484
const hlsl::RegisterAssignment *regC = nullptr,
470-
bool precise = false)
485+
bool precise = false,
486+
llvm::Optional<BitfieldInfo> bitfield = llvm::None)
471487
: astType(astType_), name(name_), vkOffsetAttr(offset),
472-
packOffsetAttr(packOffset), registerC(regC), isPrecise(precise) {}
488+
packOffsetAttr(packOffset), registerC(regC), isPrecise(precise),
489+
bitfield(std::move(bitfield)) {}
473490

474491
// The field's type.
475492
QualType astType;
@@ -483,6 +500,9 @@ class HybridStructType : public HybridType {
483500
const hlsl::RegisterAssignment *registerC;
484501
// Whether this field is marked as 'precise'.
485502
bool isPrecise;
503+
// Whether this field is a bitfield or not. If set to false, bitfield width
504+
// value is undefined.
505+
llvm::Optional<BitfieldInfo> bitfield;
486506
};
487507

488508
HybridStructType(

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,8 +4292,9 @@ void DeclResultIdMapper::storeOutStageVarsToStorage(
42924292
}
42934293
auto *ptrToOutputStageVar = spvBuilder.createAccessChain(
42944294
outputControlPointType, found->second, {ctrlPointID}, /*loc=*/{});
4295-
auto *load = spvBuilder.createLoad(outputControlPointType,
4296-
ptrToOutputStageVar, /*loc=*/{});
4295+
auto *load =
4296+
spvBuilder.createLoad(outputControlPointType, ptrToOutputStageVar,
4297+
/*loc=*/{});
42974298
spvBuilder.createStore(ptr, load, /*loc=*/{});
42984299
return;
42994300
}

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "clang/SPIRV/String.h"
2424
// clang-format on
2525

26+
#include <functional>
27+
2628
namespace clang {
2729
namespace spirv {
2830

@@ -2326,6 +2328,17 @@ EmitTypeHandler::getOrCreateConstantComposite(SpirvConstantComposite *inst) {
23262328
return inst->getResultId();
23272329
}
23282330

2331+
static inline bool
2332+
isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
2333+
const StructType::FieldInfo &field) {
2334+
if (previous.fieldIndex == field.fieldIndex) {
2335+
// Right now, the only reason for those indices to be shared is if both
2336+
// are merged bitfields.
2337+
assert(previous.bitfield.hasValue() && field.bitfield.hasValue());
2338+
}
2339+
return previous.fieldIndex == field.fieldIndex;
2340+
}
2341+
23292342
uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
23302343
// First get the decorations that would apply to this type.
23312344
bool alreadyExists = false;
@@ -2447,24 +2460,32 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
24472460
}
24482461
// Structure types
24492462
else if (const auto *structType = dyn_cast<StructType>(type)) {
2450-
llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
2451-
size_t numFields = fields.size();
2463+
std::vector<std::reference_wrapper<const StructType::FieldInfo>>
2464+
fieldsToGenerate;
2465+
{
2466+
llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
2467+
for (size_t i = 0; i < fields.size(); ++i) {
2468+
if (i > 0 && isFieldMergeWithPrevious(fields[i - 1], fields[i]))
2469+
continue;
2470+
fieldsToGenerate.push_back(std::ref(fields[i]));
2471+
}
2472+
}
24522473

24532474
// Emit OpMemberName for the struct members.
2454-
for (size_t i = 0; i < numFields; ++i)
2455-
emitNameForType(fields[i].name, id, i);
2475+
for (size_t i = 0; i < fieldsToGenerate.size(); ++i)
2476+
emitNameForType(fieldsToGenerate[i].get().name, id, i);
24562477

24572478
llvm::SmallVector<uint32_t, 4> fieldTypeIds;
2458-
for (auto &field : fields) {
2459-
fieldTypeIds.push_back(emitType(field.type));
2460-
}
2479+
for (auto &field : fieldsToGenerate)
2480+
fieldTypeIds.push_back(emitType(field.get().type));
24612481

2462-
for (size_t i = 0; i < numFields; ++i) {
2463-
auto &field = fields[i];
2482+
for (size_t i = 0; i < fieldsToGenerate.size(); ++i) {
2483+
const auto &field = fieldsToGenerate[i].get();
24642484
// Offset decorations
2465-
if (field.offset.hasValue())
2485+
if (field.offset.hasValue()) {
24662486
emitDecoration(id, spv::Decoration::Offset, {field.offset.getValue()},
24672487
i);
2488+
}
24682489

24692490
// MatrixStride decorations
24702491
if (field.matrixStride.hasValue())

tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "LiteralTypeVisitor.h"
11+
#include "LowerTypeVisitor.h"
1112
#include "clang/SPIRV/AstTypeProbe.h"
1213
#include "clang/SPIRV/SpirvFunction.h"
1314

@@ -389,6 +390,11 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
389390
const auto *decl = structType->getDecl();
390391
size_t i = 0;
391392
for (const auto *field : decl->fields()) {
393+
// If the field is a bitfield, it might be squashed later when building
394+
// the SPIR-V type depending on context. This means indices starting
395+
// from this bitfield are not guaranteed, and we shouldn't touch them.
396+
if (field->isBitField())
397+
break;
392398
tryToUpdateInstLitType(constituents[i], field->getType());
393399
++i;
394400
}

0 commit comments

Comments
 (0)