Skip to content

Commit 99c7810

Browse files
authored
Add alignment to NodeRecordType including DXIL metadata update (#6279) (#6374)
This change adds NodeRecordType alignment field to RDAT to make it possible to validate pointer and stride alignment in the runtime. This includes a change to DXIL metadata to preserve the record alignment without requiring recovery by looking for GetNodeRecordPtr. Fixes #6270 (cherry picked from commit 66ba5a1)
1 parent 9403822 commit 99c7810

18 files changed

Lines changed: 333 additions & 118 deletions

include/dxc/DXIL/DxilMetadataHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ class DxilMDHelper {
332332
// Node Record Type
333333
static const unsigned kDxilNodeRecordSizeTag = 0;
334334
static const unsigned kDxilNodeSVDispatchGridTag = 1;
335+
static const unsigned kDxilNodeRecordAlignmentTag = 2;
335336

336337
// GSState.
337338
static const unsigned kDxilGSStateNumFields = 5;
@@ -624,6 +625,7 @@ class DxilMDHelper {
624625
unsigned &payloadSizeInBytes);
625626

626627
llvm::MDTuple *EmitDxilNodeIOState(const NodeIOProperties &Node);
628+
llvm::MDTuple *EmitDxilNodeRecordType(const NodeRecordType &RecordType);
627629
hlsl::NodeIOProperties LoadDxilNodeIOState(const llvm::MDOperand &MDO);
628630
hlsl::NodeRecordType LoadDxilNodeRecordType(const llvm::MDOperand &MDO);
629631

include/dxc/DXIL/DxilNodeProps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct SVDispatchGrid {
4545
//
4646
struct NodeRecordType {
4747
unsigned size;
48+
unsigned alignment;
4849
SVDispatchGrid SV_DispatchGrid;
4950
};
5051

include/dxc/DxilContainer/RDAT_LibraryTypes.inl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ RDAT_ENUM_START(NodeAttribKind, uint32_t)
277277
RDAT_ENUM_VALUE(RecordDispatchGrid, 5)
278278
RDAT_ENUM_VALUE(OutputArraySize, 6)
279279
RDAT_ENUM_VALUE(AllowSparseNodes, 7)
280+
RDAT_ENUM_VALUE(RecordAlignmentInBytes, 8)
280281
RDAT_ENUM_VALUE_NODEF(LastValue)
281282
RDAT_ENUM_END()
282283

@@ -407,6 +408,10 @@ RDAT_STRUCT_TABLE(NodeShaderIOAttrib, NodeShaderIOAttribTable)
407408
getAttribKind() ==
408409
hlsl::RDAT::NodeAttribKind::AllowSparseNodes)
409410
RDAT_VALUE(uint32_t, AllowSparseNodes)
411+
RDAT_UNION_ELIF(RecordAlignmentInBytes,
412+
getAttribKind() ==
413+
hlsl::RDAT::NodeAttribKind::RecordAlignmentInBytes)
414+
RDAT_VALUE(uint32_t, RecordAlignmentInBytes)
410415
RDAT_UNION_ENDIF()
411416
RDAT_UNION_END()
412417
RDAT_STRUCT_END()

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,7 @@ void DxilMDHelper::SerializeNodeProps(SmallVectorImpl<llvm::Metadata *> &MDVals,
19671967
nodeinput.RecordType.SV_DispatchGrid.ComponentType)));
19681968
MDVals.push_back(
19691969
Uint32ToConstMD(nodeinput.RecordType.SV_DispatchGrid.NumComponents));
1970+
MDVals.push_back(Uint32ToConstMD(nodeinput.RecordType.alignment));
19701971
}
19711972
for (auto &nodeoutput : props->OutputNodes) {
19721973
MDVals.push_back(Uint32ToConstMD(nodeoutput.Flags));
@@ -1983,6 +1984,7 @@ void DxilMDHelper::SerializeNodeProps(SmallVectorImpl<llvm::Metadata *> &MDVals,
19831984
MDVals.push_back(Int32ToConstMD(nodeoutput.MaxRecordsSharedWith));
19841985
MDVals.push_back(Uint32ToConstMD(nodeoutput.OutputArraySize));
19851986
MDVals.push_back(BoolToConstMD(nodeoutput.AllowSparseNodes));
1987+
MDVals.push_back(Uint32ToConstMD(nodeoutput.RecordType.alignment));
19861988
}
19871989
}
19881990

@@ -2019,6 +2021,10 @@ void DxilMDHelper::DeserializeNodeProps(const MDTuple *pProps, unsigned &idx,
20192021
ConstMDToUint32(pProps->getOperand(idx++)));
20202022
nodeinput.RecordType.SV_DispatchGrid.NumComponents =
20212023
ConstMDToUint32(pProps->getOperand(idx++));
2024+
if (pProps->getNumOperands() > idx) {
2025+
nodeinput.RecordType.alignment =
2026+
ConstMDToUint32(pProps->getOperand(idx++));
2027+
}
20222028
}
20232029

20242030
for (auto &nodeoutput : props->OutputNodes) {
@@ -2037,6 +2043,10 @@ void DxilMDHelper::DeserializeNodeProps(const MDTuple *pProps, unsigned &idx,
20372043
nodeoutput.MaxRecordsSharedWith = ConstMDToInt32(pProps->getOperand(idx++));
20382044
nodeoutput.OutputArraySize = ConstMDToUint32(pProps->getOperand(idx++));
20392045
nodeoutput.AllowSparseNodes = ConstMDToBool(pProps->getOperand(idx++));
2046+
if (pProps->getNumOperands() > idx) {
2047+
nodeoutput.RecordType.alignment =
2048+
ConstMDToUint32(pProps->getOperand(idx++));
2049+
}
20402050
}
20412051
}
20422052

@@ -2755,6 +2765,32 @@ void DxilMDHelper::EmitDxilNodeState(std::vector<llvm::Metadata *> &MDVals,
27552765
}
27562766
}
27572767

2768+
llvm::MDTuple *
2769+
DxilMDHelper::EmitDxilNodeRecordType(const NodeRecordType &RecordType) {
2770+
vector<Metadata *> MDVals;
2771+
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordSizeTag));
2772+
MDVals.emplace_back(Uint32ToConstMD(RecordType.size));
2773+
2774+
if (RecordType.SV_DispatchGrid.NumComponents) {
2775+
MDVals.emplace_back(
2776+
Uint32ToConstMD(DxilMDHelper::kDxilNodeSVDispatchGridTag));
2777+
vector<Metadata *> SVDispatchGridVals;
2778+
SVDispatchGridVals.emplace_back(
2779+
Uint32ToConstMD(RecordType.SV_DispatchGrid.ByteOffset));
2780+
SVDispatchGridVals.emplace_back(Uint32ToConstMD(
2781+
static_cast<unsigned>(RecordType.SV_DispatchGrid.ComponentType)));
2782+
SVDispatchGridVals.emplace_back(
2783+
Uint32ToConstMD(RecordType.SV_DispatchGrid.NumComponents));
2784+
MDVals.emplace_back(MDNode::get(m_Ctx, SVDispatchGridVals));
2785+
}
2786+
if (RecordType.alignment) {
2787+
MDVals.emplace_back(
2788+
Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordAlignmentTag));
2789+
MDVals.emplace_back(Uint32ToConstMD(RecordType.alignment));
2790+
}
2791+
return MDNode::get(m_Ctx, MDVals);
2792+
}
2793+
27582794
llvm::MDTuple *
27592795
DxilMDHelper::EmitDxilNodeIOState(const hlsl::NodeIOProperties &Node) {
27602796
vector<Metadata *> MDVals;
@@ -2763,24 +2799,7 @@ DxilMDHelper::EmitDxilNodeIOState(const hlsl::NodeIOProperties &Node) {
27632799

27642800
if (Node.RecordType.size) {
27652801
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordTypeTag));
2766-
vector<Metadata *> NodeRecordTypeVals;
2767-
NodeRecordTypeVals.emplace_back(
2768-
Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordSizeTag));
2769-
NodeRecordTypeVals.emplace_back(Uint32ToConstMD(Node.RecordType.size));
2770-
// If the record has a SV_DispatchGrid field
2771-
if (Node.RecordType.SV_DispatchGrid.NumComponents) {
2772-
NodeRecordTypeVals.emplace_back(
2773-
Uint32ToConstMD(DxilMDHelper::kDxilNodeSVDispatchGridTag));
2774-
vector<Metadata *> SVDispatchGridVals;
2775-
SVDispatchGridVals.emplace_back(
2776-
Uint32ToConstMD(Node.RecordType.SV_DispatchGrid.ByteOffset));
2777-
SVDispatchGridVals.emplace_back(Uint32ToConstMD(static_cast<unsigned>(
2778-
Node.RecordType.SV_DispatchGrid.ComponentType)));
2779-
SVDispatchGridVals.emplace_back(
2780-
Uint32ToConstMD(Node.RecordType.SV_DispatchGrid.NumComponents));
2781-
NodeRecordTypeVals.emplace_back(MDNode::get(m_Ctx, SVDispatchGridVals));
2782-
}
2783-
MDVals.emplace_back(MDNode::get(m_Ctx, NodeRecordTypeVals));
2802+
MDVals.emplace_back(EmitDxilNodeRecordType(Node.RecordType));
27842803
}
27852804

27862805
if (Node.Flags.IsOutputNode()) {
@@ -2856,6 +2875,9 @@ DxilMDHelper::LoadDxilNodeRecordType(const llvm::MDOperand &MDO) {
28562875
Record.SV_DispatchGrid.NumComponents =
28572876
ConstMDToUint32(pSVDTupleMD->getOperand(2));
28582877
} break;
2878+
case DxilMDHelper::kDxilNodeRecordAlignmentTag: {
2879+
Record.alignment = ConstMDToUint32(MDO);
2880+
} break;
28592881
default:
28602882
m_bExtraMetadata = true;
28612883
break;

lib/DxilContainer/DxilContainerAssembler.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,13 @@ class DxilRDATWriter : public DxilPartWriter {
14381438
N.RecordType.SV_DispatchGrid.NumComponents);
14391439
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
14401440
}
1441+
1442+
if (N.RecordType.alignment) {
1443+
nAttrib = {};
1444+
nAttrib.AttribKind = (uint32_t)NodeAttribKind::RecordAlignmentInBytes;
1445+
nAttrib.RecordAlignmentInBytes = N.RecordType.alignment;
1446+
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
1447+
}
14411448
}
14421449

14431450
ioNode.Attribs =

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2620,8 +2620,10 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
26202620
}
26212621

26222622
// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
2623-
// size(MY_RECORD)
2623+
// size(MY_RECORD), alignment = alignof(MY_RECORD)
26242624
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
2625+
node.RecordType.alignment =
2626+
CGM.getDataLayout().getABITypeAlignment(Type);
26252627
// Iterate over fields of the MY_RECORD(example) struct
26262628
for (auto fieldDecl : RD->fields()) {
26272629
// Check if any of the fields have a semantic annotation =

tools/clang/test/CodeGenDXIL/hlsl/objects/NodeObjects/array-in-workgraphrecord-1.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
99
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
1010
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
11-
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
11+
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68, i32 2, i32 4}
1212
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
1313
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
1414

tools/clang/test/CodeGenDXIL/hlsl/objects/NodeObjects/array-in-workgraphrecord-2.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
99
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
1010
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
11-
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
11+
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68, i32 2, i32 4}
1212
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
1313
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
1414

0 commit comments

Comments
 (0)