1313#define __DXCAPI_USE_H__
1414
1515#include " dxc/dxcapi.h"
16+ #include < cstdlib> // for getenv
1617#include < string>
18+ #include < filesystem> // C++17 and later
19+ #include < dxc/Support/Global.h> // for hresult handling with DXC_FAILED
1720
1821namespace dxc {
1922
@@ -26,17 +29,15 @@ class DxcDllSupport {
2629 HMODULE m_dll;
2730 DxcCreateInstanceProc m_createFn;
2831 DxcCreateInstance2Proc m_createFn2;
29- std::string DxilDLLPath = " " ;
3032
31- HRESULT InitializeInternal (LPCSTR dllName, LPCSTR fnName) {
33+ HRESULT virtual InitializeInternal (LPCSTR dllName, LPCSTR fnName) {
3234 if (m_dll != nullptr )
3335 return S_OK;
3436
3537#ifdef _WIN32
3638 m_dll = LoadLibraryA (dllName);
3739 if (m_dll == nullptr )
3840 return HRESULT_FROM_WIN32 (GetLastError ());
39-
4041 m_createFn = (DxcCreateInstanceProc)GetProcAddress (m_dll, fnName);
4142
4243 if (m_createFn == nullptr ) {
@@ -77,8 +78,6 @@ class DxcDllSupport {
7778 }
7879
7980public:
80- LPCSTR GetDxilDLLPath () { return DxilDLLPath.data (); }
81- void SetDxilDLLPath (LPCSTR p) { DxilDLLPath = p; }
8281 DxcDllSupport () : m_dll(nullptr ), m_createFn(nullptr ), m_createFn2(nullptr ) {}
8382
8483 DxcDllSupport (DxcDllSupport &&other) {
@@ -90,14 +89,13 @@ class DxcDllSupport {
9089 other.m_createFn2 = nullptr ;
9190 }
9291
93- ~DxcDllSupport () { Cleanup (); }
92+ virtual ~DxcDllSupport () { Cleanup (); }
9493
95- HRESULT Initialize () {
96- // load dxcompiler.dll
94+ HRESULT virtual Initialize () {
9795 return InitializeInternal (kDxCompilerLib , " DxcCreateInstance" );
9896 }
9997
100- HRESULT InitializeForDll (LPCSTR dll, LPCSTR entryPoint) {
98+ HRESULT virtual InitializeForDll (LPCSTR dll, LPCSTR entryPoint) {
10199 return InitializeInternal (dll, entryPoint);
102100 }
103101
@@ -106,7 +104,8 @@ class DxcDllSupport {
106104 return CreateInstance (clsid, __uuidof (TInterface), (IUnknown **)pResult);
107105 }
108106
109- HRESULT CreateInstance (REFCLSID clsid, REFIID riid, IUnknown **pResult) {
107+ HRESULT virtual CreateInstance (REFCLSID clsid, REFIID riid,
108+ IUnknown **pResult) {
110109 if (pResult == nullptr )
111110 return E_POINTER;
112111 if (m_dll == nullptr )
@@ -122,7 +121,7 @@ class DxcDllSupport {
122121 (IUnknown **)pResult);
123122 }
124123
125- HRESULT CreateInstance2 (IMalloc *pMalloc, REFCLSID clsid, REFIID riid,
124+ HRESULT virtual CreateInstance2 (IMalloc *pMalloc, REFCLSID clsid, REFIID riid,
126125 IUnknown **pResult) {
127126 if (pResult == nullptr )
128127 return E_POINTER;
@@ -134,11 +133,21 @@ class DxcDllSupport {
134133 return hr;
135134 }
136135
137- bool HasCreateWithMalloc () const { return m_createFn2 != nullptr ; }
136+ bool virtual HasCreateWithMalloc () const { return m_createFn2 != nullptr ; }
137+
138+ bool virtual IsEnabled () const { return m_dll != nullptr ; }
138139
139- bool IsEnabled () const { return m_dll != nullptr ; }
140+ bool virtual GetCreateInstanceProcs (DxcCreateInstanceProc *pCreateFn,
141+ DxcCreateInstance2Proc *pCreateFn2) const {
142+ if (pCreateFn == nullptr || pCreateFn2 == nullptr ||
143+ m_createFn == nullptr )
144+ return false ;
145+ *pCreateFn = m_createFn;
146+ *pCreateFn2 = m_createFn2;
147+ return true ;
148+ }
140149
141- void Cleanup () {
150+ void virtual Cleanup () {
142151 if (m_dll != nullptr ) {
143152 m_createFn = nullptr ;
144153 m_createFn2 = nullptr ;
@@ -151,7 +160,7 @@ class DxcDllSupport {
151160 }
152161 }
153162
154- HMODULE Detach () {
163+ HMODULE virtual Detach () {
155164 HMODULE hModule = m_dll;
156165 m_dll = nullptr ;
157166 return hModule;
@@ -184,6 +193,56 @@ void WriteOperationErrorsToConsole(IDxcOperationResult *pResult,
184193void WriteOperationResultToConsole (IDxcOperationResult *pRewriteResult,
185194 bool outputWarnings);
186195
196+ class DxcDllExtValidationSupport : public DxcDllSupport {
197+ // this instance of DxcDllSupport manages the lifetime of
198+ // dxil.dll
199+ DxcDllSupport *m_DxilSupport = nullptr ;
200+
201+ std::string DxilDLLPathExt = " " ;
202+ bool InitializationSuccess = false ;
203+ // override DxcDllSupport's implementation of InitializeInternal,
204+ // adding the environment variable value check for a path to a dxil.dll
205+ // for external validation
206+ HRESULT InitializeInternal (LPCSTR dllName, LPCSTR fnName){
207+
208+ // Load dxcompiler.dll
209+ HRESULT result = m_DxilSupport->InitializeForDll (dllName, fnName);
210+ InitializationSuccess = DXC_FAILED (result) ? false : true ;
211+ if (!InitializationSuccess){
212+ return result;
213+ }
214+
215+ // now handle internal or external dxil.dll
216+ const char *envVal = std::getenv (" DXC_DXIL_DLL_PATH" );
217+ bool ValidateInternally = false ;
218+ if (!envVal || std::string (envVal).empty ()) {
219+ ValidateInternally = true ;
220+ }
221+
222+ if (!ValidateInternally){
223+ std::string DllPathStr (envVal);
224+ DxilDLLPathExt = DllPathStr;
225+ std::filesystem::path DllPath (DllPathStr);
226+
227+ // Check if path is absolute and exists
228+ if (!DllPath.is_absolute () || !std::filesystem::exists (DllPath)) {
229+ InitializationSuccess = false ;
230+ // TODO: Ideally emit some diagnostic that the given absolute path doesn't exist
231+ return HRESULT_FROM_WIN32 (GetLastError ());
232+ }
233+ result = m_DxilSupport->InitializeForDll (DllPathStr.data (), fnName);
234+ if (DXC_FAILED (result)) {
235+ InitializationSuccess = false ;
236+ }
237+ }
238+ // nothing to do if we are validating internally, dxcompiler.dll
239+ // is loaded and it'll take care of validation.
240+ return InitializationSuccess;
241+ }
242+
243+ std::string GetDxilDLLPathExt () { return DxilDLLPathExt; }
244+
245+ };
187246} // namespace dxc
188247
189248#endif
0 commit comments