Skip to content

Commit 48c3f9d

Browse files
[CoopVec] This change add two CoopVec HLSL tests to the ExecutionTest framework.
ExecutionTest::CoopVec_MulAdd: Functional verification for the Mul() and MulAdd() HLSL APIs. The driver matrix conversion API is tested as well. These tests should be considered as work-in-progress as this point. They include coverage primarily for SINT8, FLOAT16, FLOAT_E4M3, and FLOAT_E5M2. The test queries the driver for all supported configurations and runs each one, with a filtering mechanism to limit the set of tests to the minimal feature set. The set of tests can be further filtered by the following TE parameters: CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, OUTER_PRODUCT_OPTIMAL CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, FLOAT32, ... CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... Filter example: $ TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 -p:CoopVecMatrixLayout=MUL_OPTIMAL Precision coverage is minimal at this point, using an all-ones input matrix and test vector with ones in the first two components. This is enough to test basic functionality, but more comprehensive tests are needed. ExecutionTest::CoopVec_OuterProduct: Functional verification for the OuterProductAccumulate() HLSL API. This test queries the driver for all supported configurations and runs each one. No filtering is currently implemented.
1 parent 98c9a93 commit 48c3f9d

3 files changed

Lines changed: 1664 additions & 10 deletions

File tree

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
#pragma once
2+
3+
#include <vector>
4+
#include <DirectXMath.h>
5+
#include <DirectXPackedVector.h>
6+
7+
#include "dxc/Support/microcom.h"
8+
9+
#include "CoopVecAPIExtensions.h"
10+
11+
struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler {
12+
private:
13+
DXC_MICROCOM_REF_FIELD(m_dwRef)
14+
dxc::DxcDllSupport &DxcSupport;
15+
16+
public:
17+
LinAlgHeaderIncludeHandler() = delete;
18+
LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) : m_dwRef(0), DxcSupport(DxcSupport) {}
19+
20+
DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
21+
22+
HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR pFilename, IDxcBlob **ppIncludeSource) {
23+
if (wcscmp(pFilename, L"dx/linalg.h") == 0 || wcscmp(pFilename, L".\\dx\\linalg.h") == 0) {
24+
WEX::Common::String ParamValue;
25+
if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(L"LinAlgHeader",
26+
ParamValue))) {
27+
return E_FAIL;
28+
}
29+
if (ParamValue.IsEmpty()) {
30+
return E_FAIL;
31+
}
32+
LPCWSTR RealHeaderPath = reinterpret_cast<LPCWSTR>(ParamValue.GetBuffer());
33+
34+
CComPtr<IDxcUtils> pHeaderUtils;
35+
36+
IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &pHeaderUtils));
37+
38+
IDxcBlobEncoding *pHeaderBlob;
39+
IFT(pHeaderUtils->LoadFile(RealHeaderPath, nullptr, &pHeaderBlob));
40+
41+
*ppIncludeSource = pHeaderBlob;
42+
43+
return S_OK;
44+
}
45+
return E_FAIL;
46+
}
47+
48+
HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid,
49+
void **ppvObject) override {
50+
// FIXME: This is a workaround for a warning-as-error about unused parameters.
51+
#pragma warning(push)
52+
#pragma warning(disable: 4100)
53+
return DoBasicQueryInterface<IDxcIncludeHandler>(this, iid, ppvObject);
54+
#pragma warning(pop)
55+
}
56+
};
57+
58+
struct CoopVecHelpers {
59+
template <typename EltTy>
60+
static std::vector<uint8_t> CreateAllOnesInputMatrix(uint32_t Width, uint32_t Height) {
61+
std::vector<EltTy> inputMatrix(Width * Height);
62+
for (uint32_t i = 0; i < Width * Height; i++) {
63+
if constexpr (std::is_same_v<EltTy, uint8_t> || std::is_same_v<EltTy, int8_t>) {
64+
inputMatrix[i] = 1;
65+
} else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
66+
inputMatrix[i] = ConvertFloat32ToFloat16(1.0f);
67+
} else if constexpr (std::is_same_v<EltTy, float>) {
68+
inputMatrix[i] = 1.0f;
69+
}
70+
}
71+
72+
// Convert to uint8_t vector
73+
std::vector<uint8_t> uint8InputMatrix(inputMatrix.size() * sizeof(EltTy));
74+
std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), inputMatrix.size() * sizeof(EltTy));
75+
return uint8InputMatrix;
76+
}
77+
78+
template <typename EltTy>
79+
static std::vector<uint8_t> CreateInputVector(uint32_t NumThreads, uint32_t EltsPerThread) {
80+
std::vector<EltTy> inputVector(NumThreads * EltsPerThread);
81+
std::fill(inputVector.begin(), inputVector.end(), EltTy(0));
82+
if (EltsPerThread < 2) {
83+
WEX::Logging::Log::Error(L"EltsPerThread must be at least 2");
84+
return std::vector<uint8_t>();
85+
}
86+
for (uint32_t TID = 0; TID < NumThreads; TID++) {
87+
if constexpr (std::is_same_v<EltTy, uint8_t> || std::is_same_v<EltTy, int8_t>) {
88+
inputVector[TID * EltsPerThread + 0] = 1;
89+
inputVector[TID * EltsPerThread + 1] = 1;
90+
} else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
91+
inputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f);
92+
inputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f);
93+
} else if constexpr (std::is_same_v<EltTy, float>) {
94+
inputVector[TID * EltsPerThread + 0] = 1.0f;
95+
inputVector[TID * EltsPerThread + 1] = 1.0f;
96+
} else {
97+
WEX::Logging::Log::Error(L"Unsupported input type");
98+
}
99+
}
100+
101+
// Convert to uint8_t vector
102+
std::vector<uint8_t> uint8InputVector(inputVector.size() * sizeof(EltTy));
103+
std::memcpy(uint8InputVector.data(), inputVector.data(), inputVector.size() * sizeof(EltTy));
104+
return uint8InputVector;
105+
}
106+
107+
template <typename EltTy>
108+
static std::vector<uint8_t> CreateInputBias(uint32_t NumElts) {
109+
std::vector<EltTy> inputBias(NumElts);
110+
if constexpr (std::is_same_v<EltTy, uint8_t> || std::is_same_v<EltTy, int8_t>) {
111+
std::fill(inputBias.begin(), inputBias.end(), EltTy(1));
112+
} else if constexpr (std::is_same_v<EltTy, DirectX::PackedVector::HALF>) {
113+
std::fill(inputBias.begin(), inputBias.end(), ConvertFloat32ToFloat16(1.0f));
114+
} else if constexpr (std::is_same_v<EltTy, int32_t>) {
115+
std::fill(inputBias.begin(), inputBias.end(), 1);
116+
} else {
117+
WEX::Logging::Log::Error(L"Unsupported bias type");
118+
}
119+
// Convert to uint8_t vector
120+
std::vector<uint8_t> uint8InputBias(inputBias.size() * sizeof(EltTy));
121+
std::memcpy(uint8InputBias.data(), inputBias.data(), inputBias.size() * sizeof(EltTy));
122+
return uint8InputBias;
123+
}
124+
125+
static std::wstring DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
126+
switch (DataType) {
127+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
128+
return L"SINT8_T4_PACKED";
129+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
130+
return L"UINT8_T4_PACKED";
131+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
132+
return L"SINT8";
133+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
134+
return L"UINT8";
135+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
136+
return L"SINT16";
137+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
138+
return L"UINT16";
139+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
140+
return L"SINT32";
141+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
142+
return L"UINT32";
143+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
144+
return L"FLOAT32";
145+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
146+
return L"FLOAT16";
147+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
148+
return L"FLOAT_E4M3";
149+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
150+
return L"FLOAT_E5M2";
151+
default:
152+
return L"<UNKNOWN>";
153+
}
154+
}
155+
156+
static bool IsDataTypeInFilter(const wchar_t *FilterKey,
157+
D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
158+
WEX::Common::String ParamValue;
159+
if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey,
160+
ParamValue))) {
161+
// Filter not set, so treat as no filter
162+
return true;
163+
}
164+
if (ParamValue.IsEmpty()) {
165+
// Empty filter, so treat as no filter
166+
return true;
167+
}
168+
169+
// Check if the filter matches the target data type
170+
LPCWSTR FilterString = reinterpret_cast<LPCWSTR>(ParamValue.GetBuffer());
171+
return DataTypeToFilterString(DataType) == FilterString;
172+
}
173+
174+
static std::wstring MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
175+
switch (MatrixLayout) {
176+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
177+
return L"ROW_MAJOR";
178+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
179+
return L"COLUMN_MAJOR";
180+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL:
181+
return L"MUL_OPTIMAL";
182+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL:
183+
return L"OUTER_PRODUCT_OPTIMAL";
184+
default:
185+
return L"<UNKNOWN>";
186+
}
187+
}
188+
189+
static bool IsMatrixLayoutInFilter(const wchar_t *FilterKey,
190+
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
191+
WEX::Common::String ParamValue;
192+
if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey,
193+
ParamValue))) {
194+
// Filter not set, so treat as no filter
195+
return true;
196+
}
197+
if (ParamValue.IsEmpty()) {
198+
// Empty filter, so treat as no filter
199+
return true;
200+
}
201+
202+
// Check if the filter matches the target data type
203+
LPCWSTR FilterString = reinterpret_cast<LPCWSTR>(ParamValue.GetBuffer());
204+
return MatrixLayoutToFilterString(MatrixLayout) == FilterString;
205+
}
206+
207+
static std::wstring MatrixLayoutToHlslLayoutString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
208+
switch (MatrixLayout) {
209+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
210+
return L"MATRIX_LAYOUT_ROW_MAJOR";
211+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
212+
return L"MATRIX_LAYOUT_COLUMN_MAJOR";
213+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL:
214+
return L"MATRIX_LAYOUT_MUL_OPTIMAL";
215+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL:
216+
return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL";
217+
default:
218+
return L"<UNKNOWN>";
219+
}
220+
}
221+
222+
// This multiplier is used to compute the row/column stride for a matrix
223+
// given it's element size.
224+
static int GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
225+
switch (DataType) {
226+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
227+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
228+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
229+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
230+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
231+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
232+
return 1;
233+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
234+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
235+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
236+
return 2;
237+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
238+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
239+
return 4;
240+
default:
241+
WEX::Logging::Log::Error(L"Unsupported matrix data type");
242+
return 1;
243+
}
244+
}
245+
246+
static int GetNumPackedElementsForInputDataType(D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) {
247+
// Int8 packed types are the only ones that have more than 1 element per shader variable
248+
switch (InputInterpretation) {
249+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
250+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
251+
return 4;
252+
default:
253+
return 1;
254+
}
255+
}
256+
257+
// This type is used in generated HLSL source to represent the vector type
258+
// for the given data type.
259+
static std::wstring GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType, D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) {
260+
UNREFERENCED_PARAMETER(InputInterpretation);
261+
switch (DataType) {
262+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
263+
return L"int16_t";
264+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
265+
return L"uint16_t";
266+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
267+
return L"int32_t";
268+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
269+
return L"uint32_t";
270+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
271+
return L"half";
272+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
273+
return L"float";
274+
default:
275+
WEX::Logging::Log::Error(L"Unsupported input data type");
276+
return L"<UNKNOWN>";
277+
}
278+
}
279+
280+
static std::wstring GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) {
281+
switch (Interpretation) {
282+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
283+
return L"DATA_TYPE_SINT8_T4_PACKED";
284+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
285+
return L"DATA_TYPE_UINT8_T4_PACKED";
286+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
287+
return L"DATA_TYPE_SINT8";
288+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
289+
return L"DATA_TYPE_UINT8";
290+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
291+
return L"DATA_TYPE_SINT16";
292+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
293+
return L"DATA_TYPE_UINT16";
294+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
295+
return L"DATA_TYPE_SINT32";
296+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
297+
return L"DATA_TYPE_UINT32";
298+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
299+
return L"DATA_TYPE_FLOAT16";
300+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
301+
return L"DATA_TYPE_FLOAT32";
302+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
303+
return L"DATA_TYPE_FLOAT8_E4M3";
304+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
305+
return L"DATA_TYPE_FLOAT8_E5M2";
306+
default:
307+
WEX::Logging::Log::Error(L"Unsupported interpretation");
308+
return L"<UNKNOWN>";
309+
}
310+
}
311+
312+
// The returned data type is used for matrix conversion. It is hard-coded
313+
// for the test framework where all integer matrices start as SINT8 and
314+
// all FP matrices start as FLOAT32.
315+
static D3D12_LINEAR_ALGEBRA_DATATYPE GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
316+
switch (MatrixInterpretation) {
317+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
318+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED:
319+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
320+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
321+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
322+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
323+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
324+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
325+
return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8;
326+
default:
327+
return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32;
328+
}
329+
}
330+
};

0 commit comments

Comments
 (0)