Skip to content

Commit 5a31765

Browse files
adam-yangtex3d
authored andcommitted
Added way for caller to replace args in PDB utils (#3595)
(cherry picked from commit 640c9af)
1 parent 6590576 commit 5a31765

3 files changed

Lines changed: 181 additions & 46 deletions

File tree

include/dxc/dxcapi.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,11 @@ struct IDxcVersionInfo3 : public IUnknown {
582582
) = 0;
583583
};
584584

585+
struct DxcArgPair {
586+
const WCHAR *pName;
587+
const WCHAR *pValue;
588+
};
589+
585590
CROSS_PLATFORM_UUIDOF(IDxcPdbUtils, "E6C9647E-9D6A-4C3B-B94C-524B5A6C343D")
586591
struct IDxcPdbUtils : public IUnknown {
587592
virtual HRESULT STDMETHODCALLTYPE Load(_In_ IDxcBlob *pPdbOrDxil) = 0;
@@ -616,6 +621,8 @@ struct IDxcPdbUtils : public IUnknown {
616621

617622
virtual HRESULT STDMETHODCALLTYPE SetCompiler(_In_ IDxcCompiler3 *pCompiler) = 0;
618623
virtual HRESULT STDMETHODCALLTYPE CompileForFullPDB(_COM_Outptr_ IDxcResult **ppResult) = 0;
624+
virtual HRESULT STDMETHODCALLTYPE OverrideArgs(_In_ DxcArgPair *pArgPairs, UINT32 uNumArgPairs) = 0;
625+
virtual HRESULT STDMETHODCALLTYPE OverrideRootSignature(_In_ const WCHAR *pRootSignature) = 0;
619626
};
620627

621628
// Note: __declspec(selectany) requires 'extern'

tools/clang/tools/dxcompiler/dxcpdbutils.cpp

Lines changed: 126 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,7 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
267267
CComPtr<IDxcBlob> m_pDebugProgramBlob;
268268
CComPtr<IDxcBlob> m_ContainerBlob;
269269
std::vector<Source_File> m_SourceFiles;
270-
std::vector<std::wstring> m_Defines;
271-
std::vector<std::wstring> m_Args;
272-
std::vector<std::wstring> m_Flags;
270+
273271
std::wstring m_EntryPoint;
274272
std::wstring m_TargetProfile;
275273
std::wstring m_Name;
@@ -290,26 +288,33 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
290288
std::wstring Value;
291289
};
292290
std::vector<ArgPair> m_ArgPairs;
291+
std::vector<std::wstring> m_Defines;
292+
std::vector<std::wstring> m_Args;
293+
std::vector<std::wstring> m_Flags;
293294

294-
void Reset() {
295-
m_pDebugProgramBlob = nullptr;
296-
m_InputBlob = nullptr;
297-
m_ContainerBlob = nullptr;
298-
m_SourceFiles.clear();
295+
void ResetAllArgs() {
296+
m_ArgPairs.clear();
299297
m_Defines.clear();
300298
m_Args.clear();
301299
m_Flags.clear();
302300
m_EntryPoint.clear();
303301
m_TargetProfile.clear();
302+
}
303+
304+
void Reset() {
305+
m_pDebugProgramBlob = nullptr;
306+
m_InputBlob = nullptr;
307+
m_ContainerBlob = nullptr;
308+
m_SourceFiles.clear();
304309
m_Name.clear();
305310
m_MainFileName.clear();
306311
m_HashBlob = nullptr;
307312
m_HasVersionInfo = false;
308313
m_VersionInfo = {};
309314
m_VersionCommitSha.clear();
310315
m_VersionString.clear();
311-
m_ArgPairs.clear();
312316
m_pCachedRecompileResult = nullptr;
317+
ResetAllArgs();
313318
}
314319

315320
bool HasSources() const {
@@ -503,38 +508,7 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
503508
newPair.Name = ToWstring(pair.Name);
504509
newPair.Value = ToWstring(pair.Value);
505510
}
506-
507-
bool excludeFromFlags = false;
508-
if (newPair.Name == L"E") {
509-
m_EntryPoint = newPair.Value;
510-
excludeFromFlags = true;
511-
}
512-
else if (newPair.Name == L"T") {
513-
m_TargetProfile = newPair.Value;
514-
excludeFromFlags = true;
515-
}
516-
else if (newPair.Name == L"D") {
517-
m_Defines.push_back(newPair.Value);
518-
excludeFromFlags = true;
519-
}
520-
521-
std::wstring nameWithDash;
522-
if (newPair.Name.size())
523-
nameWithDash = std::wstring(L"-") + newPair.Name;
524-
525-
if (!excludeFromFlags) {
526-
if (nameWithDash.size())
527-
m_Flags.push_back(nameWithDash);
528-
if (newPair.Value.size())
529-
m_Flags.push_back(newPair.Value);
530-
}
531-
532-
if (nameWithDash.size())
533-
m_Args.push_back(nameWithDash);
534-
if (newPair.Value.size())
535-
m_Args.push_back(newPair.Value);
536-
537-
m_ArgPairs.push_back( std::move(newPair) );
511+
AddArgPair(std::move(newPair));
538512
}
539513

540514
// Entry point might have been omitted. Set it to main by default.
@@ -593,6 +567,40 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
593567
return S_OK;
594568
}
595569

570+
void AddArgPair(ArgPair &&newPair) {
571+
bool excludeFromFlags = false;
572+
if (newPair.Name == L"E") {
573+
m_EntryPoint = newPair.Value;
574+
excludeFromFlags = true;
575+
}
576+
else if (newPair.Name == L"T") {
577+
m_TargetProfile = newPair.Value;
578+
excludeFromFlags = true;
579+
}
580+
else if (newPair.Name == L"D") {
581+
m_Defines.push_back(newPair.Value);
582+
excludeFromFlags = true;
583+
}
584+
585+
std::wstring nameWithDash;
586+
if (newPair.Name.size())
587+
nameWithDash = std::wstring(L"-") + newPair.Name;
588+
589+
if (!excludeFromFlags) {
590+
if (nameWithDash.size())
591+
m_Flags.push_back(nameWithDash);
592+
if (newPair.Value.size())
593+
m_Flags.push_back(newPair.Value);
594+
}
595+
596+
if (nameWithDash.size())
597+
m_Args.push_back(nameWithDash);
598+
if (newPair.Value.size())
599+
m_Args.push_back(newPair.Value);
600+
601+
m_ArgPairs.push_back( std::move(newPair) );
602+
}
603+
596604
public:
597605
DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
598606
DXC_MICROCOM_TM_ALLOC(DxcPdbUtils)
@@ -737,6 +745,63 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
737745
return m_pDebugProgramBlob != nullptr;
738746
}
739747

748+
virtual HRESULT STDMETHODCALLTYPE OverrideArgs(_In_ DxcArgPair *pArgPairs, UINT32 uNumArgPairs) override {
749+
try {
750+
DxcThreadMalloc TM(m_pMalloc);
751+
752+
ResetAllArgs();
753+
754+
for (UINT32 i = 0; i < uNumArgPairs; i++) {
755+
ArgPair newPair;
756+
newPair.Name = pArgPairs[i].pName ? pArgPairs[i].pName : L"";
757+
newPair.Value = pArgPairs[i].pValue ? pArgPairs[i].pValue : L"";
758+
AddArgPair(std::move(newPair));
759+
}
760+
761+
// Clear the cached compile result
762+
m_pCachedRecompileResult = nullptr;
763+
}
764+
CATCH_CPP_RETURN_HRESULT()
765+
766+
return S_OK;
767+
}
768+
769+
virtual HRESULT STDMETHODCALLTYPE OverrideRootSignature(_In_ const WCHAR *pRootSignature) override {
770+
try {
771+
DxcThreadMalloc TM(m_pMalloc);
772+
773+
std::vector<ArgPair> newArgPairs;
774+
for (ArgPair &pair : m_ArgPairs) {
775+
if (pair.Name == L"rootsig-define") {
776+
continue;
777+
}
778+
newArgPairs.push_back(pair);
779+
}
780+
781+
ResetAllArgs();
782+
783+
for (ArgPair &newArg : newArgPairs) {
784+
AddArgPair(std::move(newArg));
785+
}
786+
787+
ArgPair rsPair;
788+
rsPair.Name = L"rootsig-define";
789+
rsPair.Value = L"__DXC_RS_DEFINE";
790+
AddArgPair(std::move(rsPair));
791+
792+
ArgPair defPair;
793+
defPair.Name = L"D";
794+
defPair.Value = std::wstring(L"__DXC_RS_DEFINE=") + pRootSignature;
795+
AddArgPair(std::move(defPair));
796+
797+
// Clear the cached compile result
798+
m_pCachedRecompileResult = nullptr;
799+
}
800+
CATCH_CPP_RETURN_HRESULT()
801+
802+
return S_OK;
803+
}
804+
740805
virtual HRESULT STDMETHODCALLTYPE CompileForFullPDB(_COM_Outptr_ IDxcResult **ppResult) {
741806
if (!ppResult) return E_POINTER;
742807
*ppResult = nullptr;
@@ -752,13 +817,28 @@ struct DxcPdbUtils : public IDxcPdbUtils, public IDxcPixDxilDebugInfoFactory
752817

753818
DxcThreadMalloc TM(m_pMalloc);
754819

820+
std::vector<std::wstring> new_args_storage;
821+
for (unsigned i = 0; i < m_ArgPairs.size(); i++) {
822+
std::wstring name = m_ArgPairs[i].Name;
823+
std::wstring value = m_ArgPairs[i].Value;
824+
825+
if (name == L"Zs") continue;
826+
if (name == L"Zi") continue;
827+
828+
if (name.size()) {
829+
name.insert(name.begin(), L'-');
830+
new_args_storage.push_back(std::move(name));
831+
}
832+
if (value.size()) {
833+
new_args_storage.push_back(std::move(value));
834+
}
835+
}
836+
new_args_storage.push_back(L"-Zi");
837+
755838
std::vector<const WCHAR *> new_args;
756-
for (unsigned i = 0; i < m_Args.size(); i++) {
757-
if (m_Args[i] == L"/Zs" || m_Args[i] == L"-Zs")
758-
continue;
759-
new_args.push_back(m_Args[i].c_str());
839+
for (std::wstring &arg : new_args_storage) {
840+
new_args.push_back(arg.c_str());
760841
}
761-
new_args.push_back(L"-Zi");
762842

763843
assert(m_MainFileName.size());
764844
if (m_MainFileName.size())

tools/clang/unittests/HLSL/CompilerTest.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,54 @@ static void VerifyPdbUtil(dxc::DxcDllSupport &dllSupport,
12881288
CComPtr<IDxcBlob> pFullPdb;
12891289
VERIFY_SUCCEEDED(pPdbUtils->GetFullPDB(&pFullPdb));
12901290

1291+
// Save a copy of the arg pairs
1292+
std::vector<std::pair< std::wstring, std::wstring> > pairsStorage;
1293+
UINT32 uNumArgsPairs = 0;
1294+
VERIFY_SUCCEEDED(pPdbUtils->GetArgPairCount(&uNumArgsPairs));
1295+
for (UINT32 i = 0; i < uNumArgsPairs; i++) {
1296+
CComBSTR pName, pValue;
1297+
VERIFY_SUCCEEDED(pPdbUtils->GetArgPair(i, &pName, &pValue));
1298+
std::pair< std::wstring, std::wstring> pairStorage;
1299+
pairStorage.first = pName ? pName : L"";
1300+
pairStorage.second = pValue ? pValue : L"";
1301+
pairsStorage.push_back(pairStorage);
1302+
}
1303+
1304+
// Set an obviously wrong RS and verify compilation fails
1305+
{
1306+
VERIFY_SUCCEEDED(pPdbUtils->OverrideRootSignature(L""));
1307+
CComPtr<IDxcResult> pResult;
1308+
VERIFY_SUCCEEDED(pPdbUtils->CompileForFullPDB(&pResult));
1309+
1310+
HRESULT result = S_OK;
1311+
VERIFY_SUCCEEDED(pResult->GetStatus(&result));
1312+
VERIFY_FAILED(result);
1313+
1314+
CComPtr<IDxcBlobEncoding> pErr;
1315+
VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErr));
1316+
}
1317+
1318+
// Set an obviously wrong set of args and verify compilation fails
1319+
{
1320+
1321+
std::vector<DxcArgPair> pairs;
1322+
for (auto &p : pairsStorage) {
1323+
DxcArgPair pair = {};
1324+
pair.pName = p.first.c_str();
1325+
pair.pValue = p.second.c_str();
1326+
pairs.push_back(pair);
1327+
}
1328+
1329+
VERIFY_SUCCEEDED(pPdbUtils->OverrideArgs(pairs.data(), pairs.size()));
1330+
1331+
CComPtr<IDxcResult> pResult;
1332+
VERIFY_SUCCEEDED(pPdbUtils->CompileForFullPDB(&pResult));
1333+
1334+
HRESULT result = S_OK;
1335+
VERIFY_SUCCEEDED(pResult->GetStatus(&result));
1336+
VERIFY_SUCCEEDED(result);
1337+
}
1338+
12911339
auto ReplaceDebugFlagPair = [](const std::vector<std::pair<const WCHAR *, const WCHAR *> > &List) -> std::vector<std::pair<const WCHAR *, const WCHAR *> > {
12921340
std::vector<std::pair<const WCHAR *, const WCHAR *> > ret;
12931341
for (unsigned i = 0; i < List.size(); i++) {

0 commit comments

Comments
 (0)