diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h new file mode 100644 index 0000000000..f166c61f67 --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -0,0 +1,359 @@ +#pragma once + +#if HAVE_COOPVEC_API + +#include +#include +#include + +#include "dxc/Support/microcom.h" + +#include "CoopVecAPI.h" + +struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { +private: + DXC_MICROCOM_REF_FIELD(RefCount) + dxc::DxcDllSupport &DxcSupport; + +public: + LinAlgHeaderIncludeHandler() = delete; + LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) + : RefCount(0), DxcSupport(DxcSupport) {} + + DXC_MICROCOM_ADDREF_RELEASE_IMPL(RefCount) + + HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR Filename, + IDxcBlob **IncludeSource) { + if (wcscmp(Filename, L"dx/linalg.h") == 0 || + wcscmp(Filename, L".\\dx\\linalg.h") == 0) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( + L"LinAlgHeader", ParamValue))) { + return E_FAIL; + } + if (ParamValue.IsEmpty()) { + return E_FAIL; + } + LPCWSTR RealHeaderPath = + reinterpret_cast(ParamValue.GetBuffer()); + + CComPtr HeaderUtils; + + IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &HeaderUtils)); + + IDxcBlobEncoding *HeaderBlob; + IFT(HeaderUtils->LoadFile(RealHeaderPath, nullptr, &HeaderBlob)); + + *IncludeSource = HeaderBlob; + + return S_OK; + } + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID IID, void **Object) override { +// FIXME: This is a workaround for a warning-as-error about unused parameters. +#pragma warning(push) +#pragma warning(disable : 4100) + return DoBasicQueryInterface(this, IID, Object); +#pragma warning(pop) + } +}; + +namespace CoopVecHelpers { +template +static std::vector CreateAllOnesInputMatrix(uint32_t Width, + uint32_t Height) { + std::vector InputMatrix(Width * Height); + for (uint32_t i = 0; i < Width * Height; i++) { + if constexpr (std::is_same_v || + std::is_same_v) { + InputMatrix[i] = 1; + } else if constexpr (std::is_same_v) { + InputMatrix[i] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + InputMatrix[i] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; + } + } + + // Convert to uint8_t vector + std::vector Uint8InputMatrix(InputMatrix.size() * sizeof(EltTy)); + std::memcpy(Uint8InputMatrix.data(), InputMatrix.data(), + InputMatrix.size() * sizeof(EltTy)); + return Uint8InputMatrix; +} + +template +static std::vector CreateInputVector(uint32_t NumThreads, + uint32_t EltsPerThread) { + std::vector InputVector(NumThreads * EltsPerThread); + std::fill(InputVector.begin(), InputVector.end(), EltTy(0)); + if (EltsPerThread < 2) { + WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); + return std::vector(); + } + for (uint32_t TID = 0; TID < NumThreads; TID++) { + if constexpr (std::is_same_v || + std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = 1; + InputVector[TID * EltsPerThread + 1] = 1; + } else if constexpr (std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); + InputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + InputVector[TID * EltsPerThread + 0] = 1.0f; + InputVector[TID * EltsPerThread + 1] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; + } + } + + // Convert to uint8_t vector + std::vector Uint8InputVector(InputVector.size() * sizeof(EltTy)); + std::memcpy(Uint8InputVector.data(), InputVector.data(), + InputVector.size() * sizeof(EltTy)); + return Uint8InputVector; +} + +template +static std::vector CreateInputBias(uint32_t NumElts) { + std::vector InputBias(NumElts); + if constexpr (std::is_same_v || + std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), EltTy(1)); + } else if constexpr (std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), + ConvertFloat32ToFloat16(1.0f)); + } else if constexpr (std::is_same_v) { + std::fill(InputBias.begin(), InputBias.end(), 1); + } else { + WEX::Logging::Log::Error(L"Unsupported bias type"); + } + // Convert to uint8_t vector + std::vector Uint8InputBias(InputBias.size() * sizeof(EltTy)); + std::memcpy(Uint8InputBias.data(), InputBias.data(), + InputBias.size() * sizeof(EltTy)); + return Uint8InputBias; +} + +static std::wstring +DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"FLOAT_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"FLOAT_E5M2"; + default: + return L""; + } +} + +static bool IsDataTypeInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return DataTypeToFilterString(DataType) == FilterString; +} + +static std::wstring +MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } +} + +static bool +IsMatrixLayoutInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return MatrixLayoutToFilterString(MatrixLayout) == FilterString; +} + +static std::wstring MatrixLayoutToHlslLayoutString( + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"MATRIX_LAYOUT_ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"MATRIX_LAYOUT_COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MATRIX_LAYOUT_MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } +} + +// This multiplier is used to compute the row/column stride for a matrix +// given it's element size. +static int +GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return 2; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return 4; + default: + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return 1; + } +} + +static int GetNumPackedElementsForInputDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + // Int8 packed types are the only ones that have more than 1 element per + // shader variable + switch (InputInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return 4; + default: + return 1; + } +} + +// This type is used in generated HLSL source to represent the vector type +// for the given data type. +static std::wstring +GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"int16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"uint16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"int32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"uint32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"half"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"float"; + default: + WEX::Logging::Log::Error(L"Unsupported input data type"); + return L""; + } +} + +static std::wstring +GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { + switch (Interpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"DATA_TYPE_SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"DATA_TYPE_UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"DATA_TYPE_SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"DATA_TYPE_UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"DATA_TYPE_SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"DATA_TYPE_UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"DATA_TYPE_SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"DATA_TYPE_UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"DATA_TYPE_FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"DATA_TYPE_FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"DATA_TYPE_FLOAT8_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"DATA_TYPE_FLOAT8_E5M2"; + default: + WEX::Logging::Log::Error(L"Unsupported interpretation"); + return L""; + } +} + +// The returned data type is used for matrix conversion. It is hard-coded +// for the test framework where all integer matrices start as SINT8 and +// all FP matrices start as FLOAT32. +static D3D12_LINEAR_ALGEBRA_DATATYPE +GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { + switch (MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + default: + return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + } +} +}; // namespace CoopVecHelpers + +#endif // HAVE_COOPVEC_API diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPI.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h new file mode 100644 index 0000000000..16c1105edc --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVecAPI.h @@ -0,0 +1,178 @@ +#pragma once +// clang-format off + +#if !defined(D3D12_PREVIEW_SDK_VERSION) || D3D12_PREVIEW_SDK_VERSION < 717 + +#ifdef __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +#define HAVE_COOPVEC_API 1 + +// This file contains the definitions of the D3D12 cooperative vector API. +// It is used to test the cooperative vector API on older SDKs. + +constexpr int D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL = 9; +constexpr int D3D12_FEATURE_COOPERATIVE_VECTOR = 11; + +// -------------------------------------------------------------------------------------------------------------------------------- +// Experimental Feature: D3D12CooperativeVectorExperiment +// +// Use with D3D12CooperativeVectorExperiment to enable cooperative vector experimental feature. +// +// Enabling D3D12CooperativeVectorExperiment needs no configuration struct, pass NULL in the pConfigurationStructs array. +// +// -------------------------------------------------------------------------------------------------------------------------------- +static const UUID D3D12CooperativeVectorExperiment = { /* 384748be-cca5-471e-a125-5cc997e04d39 */ + 0x384748be, + 0xcca5, + 0x471e, + {0xa1, 0x25, 0x5c, 0xc9, 0x97, 0xe0, 0x4d, 0x39} +}; + +/* interface __MIDL_itf_d3d12_0000_0082 */ +/* [local] */ + +typedef +enum D3D12_COOPERATIVE_VECTOR_TIER + { + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED = 0, + D3D12_COOPERATIVE_VECTOR_TIER_1_0 = 0x10, + D3D12_COOPERATIVE_VECTOR_TIER_1_1 = 0x11 + } D3D12_COOPERATIVE_VECTOR_TIER; + +typedef +enum D3D12_LINEAR_ALGEBRA_DATATYPE + { + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 = 2, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 = 3, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 = 4, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 = 5, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 = 7, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 = 8, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED = 16, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED = 17, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 = 18, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 = 19, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 = 20, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 = 21 + } D3D12_LINEAR_ALGEBRA_DATATYPE; + +typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL + { + _Out_ D3D12_COOPERATIVE_VECTOR_TIER CooperativeVectorTier; + } D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE BiasInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE OutputType; + BOOL TransposeSupported; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE AccumulationType; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE; + +typedef struct D3D12_FEATURE_DATA_COOPERATIVE_VECTOR + { + _Inout_ UINT MatrixVectorMulAddPropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL *pMatrixVectorMulAddProperties; + _Inout_ UINT OuterProductAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pOuterProductAccumulateProperties; + _Inout_ UINT VectorAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pVectorAccumulateProperties; + } D3D12_FEATURE_DATA_COOPERATIVE_VECTOR; + +typedef +enum D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT + { + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR = 0, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL + 1 ) + } D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO + { + _Inout_ UINT DestSize; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT DestLayout; + _In_ UINT DestStride; + _In_ UINT NumRows; + _In_ UINT NumColumns; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA + { + _Inout_ D3D12_GPU_VIRTUAL_ADDRESS DestVA; + _In_ D3D12_GPU_VIRTUAL_ADDRESS SrcVA; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO + { + _In_ UINT SrcSize; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE SrcDataType; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT SrcLayout; + _In_ UINT SrcStride; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO + { + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO; + + + +#ifndef __ID3D12DevicePreview_INTERFACE_DEFINED__ +#define __ID3D12DevicePreview_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12DevicePreview; + +MIDL_INTERFACE("55ea41d3-6bf5-4332-bbf9-905e6b4e2930") +ID3D12DevicePreview : public IUnknown +{ +public: + virtual void STDMETHODCALLTYPE GetLinearAlgebraMatrixConversionDestinationInfo( + _Inout_ D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO *pDesc) = 0; + +}; + +#endif /* __ID3D12DevicePreview_INTERFACE_DEFINED__ */ + + +#ifndef __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ +#define __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12GraphicsCommandList11; + +MIDL_INTERFACE("f0dcfabc-a84a-4fe3-b3b9-eab26b306c38") +ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 +{ +public: + virtual void STDMETHODCALLTYPE Reserved0() = 0; + virtual void STDMETHODCALLTYPE Reserved1() = 0; + virtual void STDMETHODCALLTYPE Reserved2() = 0; + + virtual void STDMETHODCALLTYPE ConvertLinearAlgebraMatrix( + _In_ const D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO *pDesc, + _In_ UINT DescCount) = 0; + +}; + +#endif /* __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ */ + +#else // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +// The used d3d12.h header does not support ID3D12GraphicsCommandList10, +// so we cannot define ID3D12GraphicsCommandList11. +#define HAVE_COOPVEC_API 0 +#endif // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ + +#else // D3D12_PREVIEW_SDK_VERSION < 717 +// Preview header has CoopVec support +#define HAVE_COOPVEC_API 1 +#endif // D3D12_PREVIEW_SDK_VERSION < 717 diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 1bef0b4f8d..55d569dd8d 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -51,6 +51,7 @@ // https://msdn.microsoft.com/en-us/library/windows/desktop/dn899120(v=vs.85).aspx // https://developer.microsoft.com/en-US/windows/downloads/windows-10-sdk // + #include #include #include @@ -63,6 +64,8 @@ #include #include #include "LongVectors.h" +#include "CoopVecAPI.h" +#include "CoopVec.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -617,6 +620,9 @@ class ExecutionTest { TEST_METHOD(LongVector_Clamp_uint64); TEST_METHOD(LongVector_Initialize_uint64); + TEST_METHOD(CoopVec_Mul); + TEST_METHOD(CoopVec_OuterProduct); + dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -752,7 +758,7 @@ class ExecutionTest { #endif } - bool UseDebugIfaces() { return true; } + bool UseDebugIfaces() { return false; } bool SaveImages() { return GetTestParamBool(L"SaveImages"); } @@ -775,6 +781,42 @@ class ExecutionTest { void RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic); + void runCoopVecMulTest(); + void runCoopVecOuterProductTest(); + +#if HAVE_COOPVEC_API + struct CoopVecMulSubtestConfig { + int InputPerThread; + int OutputPerThread; + int NumThreads; + int NumLevels; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + bool Bias; + }; + + void + runCoopVecMulTestConfig(ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); + void runCoopVecMulSubtest(ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config); + + struct CoopVecOuterProductSubtestConfig { + int DimM; // Row Count + int DimN; // Column Count + int NumThreads; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + }; + + void runCoopVecOuterProductTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); + void runCoopVecOuterProductSubtest( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config); +#endif // HAVE_COOPVEC_API + template void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList, size_t numParameter, bool isPrefix); @@ -834,7 +876,8 @@ class ExecutionTest { void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { VERIFY_SUCCEEDED(m_support.Initialize()); CComPtr pCompiler; CComPtr pLibrary; @@ -847,7 +890,7 @@ class ExecutionTest { pText, (UINT32)strlen(pText), CP_UTF8, &pTextBlob)); VERIFY_SUCCEEDED(pCompiler->Compile(pTextBlob, L"hlsl.hlsl", pEntryPoint, pTargetProfile, pOptions, numOptions, - nullptr, 0, nullptr, &pResult)); + nullptr, 0, pIncludeHandler, &pResult)); VERIFY_SUCCEEDED(pResult->GetStatus(&resultCode)); if (FAILED(resultCode)) { #ifndef _HLK_CONF @@ -882,7 +925,8 @@ class ExecutionTest { ID3D12RootSignature *pRootSignature, LPCSTR pShader, LPCWSTR pTargetProfile, ID3D12PipelineState **ppComputeState, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { CComPtr pComputeShader; // Load and compile shaders. @@ -892,7 +936,7 @@ class ExecutionTest { #endif } else { CompileFromText(pShader, L"main", pTargetProfile, &pComputeShader, - pOptions, numOptions); + pOptions, numOptions, pIncludeHandler); } // Describe and create the compute pipeline state object (PSO). @@ -1729,6 +1773,21 @@ class ExecutionTest { #endif } + bool DoesDeviceSupportCooperativeVector(ID3D12Device *Device) { +#if HAVE_COOPVEC_API + D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; + if (FAILED(Device->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, + sizeof(O)))) + return false; + return O.CooperativeVectorTier != + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; +#else + UNREFERENCED_PARAMETER(Device); + return false; +#endif + } + bool IsFallbackPathEnabled() { // Enable fallback paths with: /p:"EnableFallback=1" UINT EnableFallbackValue = 0; @@ -1841,8 +1900,18 @@ class ExecutionTest { if (pD3D12EnableExperimentalFeatures == nullptr) { return HRESULT_FROM_WIN32(GetLastError()); } - return pD3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModelsID, - nullptr, nullptr); + + std::vector Features; + + Features.push_back(D3D12ExperimentalShaderModels); + +#if HAVE_COOPVEC_API + if (GetTestParamBool(L"CooperativeVectorExperimental")) { + Features.push_back(D3D12CooperativeVectorExperiment); + } +#endif + return pD3D12EnableExperimentalFeatures((UINT)Features.size(), + Features.data(), nullptr, nullptr); } static HRESULT EnableExperimentalShaderModels() { @@ -11912,6 +11981,1374 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, TestConfig.Tolerance)); } +// Runs a set of tests for the Cooperative Vector Mul and MulAdd operations. +// The device will be queried for supported configurations and then each +// supported configuration will be tested against multiple matrix and vector +// sizes. To help reproduce individual test failures, the test will log the +// configuration it is running and the results of each test. The following +// filters can be used to limit test execution to a specific set of +// configurations: +// +// - CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, +// OUTER_PRODUCT_OPTIMAL +// - CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, +// FLOAT32, ... +// - CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... +// +// Filter example: +// TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 +// -p:CoopVecMatrixLayout=MUL_OPTIMAL +// +// The current implementation will always write the final output data as float. +void ExecutionTest::runCoopVecMulTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else + // Create device and verify coopvec support + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { + return; + } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Allocate memory for the arrays in DevOptions + std::vector MulAddProps( + DevOptions.MatrixVectorMulAddPropCount); + DevOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); + + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Test each supported data type and matrix layout + for (auto MulAddConfig : MulAddProps) { + // Filter on preview test support + bool PreviewConfig = false; + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (!PreviewConfig) { + continue; + } + + // Apply filters + bool IsInFilter = + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecMatrixInterp", + MulAddConfig.MatrixInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecBiasInterp", + MulAddConfig.BiasInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputInterp", + MulAddConfig.InputInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputType", + MulAddConfig.InputType) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecOutputType", + MulAddConfig.OutputType); + if (!IsInFilter) { + continue; + } + + // Run the test + runCoopVecMulTestConfig(D3DDevice, MulAddConfig); + } +#endif // HAVE_COOPVEC_API +} + +#if HAVE_COOPVEC_API +void ExecutionTest::runCoopVecMulTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { + + LogCommentFmt( + L"Running test for MatrixInterpretation: %s, BiasInterpretation: %s, " + L"InputInterpretation: %s, InputType: %s, OutputType: %s", + CoopVecHelpers::DataTypeToFilterString(MulProps.MatrixInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.BiasInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.OutputType).c_str()); + + constexpr CoopVecMulSubtestConfig TestConfigs[] = { + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + }; + + for (auto Config : TestConfigs) { + if ((MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter( + L"CoopVecMatrixLayout", Config.MatrixLayout); + if (!IsInFilter) { + continue; + } + + runCoopVecMulSubtest(D3DDevice, MulProps, Config); + } +} + +void ExecutionTest::runCoopVecMulSubtest( + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config) { + + LogCommentFmt( + L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: " + L"%d, NumLevels: %d, Bias: %s, MatrixLayout: %s", + Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, + Config.NumLevels, Config.Bias ? L"true" : L"false", + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + const int OutputBufferSize = (Config.OutputPerThread * Config.NumThreads * 4); + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr RootSignature; + { + CD3DX12_DESCRIPTOR_RANGE Ranges[2]; + Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, + 0); // InputVector, InputMatrix, InputBias + Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, Ranges, 2, nullptr, + 0); + } + + // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV + CComPtr DescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 4; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create the compute pipeline state for the CoopVec shader + CComPtr ComputePipelineState; + { + std::string ShaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector : register(t0); +ByteAddressBuffer InputBias : register(t1); +ByteAddressBuffer InputMatrix : register(t2); +RWByteAddressBuffer OutputBuffer: register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset = (INPUT_PER_THREAD * threadIdx * (sizeof(INPUT_DATA_TYPE) / INPUT_DIVISOR)); + inputOffset = (inputOffset + 3) & ~3; // Align to 4 bytes + vector input = InputVector.Load >(inputOffset); + + MatrixRef mat = { InputMatrix, 0, STRIDE }; + + vector accum; + + if (USE_BIAS) { + VectorRef biasVec = { InputBias, 0 }; + accum = MulAdd(mat, MakeInterpretedVector(input), biasVec); + } else { + accum = Mul(mat, MakeInterpretedVector(input)); + } + + vector result = (vector)accum; + + // Ensure 4-byte alignment for vector store + uint outputOffset = OUTPUT_PER_THREAD * threadIdx * sizeof(float); + outputOffset = (outputOffset + 3) & ~3; // Align to 4 bytes + OutputBuffer.Store >(outputOffset, result); +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, + const std::wstring &Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + MulProps.MatrixInterpretation); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.InputPerThread * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.OutputPerThread * StrideMultiplier; + break; + } + + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + MulProps.InputInterpretation); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.MatrixInterpretation); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.InputInterpretation); + const std::wstring AccumInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.BiasInterpretation); + + auto InputPerThreadDefine = + CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); + auto OutputPerThreadDefine = + CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); + auto MatrixDataTypeEnumDefine = + CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); + auto UseBiasDefine = CreateDefineFromInt(L"USE_BIAS", Config.Bias ? 1 : 0); + auto AccumInterpretationEnumDefine = CreateDefineFromString( + L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); + + LPCWSTR Options[] = { + L"-enable-16bit-types", + InputPerThreadDefine.c_str(), + OutputPerThreadDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + UseBiasDefine.c_str(), + AccumInterpretationEnumDefine.c_str(), + }; + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; + FenceObj FO; + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); + + // Setup input data + auto ExpectedOutputBuffer = + std::make_unique(Config.OutputPerThread * Config.NumThreads); + + // Setup input matrix as all-ones in sint8 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; + if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); + + // Create input vector of an appropriate type. All integer types start as + // SINT8 for now. + CComPtr InputVecSRVResource, InputVecSRVUploadResource; + std::vector InputVector; + + if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + (MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + InputVector = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + InputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (InputVector.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector.resize(InputVector.size() + 4 - (InputVector.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputVector.data(), + InputVector.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector.size()), + &InputVecSRVResource, &InputVecSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector.size() / sizeof(int32_t)), + InputVecSRVResource); + + // Create input bias + CComPtr InputBiasSRVResource, InputBiasSRVUploadResource; + std::vector InputBias; + + if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + InputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + InputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + InputBias = CoopVecHelpers::CreateInputBias( + Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported bias data type"); + return; + } + + if (InputBias.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputBias.resize(InputBias.size() + 4 - (InputBias.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputBias.data(), + InputBias.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputBias.size()), + &InputBiasSRVResource, &InputBiasSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputBias.size() / sizeof(int32_t)), + InputBiasSRVResource); + + // Calculate reference output + // FIXME: This does not capture all cases, but is sufficient for the preview + // feature set + if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { + // The input bias is really an array of int32_t + std::vector InputBiasI32(InputBias.size() / sizeof(int32_t)); + std::memcpy(InputBiasI32.data(), InputBias.data(), InputBias.size()); + + // The input vector is really an array of float if our vector input type is + // FLOAT32 + std::vector InputVectorF32(InputVector.size() / sizeof(int32_t)); + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::memcpy(InputVectorF32.data(), InputVector.data(), + InputVector.size()); + } + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + int Acc = 0; + + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + int InputElem; + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputElem = (int) + InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx]; + } else { + InputElem = + InputVector[ThreadIdx * Config.InputPerThread + InputIdx]; + } + int const MatrixElem = + InputMatrix[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (Config.Bias) { + Acc += InputBiasI32[OutputIdx]; + } + + float Result = float(Acc); + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; + } + } + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // The input bias/vector is really an array of float16 + std::vector InputVectorFP16( + InputVector.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVectorFP16.data(), InputVector.data(), InputVector.size()); + + std::vector InputBiasFP16( + InputBias.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputBiasFP16.data(), InputBias.data(), InputBias.size()); + + // The CPU reference matrix is float + std::vector InputMatrixFP32(InputMatrix.size() / sizeof(float)); + std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + float Acc = 0; + + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + float const InputElem = ConvertFloat16ToFloat32( + InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]); + float const MatrixElem = + InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (Config.Bias) { + Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]); + } + + float Result = Acc; + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; + } + } + } + + CComPtr ConvertedMatrixResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); + ConvertInfo.SrcInfo.SrcLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; + switch (MulProps.MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + } + ConvertInfo.SrcInfo.SrcStride = Config.InputPerThread * SrcEltSize; + ConvertInfo.SrcInfo.SrcSize = + Config.InputPerThread * Config.OutputPerThread * SrcEltSize; + + ConvertInfo.DestInfo.DestLayout = Config.MatrixLayout; + ConvertInfo.DestInfo.DestStride = 0; + ConvertInfo.DestInfo.NumRows = Config.OutputPerThread; + ConvertInfo.DestInfo.NumColumns = Config.InputPerThread; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + ConvertInfo.DestInfo.DestStride = Config.InputPerThread * DestEltSize; + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + ConvertInfo.DestInfo.DestStride = Config.OutputPerThread * DestEltSize; + } + + // Get destination size using preview interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + // Create resource to hold matrix copy + CreateTestResources( + D3DDevice, CommandList, nullptr, 0, + CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize), + &ConvertedMatrixResource, nullptr); + + // Set up data descriptors + ConvertInfo.DataDesc.DestVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + + // This increments baseHandle + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawSRV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); + } + + CComPtr UavResource; + CComPtr UavUploadResource; + CComPtr UavReadResource; + + // Create buffer for output and fill with 0xFF to make it obvious if it's not + // written in the shader. + std::vector OutputBufferInit(OutputBufferSize); + std::fill(OutputBufferInit.begin(), OutputBufferInit.end(), (uint8_t)0xFF); + + CreateTestUavs(D3DDevice, CommandList, OutputBufferInit.data(), + OutputBufferSize, &UavResource, &UavUploadResource, + &UavReadResource); + CreateRawUAV(D3DDevice, BaseHandle, OutputBufferSize / 4, UavResource); + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + SetDescriptorHeap(CommandList, DescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + RecordTransitionBarrier(CommandList, UavResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + CommandList->CopyResource(UavReadResource, UavResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + { + MappedData MappedData(UavReadResource, OutputBufferSize); + + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < OutputBufferSize / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; + break; + } + } + VERIFY_IS_TRUE(Equal); + } +} +#endif // HAVE_COOPVEC_API + +TEST_F(ExecutionTest, CoopVec_Mul) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + runCoopVecMulTest(); +} + +void ExecutionTest::runCoopVecOuterProductTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else + // Create device and verify coopvec support + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { + return; + } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Allocate memory for the arrays in DevOptions + std::vector AccumulateProps( + DevOptions.OuterProductAccumulatePropCount); + DevOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); + + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); + + // Test each supported data type and matrix layout + for (auto AccumulateConfig : AccumulateProps) { + // Run the test + runCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); + } +#endif // HAVE_COOPVEC_API +} + +#if HAVE_COOPVEC_API +void ExecutionTest::runCoopVecOuterProductTestConfig( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { + LogCommentFmt( + L"Running test for InputType: %s, AccumulationType: %s", + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.AccumulationType) + .c_str()); + + constexpr CoopVecOuterProductSubtestConfig TestConfigs[] = { + {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + }; + + for (auto Config : TestConfigs) { + if ((AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); + } +} + +void ExecutionTest::runCoopVecOuterProductSubtest( + ID3D12Device *D3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config) { + + LogCommentFmt( + L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", + Config.DimM, Config.DimN, Config.NumThreads, + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr RootSignature; + { + CD3DX12_DESCRIPTOR_RANGE ranges[2]; + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, + 0); // InputVector1, InputVector2 + ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr, + 0); + } + + // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV + CComPtr DescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 3; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create a compute pipeline state object. + CComPtr ComputePipelineState; + { + std::string ShaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector1 : register(t0); +ByteAddressBuffer InputVector2 : register(t1); +RWByteAddressBuffer AccumMatrix : register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ +#if 1 + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset1 = (DIM_M * threadIdx * sizeof(INPUT_DATA_TYPE)); + inputOffset1 = (inputOffset1 + 3) & ~3; // Align to 4 bytes + vector input1 = InputVector1.Load >(inputOffset1); + + uint inputOffset2 = (DIM_N * threadIdx * sizeof(INPUT_DATA_TYPE)); + inputOffset2 = (inputOffset2 + 3) & ~3; // Align to 4 bytes + vector input2 = InputVector2.Load >(inputOffset2); + + RWMatrixRef mat = { AccumMatrix, 0, STRIDE }; + + OuterProductAccumulate(input1, input2, mat); +#endif +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, + const wchar_t *Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + AccumulateProps.AccumulationType); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.DimN * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.DimM * StrideMultiplier; + break; + } + + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + AccumulateProps.InputType); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType( + AccumulateProps.AccumulationType); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.AccumulationType); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.InputType); + + auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM); + auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); + auto MatrixDataTypeEnumDefine = CreateDefineFromString( + L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); + + LPCWSTR Options[] = { + L"-enable-16bit-types", + DimMDefine.c_str(), + DimNDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + }; + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; + FenceObj FO; + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); + + // Setup input matrix as all-ones in sint8/fp32 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; + if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); + + // Create input vectors + CComPtr InputVecSRVResource1, InputVecSRVUploadResource1; + std::vector InputVector1; + CComPtr InputVecSRVResource2, InputVecSRVUploadResource2; + std::vector InputVector2; + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + InputVector1 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimM); + InputVector2 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (InputVector1.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector1.resize(InputVector1.size() + 4 - (InputVector1.size() % 4)); + } + if (InputVector2.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + InputVector2.resize(InputVector2.size() + 4 - (InputVector2.size() % 4)); + } + CreateTestResources(D3DDevice, CommandList, InputVector1.data(), + InputVector1.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector1.size()), + &InputVecSRVResource1, &InputVecSRVUploadResource1); + CreateTestResources(D3DDevice, CommandList, InputVector2.data(), + InputVector2.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector2.size()), + &InputVecSRVResource2, &InputVecSRVUploadResource2); + + // This increments baseHandle + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector1.size() / sizeof(int32_t)), + InputVecSRVResource1); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector2.size() / sizeof(int32_t)), + InputVecSRVResource2); + + // Calculate reference output + auto ExpectedOutputBufferI8 = + CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / + sizeof(float)); + std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), + ExpectedOutputBufferI8.size()); + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + std::vector InputVector1FP16( + InputVector1.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector1FP16.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP16( + InputVector2.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector2FP16.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float acc = ConvertFloat16ToFloat32(InputVector1FP16[M]) * + ConvertFloat16ToFloat32(InputVector2FP16[N]); + ExpectedOutputBuffer[M * Config.DimN + N] += acc; + } + } + } + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::vector InputVector1FP32(InputVector1.size() / sizeof(float)); + std::memcpy(InputVector1FP32.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP32(InputVector2.size() / sizeof(float)); + std::memcpy(InputVector2FP32.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float Acc = InputVector1FP32[ThreadIdx * Config.DimM + M] * + InputVector2FP32[ThreadIdx * Config.DimN + N]; + ExpectedOutputBuffer[M * Config.DimN + N] += Acc; + } + } + } + } + + CComPtr ConvertedMatrixResource, ConvertedMatrixReadResource; + int ConvertedMatrixSize = 0; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo = {}; + SrcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); + SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo = {}; + DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; + switch (AccumulateProps.AccumulationType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 + break; + } + SrcInfo.SrcStride = Config.DimM * SrcEltSize; + SrcInfo.SrcSize = Config.DimM * Config.DimN * SrcEltSize; + + DestInfo.DestLayout = Config.MatrixLayout; + DestInfo.DestStride = 0; + DestInfo.NumRows = Config.DimM; + DestInfo.NumColumns = Config.DimN; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + DestInfo.DestStride = Config.DimM * DestEltSize; + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + DestInfo.DestStride = Config.DimM * DestEltSize; + } + + // Create conversion info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo = SrcInfo; + ConvertInfo.DestInfo = DestInfo; + + // Get preview device interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + ConvertedMatrixSize = ConvertInfo.DestInfo.DestSize; + + // Hack to prevent read resource from being created with size 0 + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &ConvertedMatrixResource, nullptr, + &ConvertedMatrixReadResource); + + // Set up data descriptors + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc = {}; + DataDesc.DestVA = ConvertedMatrixResource->GetGPUVirtualAddress(); + DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc = DataDesc; + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + + // This increments baseHandle + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawUAV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); + } + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + SetDescriptorHeap(CommandList, DescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + // Convert matrix to sint8/fp32 row-major format before reading back to the + // CPU. A new resource is created, along with a readback resource, for the + // matrix copy. + CComPtr MatrixRowMajorResource, MatrixRowMajorReadResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcLayout = Config.MatrixLayout; + ConvertInfo.SrcInfo.SrcSize = ConvertedMatrixSize; + ConvertInfo.SrcInfo.SrcDataType = AccumulateProps.AccumulationType; + ConvertInfo.SrcInfo.SrcStride = 0; // OUTER_PRODUCT_OPTIMAL + + // Create destination matrix info + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + ConvertInfo.DestInfo.DestLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + ConvertInfo.DestInfo.NumRows = Config.DimM; + ConvertInfo.DestInfo.NumColumns = Config.DimN; + + if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); + } else { + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); + } + + // Get destination size using preview interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + // Create resource to hold matrix copy and a readback resource for it + // Init vector is a hack to prevent read resource from being created with + // size 0 + // TODO: Fix CreateTestUavs to allow creating readback resource without init + // data + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &MatrixRowMajorResource, nullptr, + &MatrixRowMajorReadResource); + + // Set up data descriptors + ConvertInfo.DataDesc.DestVA = + MatrixRowMajorResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + } + + RecordTransitionBarrier(CommandList, MatrixRowMajorResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + CommandList->CopyResource(MatrixRowMajorReadResource, MatrixRowMajorResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + + { + MappedData MappedData(MatrixRowMajorReadResource, (UINT)InputMatrix.size()); + + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < (UINT)InputMatrix.size() / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; + break; + } + } + VERIFY_IS_TRUE(Equal); + } +} +#endif // HAVE_COOPVEC_API + +TEST_F(ExecutionTest, CoopVec_OuterProduct) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + runCoopVecOuterProductTest(); +} + // This test expects a that retrieves a signal value from each of a // few resources that are initialized here. determines if it uses // the 6.6 Dynamic Resources feature. Values are read back from the result UAV