Skip to content

Commit 139576e

Browse files
tex3dhekota
authored andcommitted
dxcopt: Support full container and restore extra data to module (#4845)
This modifies IDxcOptimizer::RunOptimizier to accept full DxilContainer input. When full container input is used, this restores some data that is stripped from the module and placed in various other container parts. Data restored: - Subobjects from RDAT - RootSignature from RTS0 - ViewID and I/O dependency data from PSV0 - Resource names and types/annotations from STAT Serialization of these to metadata in module bitcode output still requires hlsl-dxilemit step. (cherry picked from commit 2c3d965)
1 parent 389149f commit 139576e

6 files changed

Lines changed: 1597 additions & 68 deletions

File tree

include/dxc/DXIL/DxilModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ class DxilModule {
204204
void StripDebugRelatedCode();
205205
void RemoveUnusedTypeAnnotations();
206206

207+
// Copy resource reflection back to this module's resources.
208+
void RestoreResourceReflection(const DxilModule &SourceDM);
209+
207210
// Helper to remove dx.* metadata with source and compile options.
208211
// If the parameter `bReplaceWithDummyData` is true, the named metadata
209212
// are replaced with valid empty data that satisfy tools.

include/dxc/DxilContainer/DxilContainerAssembler.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/StringRef.h"
1717

1818
struct IStream;
19+
class DxilPipelineStateValidation;
1920

2021
namespace llvm {
2122
class Module;
@@ -51,6 +52,16 @@ DxilPartWriter *NewFeatureInfoWriter(const DxilModule &M);
5152
DxilPartWriter *NewPSVWriter(const DxilModule &M, uint32_t PSVVersion = UINT_MAX);
5253
DxilPartWriter *NewRDATWriter(const DxilModule &M);
5354

55+
// Store serialized ViewID data from DxilModule to PipelineStateValidation.
56+
void StoreViewIDStateToPSV(const uint32_t *pInputData,
57+
unsigned InputSizeInUInts,
58+
DxilPipelineStateValidation &PSV);
59+
// Load ViewID state from PSV back to DxilModule view state vector.
60+
// Pass nullptr for pOutputData to compute and return needed OutputSizeInUInts.
61+
unsigned LoadViewIDStateFromPSV(unsigned *pOutputData,
62+
unsigned OutputSizeInUInts,
63+
const DxilPipelineStateValidation &PSV);
64+
5465
// Unaligned is for matching container for validator version < 1.7.
5566
DxilContainerWriter *NewDxilContainerWriter(bool bUnaligned = false);
5667

lib/DXIL/DxilModule.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,62 @@ void DxilModule::RemoveUnusedTypeAnnotations() {
18731873
}
18741874

18751875

1876+
template <typename _T>
1877+
static void CopyResourceInfo(_T &TargetRes, const _T &SourceRes,
1878+
DxilTypeSystem &TargetTypeSys,
1879+
const DxilTypeSystem &SourceTypeSys) {
1880+
if (TargetRes.GetKind() != SourceRes.GetKind() ||
1881+
TargetRes.GetLowerBound() != SourceRes.GetLowerBound() ||
1882+
TargetRes.GetRangeSize() != SourceRes.GetRangeSize() ||
1883+
TargetRes.GetSpaceID() != SourceRes.GetSpaceID()) {
1884+
DXASSERT(false, "otherwise, resource details don't match");
1885+
return;
1886+
}
1887+
1888+
if (TargetRes.GetGlobalName().empty() && !SourceRes.GetGlobalName().empty()) {
1889+
TargetRes.SetGlobalName(SourceRes.GetGlobalName());
1890+
}
1891+
1892+
if (TargetRes.GetGlobalSymbol() && SourceRes.GetGlobalSymbol() &&
1893+
SourceRes.GetGlobalSymbol()->hasName()) {
1894+
TargetRes.GetGlobalSymbol()->setName(
1895+
SourceRes.GetGlobalSymbol()->getName());
1896+
}
1897+
1898+
Type *Ty = SourceRes.GetHLSLType();
1899+
TargetRes.SetHLSLType(Ty);
1900+
TargetTypeSys.CopyTypeAnnotation(Ty, SourceTypeSys);
1901+
}
1902+
1903+
void DxilModule::RestoreResourceReflection(const DxilModule &SourceDM) {
1904+
DxilTypeSystem &TargetTypeSys = GetTypeSystem();
1905+
const DxilTypeSystem &SourceTypeSys = SourceDM.GetTypeSystem();
1906+
if (GetCBuffers().size() != SourceDM.GetCBuffers().size() ||
1907+
GetSRVs().size() != SourceDM.GetSRVs().size() ||
1908+
GetUAVs().size() != SourceDM.GetUAVs().size() ||
1909+
GetSamplers().size() != SourceDM.GetSamplers().size()) {
1910+
DXASSERT(false, "otherwise, resource lists don't match");
1911+
return;
1912+
}
1913+
for (unsigned i = 0; i < GetCBuffers().size(); ++i) {
1914+
CopyResourceInfo(GetCBuffer(i), SourceDM.GetCBuffer(i), TargetTypeSys,
1915+
SourceTypeSys);
1916+
}
1917+
for (unsigned i = 0; i < GetSRVs().size(); ++i) {
1918+
CopyResourceInfo(GetSRV(i), SourceDM.GetSRV(i), TargetTypeSys,
1919+
SourceTypeSys);
1920+
}
1921+
for (unsigned i = 0; i < GetUAVs().size(); ++i) {
1922+
CopyResourceInfo(GetUAV(i), SourceDM.GetUAV(i), TargetTypeSys,
1923+
SourceTypeSys);
1924+
}
1925+
for (unsigned i = 0; i < GetSamplers().size(); ++i) {
1926+
CopyResourceInfo(GetSampler(i), SourceDM.GetSampler(i), TargetTypeSys,
1927+
SourceTypeSys);
1928+
}
1929+
}
1930+
1931+
18761932
void DxilModule::LoadDxilResources(const llvm::MDOperand &MDO) {
18771933
if (MDO.get() == nullptr)
18781934
return;

lib/DxilContainer/DxilContainerAssembler.cpp

Lines changed: 175 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,180 @@ DxilPartWriter *hlsl::NewFeatureInfoWriter(const DxilModule &M) {
442442
return new DxilFeatureInfoWriter(M);
443443
}
444444

445+
446+
//////////////////////////////////////////////////////////
447+
// Utility code for serializing/deserializing ViewID state
448+
449+
// Code for ComputeSeriaizedViewIDStateSizeInUInts copied from
450+
// ComputeViewIdState. It could be moved into some common location if this
451+
// ViewID serialization/deserialization code were moved out of here.
452+
static unsigned RoundUpToUINT(unsigned x) { return (x + 31) / 32; }
453+
static unsigned ComputeSeriaizedViewIDStateSizeInUInts(
454+
const PSVShaderKind SK, const bool bUsesViewID,
455+
const unsigned InputScalars, const unsigned OutputScalars[4],
456+
const unsigned PCScalars) {
457+
// Compute serialized state size in UINTs.
458+
unsigned NumStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
459+
unsigned Size = 0;
460+
Size += 1; // #Inputs.
461+
for (unsigned StreamId = 0; StreamId < NumStreams; StreamId++) {
462+
Size += 1; // #Outputs for stream StreamId.
463+
unsigned NumOutputs = OutputScalars[StreamId];
464+
unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
465+
if (bUsesViewID) {
466+
Size += NumOutUINTs; // m_OutputsDependentOnViewId[StreamId]
467+
}
468+
Size += InputScalars * NumOutUINTs; // m_InputsContributingToOutputs[StreamId]
469+
}
470+
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Domain || SK == PSVShaderKind::Mesh) {
471+
Size += 1; // #PatchConstant.
472+
unsigned NumPCUINTs = RoundUpToUINT(PCScalars);
473+
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
474+
if (bUsesViewID) {
475+
Size += NumPCUINTs; // m_PCOrPrimOutputsDependentOnViewId
476+
}
477+
Size += InputScalars * NumPCUINTs; // m_InputsContributingToPCOrPrimOutputs
478+
} else {
479+
unsigned NumOutputs = OutputScalars[0];
480+
unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
481+
Size += PCScalars * NumOutUINTs; // m_PCInputsContributingToOutputs
482+
}
483+
}
484+
return Size;
485+
}
486+
487+
static const uint32_t *CopyViewIDStateForOutputToPSV(
488+
const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars,
489+
PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) {
490+
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
491+
if (ViewIDMask.IsValid()) {
492+
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
493+
memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords);
494+
pSrc += MaskDwords;
495+
}
496+
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
497+
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
498+
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
499+
memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars);
500+
pSrc += MaskDwords * InputScalars;
501+
}
502+
return pSrc;
503+
}
504+
505+
static uint32_t *CopyViewIDStateForOutputFromPSV(uint32_t *pOutputData,
506+
const unsigned InputScalars,
507+
const unsigned OutputScalars,
508+
PSVComponentMask ViewIDMask,
509+
PSVDependencyTable IOTable) {
510+
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
511+
if (ViewIDMask.IsValid()) {
512+
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
513+
for (unsigned i = 0; i < MaskDwords; i++)
514+
*(pOutputData++) = ViewIDMask.Mask[i];
515+
}
516+
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
517+
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
518+
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
519+
for (unsigned i = 0; i < MaskDwords * InputScalars; i++)
520+
*(pOutputData++) = IOTable.Table[i];
521+
}
522+
return pOutputData;
523+
}
524+
525+
void hlsl::StoreViewIDStateToPSV(const uint32_t *pInputData,
526+
unsigned InputSizeInUInts,
527+
DxilPipelineStateValidation &PSV) {
528+
PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1();
529+
DXASSERT(pInfo1, "otherwise, PSV does not meet version requirement.");
530+
PSVShaderKind SK = static_cast<PSVShaderKind>(pInfo1->ShaderStage);
531+
const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
532+
const uint32_t *pSrc = pInputData;
533+
const uint32_t InputScalars = *(pSrc++);
534+
uint32_t OutputScalars[4];
535+
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
536+
OutputScalars[streamIndex] = *(pSrc++);
537+
pSrc = CopyViewIDStateForOutputToPSV(
538+
pSrc, InputScalars, OutputScalars[streamIndex],
539+
PSV.GetViewIDOutputMask(streamIndex),
540+
PSV.GetInputToOutputTable(streamIndex));
541+
}
542+
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
543+
const uint32_t PCScalars = *(pSrc++);
544+
pSrc = CopyViewIDStateForOutputToPSV(pSrc, InputScalars, PCScalars,
545+
PSV.GetViewIDPCOutputMask(),
546+
PSV.GetInputToPCOutputTable());
547+
} else if (SK == PSVShaderKind::Domain) {
548+
const uint32_t PCScalars = *(pSrc++);
549+
pSrc = CopyViewIDStateForOutputToPSV(pSrc, PCScalars, OutputScalars[0],
550+
PSVComponentMask(),
551+
PSV.GetPCInputToOutputTable());
552+
}
553+
DXASSERT(pSrc - pInputData == InputSizeInUInts,
554+
"otherwise, different amout of data written than expected.");
555+
}
556+
557+
// This function is defined close to the serialization code in DxilPSVWriter to
558+
// reduce the chance of a mismatch. It could be defined elsewhere, but it would
559+
// make sense to move both the serialization and deserialization out of here and
560+
// into a common location.
561+
unsigned hlsl::LoadViewIDStateFromPSV(unsigned *pOutputData,
562+
unsigned OutputSizeInUInts,
563+
const DxilPipelineStateValidation &PSV) {
564+
PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1();
565+
if (!pInfo1) {
566+
return 0;
567+
}
568+
PSVShaderKind SK = static_cast<PSVShaderKind>(pInfo1->ShaderStage);
569+
const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
570+
const unsigned InputScalars = pInfo1->SigInputVectors * 4;
571+
unsigned OutputScalars[4];
572+
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
573+
OutputScalars[streamIndex] = pInfo1->SigOutputVectors[streamIndex] * 4;
574+
}
575+
unsigned PCScalars = 0;
576+
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh ||
577+
SK == PSVShaderKind::Domain) {
578+
PCScalars = pInfo1->SigPatchConstOrPrimVectors * 4;
579+
}
580+
if (pOutputData == nullptr) {
581+
return ComputeSeriaizedViewIDStateSizeInUInts(
582+
SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars, PCScalars);
583+
}
584+
585+
// Fill in serialized viewid buffer.
586+
DXASSERT(ComputeSeriaizedViewIDStateSizeInUInts(
587+
SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars,
588+
PCScalars) == OutputSizeInUInts,
589+
"otherwise, OutputSize doesn't match computed size.");
590+
unsigned *pStartOutputData = pOutputData;
591+
*(pOutputData++) = InputScalars;
592+
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
593+
*(pOutputData++) = OutputScalars[streamIndex];
594+
pOutputData = CopyViewIDStateForOutputFromPSV(
595+
pOutputData, InputScalars, OutputScalars[streamIndex],
596+
PSV.GetViewIDOutputMask(streamIndex),
597+
PSV.GetInputToOutputTable(streamIndex));
598+
}
599+
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
600+
*(pOutputData++) = PCScalars;
601+
pOutputData = CopyViewIDStateForOutputFromPSV(
602+
pOutputData, InputScalars, PCScalars, PSV.GetViewIDPCOutputMask(),
603+
PSV.GetInputToPCOutputTable());
604+
} else if (SK == PSVShaderKind::Domain) {
605+
*(pOutputData++) = PCScalars;
606+
pOutputData = CopyViewIDStateForOutputFromPSV(
607+
pOutputData, PCScalars, OutputScalars[0], PSVComponentMask(),
608+
PSV.GetPCInputToOutputTable());
609+
}
610+
DXASSERT(pOutputData - pStartOutputData == OutputSizeInUInts,
611+
"otherwise, OutputSizeInUInts didn't match size written.");
612+
return pOutputData - pStartOutputData;
613+
}
614+
615+
616+
//////////////////////////////////////////////////////////
617+
// DxilPSVWriter - Writes PSV0 part
618+
445619
class DxilPSVWriter : public DxilPartWriter {
446620
private:
447621
const DxilModule &m_Module;
@@ -509,22 +683,6 @@ class DxilPSVWriter : public DxilPartWriter {
509683
E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF;
510684
}
511685

512-
const uint32_t *CopyViewIDState(const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars, PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) {
513-
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
514-
if (ViewIDMask.IsValid()) {
515-
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
516-
memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords);
517-
pSrc += MaskDwords;
518-
}
519-
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
520-
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
521-
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
522-
memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars);
523-
pSrc += MaskDwords * InputScalars;
524-
}
525-
return pSrc;
526-
}
527-
528686
public:
529687
DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX)
530688
: m_Module(mod),
@@ -840,23 +998,7 @@ class DxilPSVWriter : public DxilPartWriter {
840998
// Gather ViewID dependency information
841999
auto &viewState = m_Module.GetSerializedViewIdState();
8421000
if (!viewState.empty()) {
843-
const uint32_t *pSrc = viewState.data();
844-
const uint32_t InputScalars = *(pSrc++);
845-
uint32_t OutputScalars[4];
846-
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
847-
OutputScalars[streamIndex] = *(pSrc++);
848-
pSrc = CopyViewIDState(pSrc, InputScalars, OutputScalars[streamIndex], m_PSV.GetViewIDOutputMask(streamIndex), m_PSV.GetInputToOutputTable(streamIndex));
849-
if (!SM->IsGS())
850-
break;
851-
}
852-
if (SM->IsHS() || SM->IsMS()) {
853-
const uint32_t PCScalars = *(pSrc++);
854-
pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable());
855-
} else if (SM->IsDS()) {
856-
const uint32_t PCScalars = *(pSrc++);
857-
pSrc = CopyViewIDState(pSrc, PCScalars, OutputScalars[0], PSVComponentMask(), m_PSV.GetPCInputToOutputTable());
858-
}
859-
DXASSERT_NOMSG(viewState.data() + viewState.size() == pSrc);
1001+
StoreViewIDStateToPSV(viewState.data(), (unsigned)viewState.size(), m_PSV);
8601002
}
8611003
}
8621004

0 commit comments

Comments
 (0)