1+ #pragma once
2+
3+ #include < vector>
4+ #include < DirectXMath.h>
5+ #include < DirectXPackedVector.h>
6+
7+ #include " dxc/Support/microcom.h"
8+
9+ #include " CoopVecAPIExtensions.h"
10+
11+ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler {
12+ private:
13+ DXC_MICROCOM_REF_FIELD (m_dwRef)
14+ dxc::DxcDllSupport &DxcSupport;
15+
16+ public:
17+ LinAlgHeaderIncludeHandler () = delete ;
18+ LinAlgHeaderIncludeHandler (dxc::DxcDllSupport &DxcSupport) : m_dwRef(0 ), DxcSupport(DxcSupport) {}
19+
20+ DXC_MICROCOM_ADDREF_RELEASE_IMPL (m_dwRef)
21+
22+ HRESULT STDMETHODCALLTYPE LoadSource (LPCWSTR pFilename, IDxcBlob **ppIncludeSource) {
23+ if (wcscmp (pFilename, L" dx/linalg.h" ) == 0 || wcscmp (pFilename, L" .\\ dx\\ linalg.h" ) == 0 ) {
24+ WEX::Common::String ParamValue;
25+ if (FAILED (WEX::TestExecution::RuntimeParameters::TryGetValue (L" LinAlgHeader" ,
26+ ParamValue))) {
27+ return E_FAIL;
28+ }
29+ if (ParamValue.IsEmpty ()) {
30+ return E_FAIL;
31+ }
32+ LPCWSTR RealHeaderPath = reinterpret_cast <LPCWSTR>(ParamValue.GetBuffer ());
33+
34+ CComPtr<IDxcUtils> pHeaderUtils;
35+
36+ IFT (DxcSupport.CreateInstance (CLSID_DxcUtils, &pHeaderUtils));
37+
38+ IDxcBlobEncoding *pHeaderBlob;
39+ IFT (pHeaderUtils->LoadFile (RealHeaderPath, nullptr , &pHeaderBlob));
40+
41+ *ppIncludeSource = pHeaderBlob;
42+
43+ return S_OK;
44+ }
45+ return E_FAIL;
46+ }
47+
48+ HRESULT STDMETHODCALLTYPE QueryInterface (REFIID iid,
49+ void **ppvObject) override {
50+ // FIXME: This is a workaround for a warning-as-error about unused parameters.
51+ #pragma warning(push)
52+ #pragma warning(disable: 4100)
53+ return DoBasicQueryInterface<IDxcIncludeHandler>(this , iid, ppvObject);
54+ #pragma warning(pop)
55+ }
56+ };
57+
58+ struct CoopVecHelpers {
59+ template <typename EltTy>
60+ static std::vector<uint8_t > CreateAllOnesInputMatrix (uint32_t Width, uint32_t Height) {
61+ std::vector<EltTy> inputMatrix (Width * Height);
62+ for (uint32_t i = 0 ; i < Width * Height; i++) {
63+ if constexpr (std::is_same_v<EltTy, uint8_t > || std::is_same_v<EltTy, int8_t >) {
64+ inputMatrix[i] = 1 ;
65+ } else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
66+ inputMatrix[i] = ConvertFloat32ToFloat16 (1 .0f );
67+ } else if constexpr (std::is_same_v<EltTy, float >) {
68+ inputMatrix[i] = 1 .0f ;
69+ }
70+ }
71+
72+ // Convert to uint8_t vector
73+ std::vector<uint8_t > uint8InputMatrix (inputMatrix.size () * sizeof (EltTy));
74+ std::memcpy (uint8InputMatrix.data (), inputMatrix.data (), inputMatrix.size () * sizeof (EltTy));
75+ return uint8InputMatrix;
76+ }
77+
78+ template <typename EltTy>
79+ static std::vector<uint8_t > CreateInputVector (uint32_t NumThreads, uint32_t EltsPerThread) {
80+ std::vector<EltTy> inputVector (NumThreads * EltsPerThread);
81+ std::fill (inputVector.begin (), inputVector.end (), EltTy (0 ));
82+ if (EltsPerThread < 2 ) {
83+ WEX::Logging::Log::Error (L" EltsPerThread must be at least 2" );
84+ return std::vector<uint8_t >();
85+ }
86+ for (uint32_t TID = 0 ; TID < NumThreads; TID++) {
87+ if constexpr (std::is_same_v<EltTy, uint8_t > || std::is_same_v<EltTy, int8_t >) {
88+ inputVector[TID * EltsPerThread + 0 ] = 1 ;
89+ inputVector[TID * EltsPerThread + 1 ] = 1 ;
90+ } else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
91+ inputVector[TID * EltsPerThread + 0 ] = ConvertFloat32ToFloat16 (1 .0f );
92+ inputVector[TID * EltsPerThread + 1 ] = ConvertFloat32ToFloat16 (1 .0f );
93+ } else if constexpr (std::is_same_v<EltTy, float >) {
94+ inputVector[TID * EltsPerThread + 0 ] = 1 .0f ;
95+ inputVector[TID * EltsPerThread + 1 ] = 1 .0f ;
96+ } else {
97+ WEX::Logging::Log::Error (L" Unsupported input type" );
98+ }
99+ }
100+
101+ // Convert to uint8_t vector
102+ std::vector<uint8_t > uint8InputVector (inputVector.size () * sizeof (EltTy));
103+ std::memcpy (uint8InputVector.data (), inputVector.data (), inputVector.size () * sizeof (EltTy));
104+ return uint8InputVector;
105+ }
106+
107+ template <typename EltTy>
108+ static std::vector<uint8_t > CreateInputBias (uint32_t NumElts) {
109+ std::vector<EltTy> inputBias (NumElts);
110+ if constexpr (std::is_same_v<EltTy, uint8_t > || std::is_same_v<EltTy, int8_t >) {
111+ std::fill (inputBias.begin (), inputBias.end (), EltTy (1 ));
112+ } else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
113+ std::fill (inputBias.begin (), inputBias.end (), ConvertFloat32ToFloat16 (1 .0f ));
114+ } else if constexpr (std::is_same_v<EltTy, int32_t >) {
115+ std::fill (inputBias.begin (), inputBias.end (), 1 );
116+ } else {
117+ WEX::Logging::Log::Error (L" Unsupported bias type" );
118+ }
119+ // Convert to uint8_t vector
120+ std::vector<uint8_t > uint8InputBias (inputBias.size () * sizeof (EltTy));
121+ std::memcpy (uint8InputBias.data (), inputBias.data (), inputBias.size () * sizeof (EltTy));
122+ return uint8InputBias;
123+ }
124+
125+ static std::wstring DataTypeToFilterString (D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
126+ switch (DataType) {
127+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
128+ return L" SINT8_T4_PACKED" ;
129+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
130+ return L" UINT8_T4_PACKED" ;
131+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
132+ return L" SINT8" ;
133+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
134+ return L" UINT8" ;
135+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
136+ return L" SINT16" ;
137+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
138+ return L" UINT16" ;
139+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
140+ return L" SINT32" ;
141+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
142+ return L" UINT32" ;
143+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
144+ return L" FLOAT32" ;
145+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
146+ return L" FLOAT16" ;
147+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
148+ return L" FLOAT_E4M3" ;
149+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
150+ return L" FLOAT_E5M2" ;
151+ default :
152+ return L" <UNKNOWN>" ;
153+ }
154+ }
155+
156+ static bool IsDataTypeInFilter (const wchar_t *FilterKey,
157+ D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
158+ WEX::Common::String ParamValue;
159+ if (FAILED (WEX::TestExecution::RuntimeParameters::TryGetValue (FilterKey,
160+ ParamValue))) {
161+ // Filter not set, so treat as no filter
162+ return true ;
163+ }
164+ if (ParamValue.IsEmpty ()) {
165+ // Empty filter, so treat as no filter
166+ return true ;
167+ }
168+
169+ // Check if the filter matches the target data type
170+ LPCWSTR FilterString = reinterpret_cast <LPCWSTR>(ParamValue.GetBuffer ());
171+ return DataTypeToFilterString (DataType) == FilterString;
172+ }
173+
174+ static std::wstring MatrixLayoutToFilterString (D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
175+ switch (MatrixLayout) {
176+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
177+ return L" ROW_MAJOR" ;
178+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
179+ return L" COLUMN_MAJOR" ;
180+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL:
181+ return L" MUL_OPTIMAL" ;
182+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL:
183+ return L" OUTER_PRODUCT_OPTIMAL" ;
184+ default :
185+ return L" <UNKNOWN>" ;
186+ }
187+ }
188+
189+ static bool IsMatrixLayoutInFilter (const wchar_t *FilterKey,
190+ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
191+ WEX::Common::String ParamValue;
192+ if (FAILED (WEX::TestExecution::RuntimeParameters::TryGetValue (FilterKey,
193+ ParamValue))) {
194+ // Filter not set, so treat as no filter
195+ return true ;
196+ }
197+ if (ParamValue.IsEmpty ()) {
198+ // Empty filter, so treat as no filter
199+ return true ;
200+ }
201+
202+ // Check if the filter matches the target data type
203+ LPCWSTR FilterString = reinterpret_cast <LPCWSTR>(ParamValue.GetBuffer ());
204+ return MatrixLayoutToFilterString (MatrixLayout) == FilterString;
205+ }
206+
207+ static std::wstring MatrixLayoutToHlslLayoutString (D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
208+ switch (MatrixLayout) {
209+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
210+ return L" MATRIX_LAYOUT_ROW_MAJOR" ;
211+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
212+ return L" MATRIX_LAYOUT_COLUMN_MAJOR" ;
213+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL:
214+ return L" MATRIX_LAYOUT_MUL_OPTIMAL" ;
215+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL:
216+ return L" MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL" ;
217+ default :
218+ return L" <UNKNOWN>" ;
219+ }
220+ }
221+
222+ // This multiplier is used to compute the row/column stride for a matrix
223+ // given it's element size.
224+ static int GetStrideMultiplierForMatrixDataType (D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
225+ switch (DataType) {
226+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
227+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
228+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
229+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
230+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
231+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
232+ return 1 ;
233+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
234+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
235+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
236+ return 2 ;
237+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
238+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
239+ return 4 ;
240+ default :
241+ WEX::Logging::Log::Error (L" Unsupported matrix data type" );
242+ return 1 ;
243+ }
244+ }
245+
246+ static int GetNumPackedElementsForInputDataType (D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) {
247+ // Int8 packed types are the only ones that have more than 1 element per shader variable
248+ switch (InputInterpretation) {
249+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
250+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
251+ return 4 ;
252+ default :
253+ return 1 ;
254+ }
255+ }
256+
257+ // This type is used in generated HLSL source to represent the vector type
258+ // for the given data type.
259+ static std::wstring GetHlslDataTypeForDataType (D3D12_LINEAR_ALGEBRA_DATATYPE DataType, D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) {
260+ UNREFERENCED_PARAMETER (InputInterpretation);
261+ switch (DataType) {
262+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
263+ return L" int16_t" ;
264+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
265+ return L" uint16_t" ;
266+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
267+ return L" int32_t" ;
268+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
269+ return L" uint32_t" ;
270+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
271+ return L" half" ;
272+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
273+ return L" float" ;
274+ default :
275+ WEX::Logging::Log::Error (L" Unsupported input data type" );
276+ return L" <UNKNOWN>" ;
277+ }
278+ }
279+
280+ static std::wstring GetHlslInterpretationForDataType (D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) {
281+ switch (Interpretation) {
282+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
283+ return L" DATA_TYPE_SINT8_T4_PACKED" ;
284+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
285+ return L" DATA_TYPE_UINT8_T4_PACKED" ;
286+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
287+ return L" DATA_TYPE_SINT8" ;
288+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
289+ return L" DATA_TYPE_UINT8" ;
290+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
291+ return L" DATA_TYPE_SINT16" ;
292+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
293+ return L" DATA_TYPE_UINT16" ;
294+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
295+ return L" DATA_TYPE_SINT32" ;
296+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
297+ return L" DATA_TYPE_UINT32" ;
298+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
299+ return L" DATA_TYPE_FLOAT16" ;
300+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
301+ return L" DATA_TYPE_FLOAT32" ;
302+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
303+ return L" DATA_TYPE_FLOAT8_E4M3" ;
304+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
305+ return L" DATA_TYPE_FLOAT8_E5M2" ;
306+ default :
307+ WEX::Logging::Log::Error (L" Unsupported interpretation" );
308+ return L" <UNKNOWN>" ;
309+ }
310+ }
311+
312+ // The returned data type is used for matrix conversion. It is hard-coded
313+ // for the test framework where all integer matrices start as SINT8 and
314+ // all FP matrices start as FLOAT32.
315+ static D3D12_LINEAR_ALGEBRA_DATATYPE GetMatrixSrcDataType (D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
316+ switch (MatrixInterpretation) {
317+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
318+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
319+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
320+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
321+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
322+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
323+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
324+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
325+ return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8;
326+ default :
327+ return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32;
328+ }
329+ }
330+ };
0 commit comments