Skip to content

Commit 5ac9138

Browse files
JoeCitizenJack Elliotttex3d
authored
Implementation of GroupSharedLimit to allow increased GroupSharedMemory (#7871)
Add a new HLSL attribute for Compute, Amp and Mesh shaders: GroupSharedLimit. This is used to limit the amount of group shared memory a shader is allowed to statically declare, and validation will fail if the limit is exceeded. There is no upper limit on this attribute, and it is expected that shader writers set the limit as the lowest common denominator for their target hardware and software use case (typically 48k or 64k for modern GPUs). If no attribute is declared the existing 32k limit is used to be compatible with existing shaders. Extends the PSV structures to include the selected limit so that runtime validation can reject the shader if it exceeds the device support. --------- Co-authored-by: Jack Elliott <[email protected]> Co-authored-by: Tex Riddell <[email protected]>
1 parent d6b78b7 commit 5ac9138

21 files changed

Lines changed: 331 additions & 35 deletions

include/dxc/DXIL/DxilFunctionProps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ struct DxilFunctionProps {
117117
memset(&Node, 0, sizeof(Node));
118118
Node.LaunchType = DXIL::NodeLaunchType::Invalid;
119119
Node.LocalRootArgumentsTableIndex = -1;
120+
groupSharedLimitBytes = 0;
120121
}
121122
union {
122123
// Geometry shader.
@@ -174,6 +175,8 @@ struct DxilFunctionProps {
174175
// numThreads shared between multiple shader types and node shaders.
175176
unsigned numThreads[3];
176177

178+
unsigned groupSharedLimitBytes;
179+
177180
struct NodeProps {
178181
DXIL::NodeLaunchType LaunchType = DXIL::NodeLaunchType::Invalid;
179182
bool IsProgramEntry;

include/dxc/DXIL/DxilMetadataHelper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ class DxilMDHelper {
320320
static const unsigned kDxilNodeOutputsTag = 21;
321321
static const unsigned kDxilNodeMaxDispatchGridTag = 22;
322322
static const unsigned kDxilRangedWaveSizeTag = 23;
323+
static const unsigned kDxilGroupSharedLimitTag = 24;
323324

324325
// Node Input/Output State.
325326
static const unsigned kDxilNodeOutputIDTag = 0;

include/dxc/DXIL/DxilModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ class DxilModule {
254254
void SetNumThreads(unsigned x, unsigned y, unsigned z);
255255
unsigned GetNumThreads(unsigned idx) const;
256256

257+
unsigned GetGroupSharedLimit() const;
258+
// The total amount of group shared memory (in bytes) used by the shader.
259+
unsigned GetTGSMSizeInBytes() const;
260+
257261
// Compute shader
258262
DxilWaveSize &GetWaveSize();
259263
const DxilWaveSize &GetWaveSize() const;

include/dxc/DxilContainer/DxilPipelineStateValidation.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 {
175175
uint32_t EntryFunctionName;
176176
};
177177

178+
struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 {
179+
uint32_t NumBytesGroupSharedMemory;
180+
};
181+
178182
enum class PSVResourceType {
179183
Invalid = 0,
180184

@@ -474,7 +478,7 @@ class PSVSignatureElement {
474478
const uint32_t *SemanticIndexes) const;
475479
};
476480

477-
#define MAX_PSV_VERSION 3
481+
#define MAX_PSV_VERSION 4
478482

479483
struct PSVInitInfo {
480484
PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {}
@@ -491,7 +495,7 @@ struct PSVInitInfo {
491495
uint8_t SigPatchConstOrPrimVectors = 0;
492496
uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0};
493497

494-
static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating.");
498+
static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating.");
495499
uint32_t RuntimeInfoSize() const {
496500
switch (PSVVersion) {
497501
case 0:
@@ -500,10 +504,12 @@ struct PSVInitInfo {
500504
return sizeof(PSVRuntimeInfo1);
501505
case 2:
502506
return sizeof(PSVRuntimeInfo2);
507+
case 3:
508+
return sizeof(PSVRuntimeInfo3);
503509
default:
504510
break;
505511
}
506-
return sizeof(PSVRuntimeInfo3);
512+
return sizeof(PSVRuntimeInfo4);
507513
}
508514
uint32_t ResourceBindInfoSize() const {
509515
if (PSVVersion < 2)
@@ -519,6 +525,7 @@ class DxilPipelineStateValidation {
519525
PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr;
520526
PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr;
521527
PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr;
528+
PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr;
522529
uint32_t m_uResourceCount = 0;
523530
uint32_t m_uPSVResourceBindInfoSize = 0;
524531
void *m_pPSVResourceBindInfo = nullptr;
@@ -634,6 +641,8 @@ class DxilPipelineStateValidation {
634641

635642
PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; }
636643

644+
PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; }
645+
637646
uint32_t GetBindCount() const { return m_uResourceCount; }
638647

639648
template <typename _T>
@@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize,
949958
m_uPSVRuntimeInfoSize); // failure ok
950959
AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0,
951960
m_uPSVRuntimeInfoSize); // failure ok
961+
AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0,
962+
m_uPSVRuntimeInfoSize); // failure ok
952963

953964
// In RWMode::CalcSize, use temp runtime info to hold needed values from
954965
// initInfo
@@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);
11371148
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
11381149
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);
11391150
void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM);
1151+
void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM);
11401152

11411153
void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
11421154
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
1143-
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
1144-
const char *EntryName, const char *Comment);
1155+
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
1156+
uint8_t ShaderKind, const char *EntryName,
1157+
const char *Comment);
11451158

11461159
} // namespace hlsl
11471160

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,6 +1624,13 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
16241624
}
16251625
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
16261626
}
1627+
1628+
const hlsl::ShaderModel *SM = GetShaderModel();
1629+
if (SM->IsSMAtLeast(6, 10)) {
1630+
MDVals.emplace_back(
1631+
Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag));
1632+
MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes));
1633+
}
16271634
} break;
16281635
// Geometry shader.
16291636
case DXIL::ShaderKind::Geometry: {
@@ -1773,6 +1780,13 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
17731780
props.numThreads[2] = ConstMDToUint32(pNode->getOperand(2));
17741781
} break;
17751782

1783+
case DxilMDHelper::kDxilGroupSharedLimitTag: {
1784+
DXASSERT(props.IsCS(), "else invalid shader kind");
1785+
props.groupSharedLimitBytes = ConstMDToUint32(MDO);
1786+
if (!m_pSM->IsSMAtLeast(6, 10))
1787+
m_bExtraMetadata = true;
1788+
} break;
1789+
17761790
case DxilMDHelper::kDxilGSStateTag: {
17771791
DXASSERT(props.IsGS(), "else invalid shader kind");
17781792
auto &GS = props.ShaderProps.GS;

lib/DXIL/DxilModule.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,28 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const {
412412
return props.numThreads[idx];
413413
}
414414

415+
unsigned DxilModule::GetGroupSharedLimit() const {
416+
DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
417+
(m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
418+
"only works for CS/MS/AS profiles");
419+
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
420+
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
421+
return props.groupSharedLimitBytes;
422+
}
423+
424+
unsigned DxilModule::GetTGSMSizeInBytes() const {
425+
const DataLayout &DL = m_pModule->getDataLayout();
426+
unsigned TGSMSize = 0;
427+
428+
for (GlobalVariable &GV : m_pModule->globals()) {
429+
if (GV.getType()->getAddressSpace() == DXIL::kTGSMAddrSpace) {
430+
TGSMSize += DL.getTypeAllocSize(GV.getType()->getElementType());
431+
}
432+
}
433+
434+
return TGSMSize;
435+
}
436+
415437
DxilWaveSize &DxilModule::GetWaveSize() {
416438
return const_cast<DxilWaveSize &>(
417439
static_cast<const DxilModule *>(this)->GetWaveSize());

lib/DxilContainer/DxilContainerAssembler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter {
798798
PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1();
799799
PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2();
800800
PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3();
801+
PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4();
802+
801803
if (pInfo)
802804
hlsl::SetShaderProps(pInfo, m_Module);
803805
if (pInfo1)
@@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter {
806808
hlsl::SetShaderProps(pInfo2, m_Module);
807809
if (pInfo3)
808810
pInfo3->EntryFunctionName = EntryFunctionName;
811+
if (pInfo4)
812+
hlsl::SetShaderProps(pInfo4, m_Module);
809813

810814
// Set resource binding information
811815
UINT uResIndex = 0;

lib/DxilContainer/DxilPipelineStateValidation.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ uint32_t hlsl::GetPSVVersion(uint32_t ValMajor, uint32_t ValMinor) {
3333
PSVVersion = 1;
3434
else if (DXIL::CompareVersions(ValMajor, ValMinor, 1, 8) < 0)
3535
PSVVersion = 2;
36+
else if (DXIL::CompareVersions(ValMajor, ValMinor, 1, 10) < 0)
37+
PSVVersion = 3;
3638
return PSVVersion;
3739
}
3840

@@ -305,6 +307,20 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) {
305307
}
306308
}
307309

310+
void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
311+
assert(pInfo4);
312+
const ShaderModel *SM = DM.GetShaderModel();
313+
switch (SM->GetKind()) {
314+
case ShaderModel::Kind::Compute:
315+
case ShaderModel::Kind::Mesh:
316+
case ShaderModel::Kind::Amplification:
317+
pInfo4->NumBytesGroupSharedMemory = DM.GetTGSMSizeInBytes();
318+
break;
319+
default:
320+
break;
321+
}
322+
}
323+
308324
void PSVResourceBindInfo0::Print(raw_ostream &OS) const {
309325
OS << "PSVResourceBindInfo:\n";
310326
OS << " Space: " << Space << "\n";
@@ -584,8 +600,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName,
584600

585601
void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
586602
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
587-
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
588-
const char *EntryName, const char *Comment) {
603+
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
604+
uint8_t ShaderKind, const char *EntryName,
605+
const char *Comment) {
589606
if (pInfo1 && pInfo1->ShaderStage != ShaderKind)
590607
ShaderKind = pInfo1->ShaderStage;
591608
OS << Comment << "PSVRuntimeInfo:\n";
@@ -808,13 +825,23 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
808825
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
809826
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
810827
}
828+
if (pInfo4) {
829+
OS << Comment
830+
<< " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory
831+
<< "\n";
832+
}
811833
break;
812834
case PSVShaderKind::Amplification:
813835
OS << Comment << " Amplification Shader\n";
814836
if (pInfo2) {
815837
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
816838
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
817839
}
840+
if (pInfo4) {
841+
OS << Comment
842+
<< " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory
843+
<< "\n";
844+
}
818845
break;
819846
case PSVShaderKind::Mesh:
820847
OS << Comment << " Mesh Shader\n";
@@ -841,6 +868,11 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
841868
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
842869
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
843870
}
871+
if (pInfo4) {
872+
OS << Comment
873+
<< " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory
874+
<< "\n";
875+
}
844876
break;
845877
case PSVShaderKind::Library:
846878
case PSVShaderKind::Invalid:
@@ -887,9 +919,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo(
887919
PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1;
888920
PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2;
889921
PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3;
922+
PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4;
890923

891924
hlsl::PrintPSVRuntimeInfo(
892-
OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind,
925+
OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind,
893926
m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "",
894927
Comment);
895928
}

lib/DxilValidation/DxilContainerValidation.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ class PSVContentVerifier {
185185
unsigned PSVVersion);
186186
void VerifyViewIDDependence(PSVRuntimeInfo1 *PSV1, unsigned PSVVersion);
187187
void VerifyEntryProperties(const ShaderModel *SM, PSVRuntimeInfo0 *PSV0,
188-
PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2);
188+
PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2,
189+
PSVRuntimeInfo3 *PSV3, PSVRuntimeInfo4 *PSV4);
189190
void EmitMismatchError(StringRef Name, StringRef PartContent,
190191
StringRef ModuleContent) {
191192
ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches,
@@ -409,16 +410,16 @@ void PSVContentVerifier::VerifyResources(unsigned PSVVersion) {
409410
VerifyResourceTable(DM.GetUAVs(), ResIndex, PSVVersion);
410411
}
411412

412-
void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
413-
PSVRuntimeInfo0 *PSV0,
414-
PSVRuntimeInfo1 *PSV1,
415-
PSVRuntimeInfo2 *PSV2) {
416-
PSVRuntimeInfo3 DMPSV;
417-
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3));
413+
void PSVContentVerifier::VerifyEntryProperties(
414+
const ShaderModel *SM, PSVRuntimeInfo0 *PSV0, PSVRuntimeInfo1 *PSV1,
415+
PSVRuntimeInfo2 *PSV2, PSVRuntimeInfo3 *PSV3, PSVRuntimeInfo4 *PSV4) {
416+
PSVRuntimeInfo4 DMPSV;
417+
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4));
418418

419419
hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM);
420420
hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM);
421421
hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM);
422+
hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM);
422423
if (PSV1) {
423424
// Init things not set in InitPSVRuntimeInfo.
424425
DMPSV.ShaderStage = static_cast<uint8_t>(SM->GetKind());
@@ -444,10 +445,14 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
444445
else
445446
Mismatched = memcmp(PSV0, &DMPSV, sizeof(PSVRuntimeInfo0)) != 0;
446447

448+
if (PSV4 &&
449+
PSV4->NumBytesGroupSharedMemory != DMPSV.NumBytesGroupSharedMemory)
450+
Mismatched = true;
451+
447452
if (Mismatched) {
448453
std::string Str;
449454
raw_string_ostream OS(Str);
450-
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
455+
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
451456
static_cast<uint8_t>(SM->GetKind()),
452457
DM.GetEntryFunctionName().c_str(), "");
453458
OS.flush();
@@ -476,9 +481,11 @@ void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor,
476481
PSVRuntimeInfo0 *PSV0 = PSV.GetPSVRuntimeInfo0();
477482
PSVRuntimeInfo1 *PSV1 = PSV.GetPSVRuntimeInfo1();
478483
PSVRuntimeInfo2 *PSV2 = PSV.GetPSVRuntimeInfo2();
484+
PSVRuntimeInfo3 *PSV3 = PSV.GetPSVRuntimeInfo3();
485+
PSVRuntimeInfo4 *PSV4 = PSV.GetPSVRuntimeInfo4();
479486

480487
const ShaderModel *SM = DM.GetShaderModel();
481-
VerifyEntryProperties(SM, PSV0, PSV1, PSV2);
488+
VerifyEntryProperties(SM, PSV0, PSV1, PSV2, PSV3, PSV4);
482489
if (PSVVersion > 0) {
483490
if (((PSV.GetSigInputElements() + PSV.GetSigOutputElements() +
484491
PSV.GetSigPatchConstOrPrimElements()) > 0) &&

lib/DxilValidation/DxilValidation.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3935,6 +3935,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) {
39353935
Rule = ValidationRule::SmMaxMSSMSize;
39363936
MaxSize = DXIL::kMaxMSSMSize;
39373937
}
3938+
3939+
// Check if the entry function has attribute to override TGSM size.
3940+
if (M.HasDxilEntryProps(M.GetEntryFunction())) {
3941+
DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction());
3942+
if (EntryProps.props.IsCS()) {
3943+
unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes;
3944+
if (SpecifiedTGSMSize > 0) {
3945+
MaxSize = SpecifiedTGSMSize;
3946+
}
3947+
}
3948+
}
3949+
39383950
if (TGSMSize > MaxSize) {
39393951
Module::global_iterator GI = M.GetModule()->global_end();
39403952
GlobalVariable *GV = &*GI;

0 commit comments

Comments
 (0)