From 48c3f9d528f9adb86a999c085091c6306ef6680b Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Wed, 30 Apr 2025 08:54:40 -0400 Subject: [PATCH 01/17] [CoopVec] This change add two CoopVec HLSL tests to the ExecutionTest framework. ExecutionTest::CoopVec_MulAdd: Functional verification for the Mul() and MulAdd() HLSL APIs. The driver matrix conversion API is tested as well. These tests should be considered as work-in-progress as this point. They include coverage primarily for SINT8, FLOAT16, FLOAT_E4M3, and FLOAT_E5M2. The test queries the driver for all supported configurations and runs each one, with a filtering mechanism to limit the set of tests to the minimal feature set. The set of tests can be further filtered by the following TE parameters: CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, OUTER_PRODUCT_OPTIMAL CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, FLOAT32, ... CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... Filter example: $ TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 -p:CoopVecMatrixLayout=MUL_OPTIMAL Precision coverage is minimal at this point, using an all-ones input matrix and test vector with ones in the first two components. This is enough to test basic functionality, but more comprehensive tests are needed. ExecutionTest::CoopVec_OuterProduct: Functional verification for the OuterProductAccumulate() HLSL API. This test queries the driver for all supported configurations and runs each one. No filtering is currently implemented. --- tools/clang/unittests/HLSLExec/CoopVec.h | 330 +++++ .../unittests/HLSLExec/CoopVecAPIExtensions.h | 165 +++ .../unittests/HLSLExec/ExecutionTest.cpp | 1179 ++++++++++++++++- 3 files changed, 1664 insertions(+), 10 deletions(-) create mode 100644 tools/clang/unittests/HLSLExec/CoopVec.h create mode 100644 tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h new file mode 100644 index 0000000000..67c14ac987 --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -0,0 +1,330 @@ +#pragma once + +#include +#include +#include + +#include "dxc/Support/microcom.h" + +#include "CoopVecAPIExtensions.h" + +struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { +private: + DXC_MICROCOM_REF_FIELD(m_dwRef) + dxc::DxcDllSupport &DxcSupport; + +public: + LinAlgHeaderIncludeHandler() = delete; + LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) : m_dwRef(0), DxcSupport(DxcSupport) {} + + DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef) + + HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR pFilename, IDxcBlob **ppIncludeSource) { + if (wcscmp(pFilename, L"dx/linalg.h") == 0 || wcscmp(pFilename, L".\\dx\\linalg.h") == 0) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(L"LinAlgHeader", + ParamValue))) { + return E_FAIL; + } + if (ParamValue.IsEmpty()) { + return E_FAIL; + } + LPCWSTR RealHeaderPath = reinterpret_cast(ParamValue.GetBuffer()); + + CComPtr pHeaderUtils; + + IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &pHeaderUtils)); + + IDxcBlobEncoding *pHeaderBlob; + IFT(pHeaderUtils->LoadFile(RealHeaderPath, nullptr, &pHeaderBlob)); + + *ppIncludeSource = pHeaderBlob; + + return S_OK; + } + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, + void **ppvObject) override { +// FIXME: This is a workaround for a warning-as-error about unused parameters. +#pragma warning(push) +#pragma warning(disable: 4100) + return DoBasicQueryInterface(this, iid, ppvObject); +#pragma warning(pop) + } +}; + +struct CoopVecHelpers { + template + static std::vector CreateAllOnesInputMatrix(uint32_t Width, uint32_t Height) { + std::vector inputMatrix(Width * Height); + for (uint32_t i = 0; i < Width * Height; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + inputMatrix[i] = 1; + } else if constexpr (std::is_same_v) { + inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + inputMatrix[i] = 1.0f; + } + } + + // Convert to uint8_t vector + std::vector uint8InputMatrix(inputMatrix.size() * sizeof(EltTy)); + std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), inputMatrix.size() * sizeof(EltTy)); + return uint8InputMatrix; + } + + template + static std::vector CreateInputVector(uint32_t NumThreads, uint32_t EltsPerThread) { + std::vector inputVector(NumThreads * EltsPerThread); + std::fill(inputVector.begin(), inputVector.end(), EltTy(0)); + if (EltsPerThread < 2) { + WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); + return std::vector(); + } + for (uint32_t TID = 0; TID < NumThreads; TID++) { + if constexpr (std::is_same_v || std::is_same_v) { + inputVector[TID * EltsPerThread + 0] = 1; + inputVector[TID * EltsPerThread + 1] = 1; + } else if constexpr (std::is_same_v) { + inputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); + inputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + inputVector[TID * EltsPerThread + 0] = 1.0f; + inputVector[TID * EltsPerThread + 1] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + } + } + + // Convert to uint8_t vector + std::vector uint8InputVector(inputVector.size() * sizeof(EltTy)); + std::memcpy(uint8InputVector.data(), inputVector.data(), inputVector.size() * sizeof(EltTy)); + return uint8InputVector; + } + + template + static std::vector CreateInputBias(uint32_t NumElts) { + std::vector inputBias(NumElts); + if constexpr (std::is_same_v || std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), EltTy(1)); + } else if constexpr (std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), ConvertFloat32ToFloat16(1.0f)); + } else if constexpr (std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), 1); + } else { + WEX::Logging::Log::Error(L"Unsupported bias type"); + } + // Convert to uint8_t vector + std::vector uint8InputBias(inputBias.size() * sizeof(EltTy)); + std::memcpy(uint8InputBias.data(), inputBias.data(), inputBias.size() * sizeof(EltTy)); + return uint8InputBias; + } + + static std::wstring DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"FLOAT_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"FLOAT_E5M2"; + default: + return L""; + } + } + + static bool IsDataTypeInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return DataTypeToFilterString(DataType) == FilterString; + } + + static std::wstring MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } + } + + static bool IsMatrixLayoutInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; + } + + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return MatrixLayoutToFilterString(MatrixLayout) == FilterString; + } + + static std::wstring MatrixLayoutToHlslLayoutString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"MATRIX_LAYOUT_ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"MATRIX_LAYOUT_COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MATRIX_LAYOUT_MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; + default: + return L""; + } + } + + // This multiplier is used to compute the row/column stride for a matrix + // given it's element size. + static int GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return 2; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return 4; + default: + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return 1; + } + } + + static int GetNumPackedElementsForInputDataType(D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + // Int8 packed types are the only ones that have more than 1 element per shader variable + switch (InputInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return 4; + default: + return 1; + } + } + + // This type is used in generated HLSL source to represent the vector type + // for the given data type. + static std::wstring GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType, D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + UNREFERENCED_PARAMETER(InputInterpretation); + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"int16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"uint16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"int32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"uint32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"half"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"float"; + default: + WEX::Logging::Log::Error(L"Unsupported input data type"); + return L""; + } + } + + static std::wstring GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { + switch (Interpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"DATA_TYPE_SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"DATA_TYPE_UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"DATA_TYPE_SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"DATA_TYPE_UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"DATA_TYPE_SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"DATA_TYPE_UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"DATA_TYPE_SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"DATA_TYPE_UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"DATA_TYPE_FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"DATA_TYPE_FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"DATA_TYPE_FLOAT8_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"DATA_TYPE_FLOAT8_E5M2"; + default: + WEX::Logging::Log::Error(L"Unsupported interpretation"); + return L""; + } + } + + // The returned data type is used for matrix conversion. It is hard-coded + // for the test framework where all integer matrices start as SINT8 and + // all FP matrices start as FLOAT32. + static D3D12_LINEAR_ALGEBRA_DATATYPE GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { + switch (MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + default: + return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + } + } +}; \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h b/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h new file mode 100644 index 0000000000..6166254294 --- /dev/null +++ b/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h @@ -0,0 +1,165 @@ +#pragma once + +#if D3D12_PREVIEW_SDK_VERSION < 717 + +// This file contains the definitions of the D3D12 cooperative vector API extensions. +// It is used to test the cooperative vector API extensions on older SDKs. + +constexpr int D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL = 9; +constexpr int D3D12_FEATURE_COOPERATIVE_VECTOR = 11; + +// -------------------------------------------------------------------------------------------------------------------------------- +// Experimental Feature: D3D12CooperativeVectorExperiment +// +// Use with D3D12CooperativeVectorExperiment to enable cooperative vector experimental feature. +// +// Enabling D3D12CooperativeVectorExperiment needs no configuration struct, pass NULL in the pConfigurationStructs array. +// +// -------------------------------------------------------------------------------------------------------------------------------- +static const UUID D3D12CooperativeVectorExperiment = { /* 384748be-cca5-471e-a125-5cc997e04d39 */ + 0x384748be, + 0xcca5, + 0x471e, + {0xa1, 0x25, 0x5c, 0xc9, 0x97, 0xe0, 0x4d, 0x39} +}; + +/* interface __MIDL_itf_d3d12_0000_0082 */ +/* [local] */ + +typedef +enum D3D12_COOPERATIVE_VECTOR_TIER + { + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED = 0, + D3D12_COOPERATIVE_VECTOR_TIER_1_0 = 0x10, + D3D12_COOPERATIVE_VECTOR_TIER_1_1 = 0x11 + } D3D12_COOPERATIVE_VECTOR_TIER; + +typedef +enum D3D12_LINEAR_ALGEBRA_DATATYPE + { + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 = 2, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 = 3, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 = 4, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 = 5, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 = 7, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 = 8, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED = 16, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED = 17, + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 = 18, + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 = 19, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 = 20, + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 = 21 + } D3D12_LINEAR_ALGEBRA_DATATYPE; + +typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL + { + _Out_ D3D12_COOPERATIVE_VECTOR_TIER CooperativeVectorTier; + } D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE BiasInterpretation; + D3D12_LINEAR_ALGEBRA_DATATYPE OutputType; + BOOL TransposeSupported; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL; + +typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE + { + D3D12_LINEAR_ALGEBRA_DATATYPE InputType; + D3D12_LINEAR_ALGEBRA_DATATYPE AccumulationType; + } D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE; + +typedef struct D3D12_FEATURE_DATA_COOPERATIVE_VECTOR + { + _Inout_ UINT MatrixVectorMulAddPropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL *pMatrixVectorMulAddProperties; + _Inout_ UINT OuterProductAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pOuterProductAccumulateProperties; + _Inout_ UINT VectorAccumulatePropCount; + _Out_ D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE *pVectorAccumulateProperties; + } D3D12_FEATURE_DATA_COOPERATIVE_VECTOR; + +typedef +enum D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT + { + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR = 0, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR + 1 ) , + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = ( D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL + 1 ) + } D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO + { + _Inout_ UINT DestSize; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT DestLayout; + _In_ UINT DestStride; + _In_ UINT NumRows; + _In_ UINT NumColumns; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA + { + _Inout_ D3D12_GPU_VIRTUAL_ADDRESS DestVA; + _In_ D3D12_GPU_VIRTUAL_ADDRESS SrcVA; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO + { + _In_ UINT SrcSize; + _In_ D3D12_LINEAR_ALGEBRA_DATATYPE SrcDataType; + _In_ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT SrcLayout; + _In_ UINT SrcStride; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO; + +typedef struct D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO + { + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc; + } D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO; + + + +#ifndef __ID3D12DevicePreview_INTERFACE_DEFINED__ +#define __ID3D12DevicePreview_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12DevicePreview; + +MIDL_INTERFACE("55ea41d3-6bf5-4332-bbf9-905e6b4e2930") +ID3D12DevicePreview : public IUnknown +{ +public: + virtual void STDMETHODCALLTYPE GetLinearAlgebraMatrixConversionDestinationInfo( + _Inout_ D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO *pDesc) = 0; + +}; + +#endif /* __ID3D12DevicePreview_INTERFACE_DEFINED__ */ + + +#ifndef __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ +#define __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ + +EXTERN_C const IID IID_ID3D12GraphicsCommandList11; + +MIDL_INTERFACE("f0dcfabc-a84a-4fe3-b3b9-eab26b306c38") +ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 +{ +public: + virtual void STDMETHODCALLTYPE Reserved0() = 0; + virtual void STDMETHODCALLTYPE Reserved1() = 0; + virtual void STDMETHODCALLTYPE Reserved2() = 0; + + virtual void STDMETHODCALLTYPE ConvertLinearAlgebraMatrix( + _In_ const D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO *pDesc, + _In_ UINT DescCount) = 0; + +}; + +#endif /* __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ */ + +#endif // D3D12_PREVIEW_SDK_VERSION < 717 diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 1bef0b4f8d..011e28e42e 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -51,6 +51,7 @@ // https://msdn.microsoft.com/en-us/library/windows/desktop/dn899120(v=vs.85).aspx // https://developer.microsoft.com/en-US/windows/downloads/windows-10-sdk // + #include #include #include @@ -63,6 +64,8 @@ #include #include #include "LongVectors.h" +#include "CoopVec.h" +#include "CoopVecAPIExtensions.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -617,6 +620,10 @@ class ExecutionTest { TEST_METHOD(LongVector_Clamp_uint64); TEST_METHOD(LongVector_Initialize_uint64); + TEST_METHOD(CoopVec_Mul); + TEST_METHOD(CoopVec_OuterProduct); + + dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -752,7 +759,7 @@ class ExecutionTest { #endif } - bool UseDebugIfaces() { return true; } + bool UseDebugIfaces() { return false; } bool SaveImages() { return GetTestParamBool(L"SaveImages"); } @@ -775,6 +782,30 @@ class ExecutionTest { void RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic); + struct CoopVecMulSubtestConfig { + int InputPerThread; + int OutputPerThread; + int NumThreads; + int NumLevels; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + bool Bias; + }; + + void RunCoopVecMulTest(); + void RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); + void RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config); + + struct CoopVecOuterProductSubtestConfig { + int DimM; // Row Count + int DimN; // Column Count + int NumThreads; + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; + }; + + void RunCoopVecOuterProductTest(); + void RunCoopVecOuterProductTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); + void RunCoopVecOuterProductSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config); + template void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList, size_t numParameter, bool isPrefix); @@ -834,7 +865,8 @@ class ExecutionTest { void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { VERIFY_SUCCEEDED(m_support.Initialize()); CComPtr pCompiler; CComPtr pLibrary; @@ -847,7 +879,7 @@ class ExecutionTest { pText, (UINT32)strlen(pText), CP_UTF8, &pTextBlob)); VERIFY_SUCCEEDED(pCompiler->Compile(pTextBlob, L"hlsl.hlsl", pEntryPoint, pTargetProfile, pOptions, numOptions, - nullptr, 0, nullptr, &pResult)); + nullptr, 0, pIncludeHandler, &pResult)); VERIFY_SUCCEEDED(pResult->GetStatus(&resultCode)); if (FAILED(resultCode)) { #ifndef _HLK_CONF @@ -882,7 +914,8 @@ class ExecutionTest { ID3D12RootSignature *pRootSignature, LPCSTR pShader, LPCWSTR pTargetProfile, ID3D12PipelineState **ppComputeState, - LPCWSTR *pOptions = nullptr, int numOptions = 0) { + LPCWSTR *pOptions = nullptr, int numOptions = 0, + IDxcIncludeHandler *pIncludeHandler = nullptr) { CComPtr pComputeShader; // Load and compile shaders. @@ -892,7 +925,7 @@ class ExecutionTest { #endif } else { CompileFromText(pShader, L"main", pTargetProfile, &pComputeShader, - pOptions, numOptions); + pOptions, numOptions, pIncludeHandler); } // Describe and create the compute pipeline state object (PSO). @@ -1729,6 +1762,14 @@ class ExecutionTest { #endif } + bool DoesDeviceSupportCooperativeVector(ID3D12Device *pDevice) { + D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; + if (FAILED(pDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, sizeof(O)))) + return false; + return O.CooperativeVectorTier != D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; + } + bool IsFallbackPathEnabled() { // Enable fallback paths with: /p:"EnableFallback=1" UINT EnableFallbackValue = 0; @@ -1841,7 +1882,15 @@ class ExecutionTest { if (pD3D12EnableExperimentalFeatures == nullptr) { return HRESULT_FROM_WIN32(GetLastError()); } - return pD3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModelsID, + + std::vector Features; + + Features.push_back(D3D12ExperimentalShaderModels); + if (GetTestParamBool(L"CooperativeVectorExperimental")) { + Features.push_back(D3D12CooperativeVectorExperiment); + } + + return pD3D12EnableExperimentalFeatures((UINT)Features.size(), Features.data(), nullptr, nullptr); } @@ -11912,10 +11961,1120 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, TestConfig.Tolerance)); } -// This test expects a that retrieves a signal value from each of a -// few resources that are initialized here. determines if it uses -// the 6.6 Dynamic Resources feature. Values are read back from the result UAV -// and compared to the expected signals + +// Runs a set of tests for the Cooperative Vector Mul and MulAdd operations. +// The device will be queried for supported configurations and then each supported +// configuration will be tested against multiple matrix and vector sizes. To help +// reproduce individual test failures, the test will log the configuration it is +// running and the results of each test. The following filters can be used to limit +// test execution to a specific set of configurations: +// +// - CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, OUTER_PRODUCT_OPTIMAL +// - CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... +// - CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, FLOAT32, ... +// - CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... +// +// Filter example: +// TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 -p:CoopVecMatrixLayout=MUL_OPTIMAL +// +// The current implementation will always write the final output data as float. +void ExecutionTest::RunCoopVecMulTest() { + // Create device and verify coopvec support + CComPtr pD3DDevice; + if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { + WEX::Logging::Log::Comment( + "Device does not support SM 6.9. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + if (!DoesDeviceSupportCooperativeVector(pD3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; + VERIFY_SUCCEEDED( + pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, + &devOptions, sizeof(devOptions))); + + // Allocate memory for the arrays in devOptions + std::vector MulAddProps(devOptions.MatrixVectorMulAddPropCount); + devOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); + + VERIFY_SUCCEEDED( + pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, + &devOptions, sizeof(devOptions))); + + // Test each supported data type and matrix layout + for (auto MulAddConfig : MulAddProps) { + // Filter on preview test support + bool PreviewConfig = false; + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED && + MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 && + MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + PreviewConfig = true; + } + + if (!PreviewConfig) { + continue; + } + + // Apply filters + bool IsInFilter = CoopVecHelpers::IsDataTypeInFilter(L"CoopVecMatrixInterp", MulAddConfig.MatrixInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecBiasInterp", MulAddConfig.BiasInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputInterp", MulAddConfig.InputInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputType", MulAddConfig.InputType) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecOutputType", MulAddConfig.OutputType); + if (!IsInFilter) { + continue; + } + + + // Run the test + RunCoopVecMulTestConfig(pD3DDevice, MulAddConfig); + } +} + +void ExecutionTest::RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { + + LogCommentFmt(L"Running test for MatrixInterpretation: %s, BiasInterpretation: %s, InputInterpretation: %s, InputType: %s, OutputType: %s", + CoopVecHelpers::DataTypeToFilterString(MulProps.MatrixInterpretation).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.BiasInterpretation).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputInterpretation).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.OutputType).c_str()); + + constexpr CoopVecMulSubtestConfig TestConfigs[] = { + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + }; + + for (auto Config : TestConfigs) { + if ((MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter(L"CoopVecMatrixLayout", Config.MatrixLayout); + if (!IsInFilter) { + continue; + } + + RunCoopVecMulSubtest(pD3DDevice, MulProps, Config); + } +} + +void ExecutionTest::RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config) { + + LogCommentFmt(L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: %d, NumLevels: %d, Bias: %s, MatrixLayout: %s", + Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, Config.NumLevels, Config.Bias ? L"true" : L"false", CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + const int OUTPUT_BUFFER_SIZE = (Config.OutputPerThread * Config.NumThreads * 4); + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr pRootSignature; + { + CD3DX12_DESCRIPTOR_RANGE ranges[2]; + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, 0); // InputVector, InputMatrix, InputBias + ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer + CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, nullptr, 0); + } + + // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV + CComPtr pDescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + desc.NumDescriptors = 4; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&pDescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle(pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create the compute pipeline state for the CoopVec shader + CComPtr pComputePipelineState; + { + std::string shaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector : register(t0); +ByteAddressBuffer InputBias : register(t1); +ByteAddressBuffer InputMatrix : register(t2); +RWByteAddressBuffer OutputBuffer: register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset = (INPUT_PER_THREAD * threadIdx); + inputOffset = (inputOffset + 3) & ~3; // Align to 4 bytes + vector input = InputVector.Load >(inputOffset); + + MatrixRef mat = { InputMatrix, 0, STRIDE }; + + vector accum; + + if (USE_BIAS) { + VectorRef biasVec = { InputBias, 0 }; + accum = MulAdd(mat, MakeInterpretedVector(input), biasVec); + } else { + accum = Mul(mat, MakeInterpretedVector(input)); + } + + vector result = (vector)accum; + + // Ensure 4-byte alignment for vector store + uint outputOffset = OUTPUT_PER_THREAD * threadIdx; + outputOffset = (outputOffset + 3) & ~3; // Align to 4 bytes + OutputBuffer.Store >(outputOffset * 4, result); +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, const std::wstring &Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(MulProps.MatrixInterpretation); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.InputPerThread * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.OutputPerThread * StrideMultiplier; + break; + } + + const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(MulProps.InputInterpretation); + const std::wstring InputDataType = CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType, MulProps.InputInterpretation); + const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation, MulProps.BiasInterpretation); + const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.MatrixInterpretation); + const std::wstring InputInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.InputInterpretation); + const std::wstring AccumInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.BiasInterpretation); + + auto InputPerThreadDefine = CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); + auto OutputPerThreadDefine = CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); + auto NumThreadsDefine = CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); + auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); + auto InputInterpretationEnumDefine = CreateDefineFromString(L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); + auto HlslMatrixLayoutDefine = CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); + auto MatrixDataTypeEnumDefine = CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); + auto UseBiasDefine = CreateDefineFromInt(L"USE_BIAS", Config.Bias ? 1 : 0); + auto AccumInterpretationEnumDefine = CreateDefineFromString(L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); + + LPCWSTR pOptions[] = { + L"-enable-16bit-types", + InputPerThreadDefine.c_str(), + OutputPerThreadDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + UseBiasDefine.c_str(), + AccumInterpretationEnumDefine.c_str(), + }; + + CComPtr pIncludeHandler = new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), + L"cs_6_9", &pComputePipelineState, pOptions, + _countof(pOptions), pIncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr pCommandList; + CComPtr pCommandAllocator; + CComPtr pCommandQueue; + FenceObj FO; + CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", + &pCommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(pD3DDevice, &FO); + VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); + VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, pComputePipelineState, + IID_PPV_ARGS(&pCommandList))); + + + // Setup input data + auto ExpectedOutputBuffer = std::make_unique(Config.OutputPerThread * Config.NumThreads); + + // Setup input matrix as all-ones in sint8 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr pInputMatrixSRVResource, pInputMatrixSRVUploadResource; + std::vector inputMatrix; + if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.InputPerThread, Config.OutputPerThread); + } else if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix conversion + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.InputPerThread, Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), inputMatrix.size(), CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + + // Create input vector of an appropriate type. All integer types start as SINT8 for now. + CComPtr pInputVecSRVResource, pInputVecSRVUploadResource; + std::vector inputVector; + + if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && + (MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (inputVector.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + inputVector.resize(inputVector.size() + 4 - (inputVector.size() % 4)); + } + CreateTestResources(pD3DDevice, pCommandList, inputVector.data(), inputVector.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector.size()), &pInputVecSRVResource, &pInputVecSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector.size() / sizeof(int32_t)), pInputVecSRVResource); + + // Create input bias + CComPtr pInputBiasSRVResource, pInputBiasSRVUploadResource; + std::vector inputBias; + + if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else { + WEX::Logging::Log::Error(L"Unsupported bias data type"); + return; + } + + if (inputBias.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + inputBias.resize(inputBias.size() + 4 - (inputBias.size() % 4)); + } + CreateTestResources(pD3DDevice, pCommandList, inputBias.data(), inputBias.size(), CD3DX12_RESOURCE_DESC::Buffer(inputBias.size()), &pInputBiasSRVResource, &pInputBiasSRVUploadResource); + + // This increments baseHandle + CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputBias.size() / sizeof(int32_t)), pInputBiasSRVResource); + + // Calculate reference output + // FIXME: This does not capture all cases, but is sufficient for the preview feature set + if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { + // The input bias is really an array of int32_t + std::vector inputBiasI32(inputBias.size() / sizeof(int32_t)); + std::memcpy(inputBiasI32.data(), inputBias.data(), inputBias.size()); + + // The input vector is really an array of float if our vector input type is FLOAT32 + std::vector inputVectorF32(inputVector.size() / sizeof(int32_t)); + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::memcpy(inputVectorF32.data(), inputVector.data(), inputVector.size()); + } + + for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { + for (int outputIdx = 0; outputIdx < Config.OutputPerThread; ++outputIdx) { + int acc = 0; + + for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { + int inputElem; + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 ) { + inputElem = (int)inputVectorF32[threadIdx * Config.InputPerThread + inputIdx]; + } else { + inputElem = inputVector[threadIdx * Config.InputPerThread + inputIdx]; + } + int const matrixElem = inputMatrix[outputIdx * Config.InputPerThread + inputIdx]; + acc += inputElem * matrixElem; + } + + if (Config.Bias) { + acc += inputBiasI32[outputIdx]; + } + + float result = float(acc); + ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = result; + } + } + } else if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // The input bias/vector is really an array of float16 + std::vector inputVectorFP16(inputVector.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputVectorFP16.data(), inputVector.data(), inputVector.size()); + + std::vector inputBiasFP16(inputBias.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputBiasFP16.data(), inputBias.data(), inputBias.size()); + + // The CPU reference matrix is float + std::vector inputMatrixFP32(inputMatrix.size() / sizeof(float)); + std::memcpy(inputMatrixFP32.data(), inputMatrix.data(), inputMatrix.size()); + + for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { + for (int outputIdx = 0; outputIdx < Config.OutputPerThread; ++outputIdx) { + float acc = 0; + + for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { + float const inputElem = ConvertFloat16ToFloat32(inputVectorFP16[threadIdx * Config.InputPerThread + inputIdx]); + float const matrixElem = inputMatrixFP32[outputIdx * Config.InputPerThread + inputIdx]; + acc += inputElem * matrixElem; + } + + if (Config.Bias) { + acc += ConvertFloat16ToFloat32(inputBiasFP16[outputIdx]); + } + + float result = acc; + ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = result; + } + } + } + + CComPtr pConvertedMatrixResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; + convertInfo.SrcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); + convertInfo.SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + convertInfo.DestInfo.DestSize = 0; // Will be populated by driver + int srcEltSize = 0; + int destEltSize = 0; + switch (MulProps.MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + srcEltSize = 1; + destEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + srcEltSize = 4; // FP32 + destEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + srcEltSize = 4; // FP32 + destEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + srcEltSize = 4; // FP32 + destEltSize = 1; // FP8 + break; + } + convertInfo.SrcInfo.SrcStride = Config.InputPerThread * srcEltSize; + convertInfo.SrcInfo.SrcSize = Config.InputPerThread * Config.OutputPerThread * srcEltSize; + + convertInfo.DestInfo.DestLayout = Config.MatrixLayout; + convertInfo.DestInfo.DestStride = 0; + convertInfo.DestInfo.NumRows = Config.OutputPerThread; + convertInfo.DestInfo.NumColumns = Config.InputPerThread; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + convertInfo.DestInfo.DestStride = Config.InputPerThread * destEltSize; + } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + convertInfo.DestInfo.DestStride = Config.OutputPerThread * destEltSize; + } + + // Get destination size using preview interface + { + CComPtr previewDevice; + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + + // Query required destination size + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + } + + // Create resource to hold matrix copy + CreateTestResources(pD3DDevice, pCommandList, nullptr, 0, CD3DX12_RESOURCE_DESC::Buffer(convertInfo.DestInfo.DestSize), &pConvertedMatrixResource, nullptr); + + // Set up data descriptors + convertInfo.DataDesc.DestVA = pConvertedMatrixResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.SrcVA = pInputMatrixSRVResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr commandList11; + VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + + // This increments baseHandle + if ((convertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawSRV(pD3DDevice, baseHandle, convertInfo.DestInfo.DestSize / sizeof(int32_t), pConvertedMatrixResource); + } + + CComPtr pUavResource; + CComPtr pUavUploadResource; + CComPtr pUavReadResource; + + // Create buffer for output and fill with 0xFF to make it obvious if it's not + // written in the shader. + std::vector outputBufferInit(OUTPUT_BUFFER_SIZE); + std::fill(outputBufferInit.begin(), outputBufferInit.end(), (uint8_t)0xFF); + + CreateTestUavs(pD3DDevice, pCommandList, outputBufferInit.data(), OUTPUT_BUFFER_SIZE, &pUavResource, &pUavUploadResource, &pUavReadResource); + CreateRawUAV(pD3DDevice, baseHandle, OUTPUT_BUFFER_SIZE / 4, pUavResource); + + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + + SetDescriptorHeap(pCommandList, pDescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle(pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + pCommandList->SetComputeRootSignature(pRootSignature); + pCommandList->SetComputeRootDescriptorTable(0, resHandle); + pCommandList->SetPipelineState(pComputePipelineState); + pCommandList->Dispatch(1, 1, 1); + RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); + pCommandList->CopyResource(pUavReadResource, pUavResource); + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + + { + MappedData mappedData(pUavReadResource, OUTPUT_BUFFER_SIZE); + + float *resultBuffer = (float *)mappedData.data(); + bool equal = true; + for (int i = 0; i < OUTPUT_BUFFER_SIZE / sizeof(float); i++) { + if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, resultBuffer[i], i, ExpectedOutputBuffer[i]); + equal = false; + break; + } + } + VERIFY_IS_TRUE(equal); + } +} + +TEST_F(ExecutionTest, CoopVec_Mul) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + RunCoopVecMulTest(); +} + + +void ExecutionTest::RunCoopVecOuterProductTest() { + // Create device and verify coopvec support + CComPtr pD3DDevice; + if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { + WEX::Logging::Log::Comment( + "Device does not support SM 6.9. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + if (!DoesDeviceSupportCooperativeVector(pD3DDevice)) { + WEX::Logging::Log::Comment( + "Device does not support cooperative vector. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } + + // Query coopvec feature data. First call gets the size of the arrays. The + // second call populates the arrays using memory we allocate. + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; + VERIFY_SUCCEEDED( + pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, + &devOptions, sizeof(devOptions))); + + // Allocate memory for the arrays in devOptions + std::vector AccumulateProps(devOptions.OuterProductAccumulatePropCount); + devOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); + + VERIFY_SUCCEEDED( + pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, + &devOptions, sizeof(devOptions))); + + // Test each supported data type and matrix layout + for (auto AccumulateConfig : AccumulateProps) { + // Run the test + RunCoopVecOuterProductTestConfig(pD3DDevice, AccumulateConfig); + } +} + +void ExecutionTest::RunCoopVecOuterProductTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { + UNREFERENCED_PARAMETER(pD3DDevice); + UNREFERENCED_PARAMETER(AccumulateProps); + + LogCommentFmt(L"Running test for InputType: %s, AccumulationType: %s", + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.AccumulationType).c_str()); + + constexpr CoopVecOuterProductSubtestConfig TestConfigs[] = { + {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL }, + }; + + for (auto Config : TestConfigs) { + if ((AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || + Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + continue; + } + + RunCoopVecOuterProductSubtest(pD3DDevice, AccumulateProps, Config); + } +} + +void ExecutionTest::RunCoopVecOuterProductSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config) { + UNREFERENCED_PARAMETER(pD3DDevice); + UNREFERENCED_PARAMETER(AccumulateProps); + UNREFERENCED_PARAMETER(Config); + + LogCommentFmt(L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", + Config.DimM, + Config.DimN, + Config.NumThreads, + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + + + // Create root signature with a single root entry for all SRVs and UAVs + CComPtr pRootSignature; + { + CD3DX12_DESCRIPTOR_RANGE ranges[2]; + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0); // InputVector1, InputVector2 + ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix + CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, nullptr, 0); + } + + // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV + CComPtr pDescriptorHeap; + { + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + desc.NumDescriptors = 3; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&pDescriptorHeap))); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle(pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + + // Create a compute pipeline state object. + CComPtr pComputePipelineState; + { + std::string shaderSource = R"( +#include "dx/linalg.h" + +ByteAddressBuffer InputVector1 : register(t0); +ByteAddressBuffer InputVector2 : register(t1); +RWByteAddressBuffer AccumMatrix : register(u0); + +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ +#if 1 + using namespace dx::linalg; + + // Ensure 4-byte alignment for vector loads + uint inputOffset1 = (DIM_M * threadIdx); + inputOffset1 = (inputOffset1 + 3) & ~3; // Align to 4 bytes + vector input1 = InputVector1.Load >(inputOffset1); + + uint inputOffset2 = (DIM_N * threadIdx); + inputOffset2 = (inputOffset2 + 3) & ~3; // Align to 4 bytes + vector input2 = InputVector2.Load >(inputOffset2); + + RWMatrixRef mat = { AccumMatrix, 0, STRIDE }; + + OuterProductAccumulate(input1, input2, mat); +#endif +} + )"; + + auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + int Stride = 0; + const std::wstring HlslMatrixLayout = CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(AccumulateProps.AccumulationType); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.DimN * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.DimM * StrideMultiplier; + break; + } + + const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(AccumulateProps.InputType); + const std::wstring InputDataType = CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType, AccumulateProps.InputType); + const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.AccumulationType, AccumulateProps.AccumulationType); + const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType(AccumulateProps.AccumulationType); + const std::wstring InputInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(AccumulateProps.InputType); + + auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM); + auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN); + auto NumThreadsDefine = CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); + auto InputDataTypeDefine = CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); + auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); + auto InputInterpretationEnumDefine = CreateDefineFromString(L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); + auto HlslMatrixLayoutDefine = CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); + auto MatrixDataTypeEnumDefine = CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); + + LPCWSTR pOptions[] = { + L"-enable-16bit-types", + DimMDefine.c_str(), + DimNDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + }; + + CComPtr pIncludeHandler = new LinAlgHeaderIncludeHandler(m_support); + + CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), + L"cs_6_9", &pComputePipelineState, pOptions, + _countof(pOptions), pIncludeHandler); + } + + // Create a command list for the compute shader. + CComPtr pCommandList; + CComPtr pCommandAllocator; + CComPtr pCommandQueue; + FenceObj FO; + CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", + &pCommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); + InitFenceObj(pD3DDevice, &FO); + VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); + VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, pComputePipelineState, + IID_PPV_ARGS(&pCommandList))); + + // Setup input matrix as all-ones in sint8/fp32 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + CComPtr pInputMatrixSRVResource, pInputMatrixSRVUploadResource; + std::vector inputMatrix; + if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + } else if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix conversion + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + } else { + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return; + } + + CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), inputMatrix.size(), CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + + // Create input vectors + CComPtr pInputVecSRVResource1, pInputVecSRVUploadResource1; + std::vector inputVector1; + CComPtr pInputVecSRVResource2, pInputVecSRVUploadResource2; + std::vector inputVector2; + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); + inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); + } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); + inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); + } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); + inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); + } else { + WEX::Logging::Log::Error(L"Unsupported input data type"); + return; + } + if (inputVector1.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + inputVector1.resize(inputVector1.size() + 4 - (inputVector1.size() % 4)); + } + if (inputVector2.size() % 4 != 0) { + // Align size to 4 bytes for ByteAddressBuffer + inputVector2.resize(inputVector2.size() + 4 - (inputVector2.size() % 4)); + } + CreateTestResources(pD3DDevice, pCommandList, inputVector1.data(), inputVector1.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector1.size()), &pInputVecSRVResource1, &pInputVecSRVUploadResource1); + CreateTestResources(pD3DDevice, pCommandList, inputVector2.data(), inputVector2.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector2.size()), &pInputVecSRVResource2, &pInputVecSRVUploadResource2); + + // This increments baseHandle + CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector1.size() / sizeof(int32_t)), pInputVecSRVResource1); + CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector2.size() / sizeof(int32_t)), pInputVecSRVResource2); + + // Calculate reference output + auto ExpectedOutputBufferI8 = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / sizeof(float)); + std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), ExpectedOutputBufferI8.size()); + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + std::vector inputVector1FP16(inputVector1.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputVector1FP16.data(), inputVector1.data(), inputVector1.size()); + + std::vector inputVector2FP16(inputVector2.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputVector2FP16.data(), inputVector2.data(), inputVector2.size()); + + for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { + for (int m = 0; m < Config.DimM; ++m) { + for (int n = 0; n < Config.DimN; ++n) { + float acc = ConvertFloat16ToFloat32(inputVector1FP16[m]) * ConvertFloat16ToFloat32(inputVector2FP16[n]); + ExpectedOutputBuffer[m * Config.DimN + n] += acc; + } + } + } + } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + std::vector inputVector1FP32(inputVector1.size() / sizeof(float)); + std::memcpy(inputVector1FP32.data(), inputVector1.data(), inputVector1.size()); + + std::vector inputVector2FP32(inputVector2.size() / sizeof(float)); + std::memcpy(inputVector2FP32.data(), inputVector2.data(), inputVector2.size()); + + for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { + for (int m = 0; m < Config.DimM; ++m) { + for (int n = 0; n < Config.DimN; ++n) { + float acc = inputVector1FP32[m] * inputVector2FP32[n]; + ExpectedOutputBuffer[m * Config.DimN + n] += acc; + } + } + } + } + + + CComPtr pConvertedMatrixResource, pConvertedMatrixReadResource; + int ConvertedMatrixSize = 0; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO srcInfo = {}; + srcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); + srcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO destInfo = {}; + destInfo.DestSize = 0; // Will be populated by driver + int srcEltSize = 0; + int destEltSize = 0; + switch (AccumulateProps.AccumulationType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + srcEltSize = 1; + destEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + srcEltSize = 4; // FP32 + destEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + srcEltSize = 4; // FP32 + destEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + srcEltSize = 4; // FP32 + destEltSize = 1; // FP8 + break; + } + srcInfo.SrcStride = Config.DimM * srcEltSize; + srcInfo.SrcSize = Config.DimM * Config.DimN * srcEltSize; + + destInfo.DestLayout = Config.MatrixLayout; + destInfo.DestStride = 0; + destInfo.NumRows = Config.DimM; + destInfo.NumColumns = Config.DimN; + + if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + destInfo.DestStride = Config.DimM * destEltSize; + } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + destInfo.DestStride = Config.DimM * destEltSize; + } + + // Create conversion info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; + convertInfo.SrcInfo = srcInfo; + convertInfo.DestInfo = destInfo; + + // Get preview device interface + { + CComPtr previewDevice; + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + + // Query required destination size + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + } + + ConvertedMatrixSize = convertInfo.DestInfo.DestSize; + + // Hack to prevent read resource from being created with size 0 + std::vector tempData(convertInfo.DestInfo.DestSize); + CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), &pConvertedMatrixResource, nullptr, &pConvertedMatrixReadResource); + + // Set up data descriptors + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA dataDesc = {}; + dataDesc.DestVA = pConvertedMatrixResource->GetGPUVirtualAddress(); + dataDesc.SrcVA = pInputMatrixSRVResource->GetGPUVirtualAddress(); + convertInfo.DataDesc = dataDesc; + + // Get command list interface and perform conversion + CComPtr commandList11; + VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + + // This increments baseHandle + if ((convertInfo.DestInfo.DestSize % 4) != 0) { + WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); + return; + } + CreateRawUAV(pD3DDevice, baseHandle, convertInfo.DestInfo.DestSize / sizeof(int32_t), pConvertedMatrixResource); + } + + + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + + SetDescriptorHeap(pCommandList, pDescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle(pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + pCommandList->SetComputeRootSignature(pRootSignature); + pCommandList->SetComputeRootDescriptorTable(0, resHandle); + pCommandList->SetPipelineState(pComputePipelineState); + pCommandList->Dispatch(1, 1, 1); + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + VERIFY_SUCCEEDED(pCommandAllocator->Reset()); + VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + + + // Convert matrix to sint8/fp32 row-major format before reading back to the CPU. + // A new resource is created, along with a readback resource, for the matrix copy. + CComPtr pMatrixRowMajorResource, pMatrixRowMajorReadResource; + { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; + convertInfo.SrcInfo.SrcLayout = Config.MatrixLayout; + convertInfo.SrcInfo.SrcSize = ConvertedMatrixSize; + convertInfo.SrcInfo.SrcDataType = AccumulateProps.AccumulationType; + convertInfo.SrcInfo.SrcStride = 0; // OUTER_PRODUCT_OPTIMAL + + // Create destination matrix info + convertInfo.DestInfo.DestSize = 0; // Will be populated by driver + convertInfo.DestInfo.DestLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + convertInfo.DestInfo.NumRows = Config.DimM; + convertInfo.DestInfo.NumColumns = Config.DimN; + + if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + convertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); + } else { + convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + convertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); + } + + // Get destination size using preview interface + { + CComPtr previewDevice; + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + + // Query required destination size + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + } + + // Create resource to hold matrix copy and a readback resource for it + // Init vector is a hack to prevent read resource from being created with size 0 + // TODO: Fix CreateTestUavs to allow creating readback resource without init data + std::vector tempData(convertInfo.DestInfo.DestSize); + CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), &pMatrixRowMajorResource, nullptr, &pMatrixRowMajorReadResource); + + // Set up data descriptors + convertInfo.DataDesc.DestVA = pMatrixRowMajorResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.SrcVA = pConvertedMatrixResource->GetGPUVirtualAddress(); + + // Get command list interface and perform conversion + CComPtr commandList11; + VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + } + + RecordTransitionBarrier(pCommandList, pMatrixRowMajorResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); + pCommandList->CopyResource(pMatrixRowMajorReadResource, pMatrixRowMajorResource); + pCommandList->Close(); + ExecuteCommandList(pCommandQueue, pCommandList); + WaitForSignal(pCommandQueue, FO); + + { + MappedData mappedData(pMatrixRowMajorReadResource, (UINT)inputMatrix.size()); + + float *resultBuffer = (float *)mappedData.data(); + bool equal = true; + for (int i = 0; i < (UINT)inputMatrix.size() / sizeof(float); i++) { + if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + LogErrorFmt(L"Result mismatch at index %d", i); + LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, resultBuffer[i], i, ExpectedOutputBuffer[i]); + equal = false; + break; + } + } + VERIFY_IS_TRUE(equal); + } +} + +TEST_F(ExecutionTest, CoopVec_OuterProduct) { + WEX::TestExecution::SetVerifyOutput verifySettings( + WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + RunCoopVecOuterProductTest(); +} + + // This test expects a that retrieves a signal value from each of a + // few resources that are initialized here. determines if it uses + // the 6.6 Dynamic Resources feature. Values are read back from the result UAV + // and compared to the expected signals void ExecutionTest::RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic) { WEX::TestExecution::SetVerifyOutput verifySettings( From 0690152a76a4df0277199fc8a5ed020ef3e1076a Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Fri, 2 May 2025 18:52:02 -0400 Subject: [PATCH 02/17] clang-format --- tools/clang/unittests/HLSLExec/CoopVec.h | 334 +++--- .../unittests/HLSLExec/ExecutionTest.cpp | 1002 +++++++++++------ 2 files changed, 816 insertions(+), 520 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 67c14ac987..3ba2c2babf 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include +#include #include "dxc/Support/microcom.h" @@ -15,21 +15,25 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { public: LinAlgHeaderIncludeHandler() = delete; - LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) : m_dwRef(0), DxcSupport(DxcSupport) {} + LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) + : m_dwRef(0), DxcSupport(DxcSupport) {} DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef) - HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR pFilename, IDxcBlob **ppIncludeSource) { - if (wcscmp(pFilename, L"dx/linalg.h") == 0 || wcscmp(pFilename, L".\\dx\\linalg.h") == 0) { + HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR pFilename, + IDxcBlob **ppIncludeSource) { + if (wcscmp(pFilename, L"dx/linalg.h") == 0 || + wcscmp(pFilename, L".\\dx\\linalg.h") == 0) { WEX::Common::String ParamValue; - if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(L"LinAlgHeader", - ParamValue))) { + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( + L"LinAlgHeader", ParamValue))) { return E_FAIL; } if (ParamValue.IsEmpty()) { return E_FAIL; } - LPCWSTR RealHeaderPath = reinterpret_cast(ParamValue.GetBuffer()); + LPCWSTR RealHeaderPath = + reinterpret_cast(ParamValue.GetBuffer()); CComPtr pHeaderUtils; @@ -46,10 +50,10 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { } HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, - void **ppvObject) override { + void **ppvObject) override { // FIXME: This is a workaround for a warning-as-error about unused parameters. #pragma warning(push) -#pragma warning(disable: 4100) +#pragma warning(disable : 4100) return DoBasicQueryInterface(this, iid, ppvObject); #pragma warning(pop) } @@ -57,10 +61,12 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { struct CoopVecHelpers { template - static std::vector CreateAllOnesInputMatrix(uint32_t Width, uint32_t Height) { + static std::vector CreateAllOnesInputMatrix(uint32_t Width, + uint32_t Height) { std::vector inputMatrix(Width * Height); for (uint32_t i = 0; i < Width * Height; i++) { - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { inputMatrix[i] = 1; } else if constexpr (std::is_same_v) { inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); @@ -71,12 +77,14 @@ struct CoopVecHelpers { // Convert to uint8_t vector std::vector uint8InputMatrix(inputMatrix.size() * sizeof(EltTy)); - std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), inputMatrix.size() * sizeof(EltTy)); + std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), + inputMatrix.size() * sizeof(EltTy)); return uint8InputMatrix; } template - static std::vector CreateInputVector(uint32_t NumThreads, uint32_t EltsPerThread) { + static std::vector CreateInputVector(uint32_t NumThreads, + uint32_t EltsPerThread) { std::vector inputVector(NumThreads * EltsPerThread); std::fill(inputVector.begin(), inputVector.end(), EltTy(0)); if (EltsPerThread < 2) { @@ -84,7 +92,8 @@ struct CoopVecHelpers { return std::vector(); } for (uint32_t TID = 0; TID < NumThreads; TID++) { - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { inputVector[TID * EltsPerThread + 0] = 1; inputVector[TID * EltsPerThread + 1] = 1; } else if constexpr (std::is_same_v) { @@ -100,64 +109,69 @@ struct CoopVecHelpers { // Convert to uint8_t vector std::vector uint8InputVector(inputVector.size() * sizeof(EltTy)); - std::memcpy(uint8InputVector.data(), inputVector.data(), inputVector.size() * sizeof(EltTy)); + std::memcpy(uint8InputVector.data(), inputVector.data(), + inputVector.size() * sizeof(EltTy)); return uint8InputVector; } template static std::vector CreateInputBias(uint32_t NumElts) { std::vector inputBias(NumElts); - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { std::fill(inputBias.begin(), inputBias.end(), EltTy(1)); } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), ConvertFloat32ToFloat16(1.0f)); + std::fill(inputBias.begin(), inputBias.end(), + ConvertFloat32ToFloat16(1.0f)); } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), 1); + std::fill(inputBias.begin(), inputBias.end(), 1); } else { WEX::Logging::Log::Error(L"Unsupported bias type"); } // Convert to uint8_t vector std::vector uint8InputBias(inputBias.size() * sizeof(EltTy)); - std::memcpy(uint8InputBias.data(), inputBias.data(), inputBias.size() * sizeof(EltTy)); + std::memcpy(uint8InputBias.data(), inputBias.data(), + inputBias.size() * sizeof(EltTy)); return uint8InputBias; } - static std::wstring DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + static std::wstring + DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - return L"SINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return L"UINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - return L"SINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - return L"UINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"SINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"UINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"SINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"UINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"FLOAT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"FLOAT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - return L"FLOAT_E4M3"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return L"FLOAT_E5M2"; - default: - return L""; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"FLOAT_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"FLOAT_E5M2"; + default: + return L""; } } static bool IsDataTypeInFilter(const wchar_t *FilterKey, D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { WEX::Common::String ParamValue; - if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, - ParamValue))) { + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( + FilterKey, ParamValue))) { // Filter not set, so treat as no filter return true; } @@ -171,26 +185,28 @@ struct CoopVecHelpers { return DataTypeToFilterString(DataType) == FilterString; } - static std::wstring MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + static std::wstring + MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { switch (MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - return L"ROW_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - return L"COLUMN_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: - return L"MUL_OPTIMAL"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: - return L"OUTER_PRODUCT_OPTIMAL"; - default: - return L""; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"OUTER_PRODUCT_OPTIMAL"; + default: + return L""; } } - static bool IsMatrixLayoutInFilter(const wchar_t *FilterKey, - D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + static bool + IsMatrixLayoutInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { WEX::Common::String ParamValue; - if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, - ParamValue))) { + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( + FilterKey, ParamValue))) { // Filter not set, so treat as no filter return true; } @@ -204,127 +220,135 @@ struct CoopVecHelpers { return MatrixLayoutToFilterString(MatrixLayout) == FilterString; } - static std::wstring MatrixLayoutToHlslLayoutString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + static std::wstring MatrixLayoutToHlslLayoutString( + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { switch (MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - return L"MATRIX_LAYOUT_ROW_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - return L"MATRIX_LAYOUT_COLUMN_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: - return L"MATRIX_LAYOUT_MUL_OPTIMAL"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: - return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; - default: - return L""; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"MATRIX_LAYOUT_ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"MATRIX_LAYOUT_COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MATRIX_LAYOUT_MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; + default: + return L""; } } // This multiplier is used to compute the row/column stride for a matrix // given it's element size. - static int GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + static int + GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return 1; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return 2; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return 4; - default: - WEX::Logging::Log::Error(L"Unsupported matrix data type"); - return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return 2; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return 4; + default: + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return 1; } } - static int GetNumPackedElementsForInputDataType(D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { - // Int8 packed types are the only ones that have more than 1 element per shader variable + static int GetNumPackedElementsForInputDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + // Int8 packed types are the only ones that have more than 1 element per + // shader variable switch (InputInterpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return 4; - default: - return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return 4; + default: + return 1; } } // This type is used in generated HLSL source to represent the vector type // for the given data type. - static std::wstring GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType, D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + static std::wstring GetHlslDataTypeForDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE DataType, + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { UNREFERENCED_PARAMETER(InputInterpretation); switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"int16_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"uint16_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"int32_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"uint32_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"half"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"float"; - default: - WEX::Logging::Log::Error(L"Unsupported input data type"); - return L""; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"int16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"uint16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"int32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"uint32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"half"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"float"; + default: + WEX::Logging::Log::Error(L"Unsupported input data type"); + return L""; } } - static std::wstring GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { + static std::wstring GetHlslInterpretationForDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { switch (Interpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - return L"DATA_TYPE_SINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return L"DATA_TYPE_UINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - return L"DATA_TYPE_SINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - return L"DATA_TYPE_UINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"DATA_TYPE_SINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"DATA_TYPE_UINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"DATA_TYPE_SINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"DATA_TYPE_UINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"DATA_TYPE_FLOAT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"DATA_TYPE_FLOAT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - return L"DATA_TYPE_FLOAT8_E4M3"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return L"DATA_TYPE_FLOAT8_E5M2"; - default: - WEX::Logging::Log::Error(L"Unsupported interpretation"); - return L""; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"DATA_TYPE_SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"DATA_TYPE_UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"DATA_TYPE_SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"DATA_TYPE_UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"DATA_TYPE_SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"DATA_TYPE_UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"DATA_TYPE_SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"DATA_TYPE_UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"DATA_TYPE_FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"DATA_TYPE_FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"DATA_TYPE_FLOAT8_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"DATA_TYPE_FLOAT8_E5M2"; + default: + WEX::Logging::Log::Error(L"Unsupported interpretation"); + return L""; } } // The returned data type is used for matrix conversion. It is hard-coded // for the test framework where all integer matrices start as SINT8 and // all FP matrices start as FLOAT32. - static D3D12_LINEAR_ALGEBRA_DATATYPE GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { + static D3D12_LINEAR_ALGEBRA_DATATYPE + GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { switch (MatrixInterpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - default: - return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + default: + return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; } } }; \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 011e28e42e..331dcd3542 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -623,7 +623,6 @@ class ExecutionTest { TEST_METHOD(CoopVec_Mul); TEST_METHOD(CoopVec_OuterProduct); - dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -792,19 +791,28 @@ class ExecutionTest { }; void RunCoopVecMulTest(); - void RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); - void RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config); + void + RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); + void RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config); struct CoopVecOuterProductSubtestConfig { - int DimM; // Row Count - int DimN; // Column Count + int DimM; // Row Count + int DimN; // Column Count int NumThreads; D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; }; void RunCoopVecOuterProductTest(); - void RunCoopVecOuterProductTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); - void RunCoopVecOuterProductSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config); + void RunCoopVecOuterProductTestConfig( + ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); + void RunCoopVecOuterProductSubtest( + ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config); template void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList, @@ -1765,9 +1773,11 @@ class ExecutionTest { bool DoesDeviceSupportCooperativeVector(ID3D12Device *pDevice) { D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; if (FAILED(pDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, sizeof(O)))) + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, + sizeof(O)))) return false; - return O.CooperativeVectorTier != D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; + return O.CooperativeVectorTier != + D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; } bool IsFallbackPathEnabled() { @@ -1884,14 +1894,14 @@ class ExecutionTest { } std::vector Features; - + Features.push_back(D3D12ExperimentalShaderModels); if (GetTestParamBool(L"CooperativeVectorExperimental")) { Features.push_back(D3D12CooperativeVectorExperiment); } - return pD3D12EnableExperimentalFeatures((UINT)Features.size(), Features.data(), - nullptr, nullptr); + return pD3D12EnableExperimentalFeatures((UINT)Features.size(), + Features.data(), nullptr, nullptr); } static HRESULT EnableExperimentalShaderModels() { @@ -11961,31 +11971,33 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, TestConfig.Tolerance)); } - // Runs a set of tests for the Cooperative Vector Mul and MulAdd operations. -// The device will be queried for supported configurations and then each supported -// configuration will be tested against multiple matrix and vector sizes. To help -// reproduce individual test failures, the test will log the configuration it is -// running and the results of each test. The following filters can be used to limit -// test execution to a specific set of configurations: +// The device will be queried for supported configurations and then each +// supported configuration will be tested against multiple matrix and vector +// sizes. To help reproduce individual test failures, the test will log the +// configuration it is running and the results of each test. The following +// filters can be used to limit test execution to a specific set of +// configurations: // // - CoopVecMatrixInterp: SINT8, FLOAT16, FLOAT_E4M3, ... -// - CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, OUTER_PRODUCT_OPTIMAL +// - CoopVecMatrixLayout: ROW_MAJOR, COLUMN_MAJOR, MUL_OPTIMAL, +// OUTER_PRODUCT_OPTIMAL // - CoopVecBiasInterp: SINT32, FLOAT16, FLOAT_E4M3, ... // - CoopVecInputInterp: SINT8, FLOAT16, FLOAT_E4M3, ... -// - CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, FLOAT32, ... +// - CoopVecInputType: SINT8, UINT8, SINT16, UINT16, SINT32, UINT32, FLOAT16, +// FLOAT32, ... // - CoopVecOutputType: SINT32, UINT32, FLOAT16, FLOAT32, ... // // Filter example: -// TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 -p:CoopVecMatrixLayout=MUL_OPTIMAL +// TE.exe ... -p:CoopVecMatrixInterp=FLOAT16 +// -p:CoopVecMatrixLayout=MUL_OPTIMAL // // The current implementation will always write the final output data as float. void ExecutionTest::RunCoopVecMulTest() { // Create device and verify coopvec support CComPtr pD3DDevice; if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { - WEX::Logging::Log::Comment( - "Device does not support SM 6.9. Skipping."); + WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } @@ -11999,58 +12011,74 @@ void ExecutionTest::RunCoopVecMulTest() { // Query coopvec feature data. First call gets the size of the arrays. The // second call populates the arrays using memory we allocate. D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; - VERIFY_SUCCEEDED( - pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, - &devOptions, sizeof(devOptions))); + VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, + sizeof(devOptions))); // Allocate memory for the arrays in devOptions - std::vector MulAddProps(devOptions.MatrixVectorMulAddPropCount); + std::vector MulAddProps( + devOptions.MatrixVectorMulAddPropCount); devOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); - VERIFY_SUCCEEDED( - pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, - &devOptions, sizeof(devOptions))); + VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, + sizeof(devOptions))); // Test each supported data type and matrix layout for (auto MulAddConfig : MulAddProps) { // Filter on preview test support bool PreviewConfig = false; if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { PreviewConfig = true; } - + if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && - MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 && MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { PreviewConfig = true; } if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && - MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && - MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2 && MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { PreviewConfig = true; } if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && - MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED && - MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && - MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { PreviewConfig = true; } if (MulAddConfig.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 && - MulAddConfig.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && - MulAddConfig.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && - MulAddConfig.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && + MulAddConfig.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 && + MulAddConfig.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 && MulAddConfig.OutputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { PreviewConfig = true; } @@ -12060,74 +12088,98 @@ void ExecutionTest::RunCoopVecMulTest() { } // Apply filters - bool IsInFilter = CoopVecHelpers::IsDataTypeInFilter(L"CoopVecMatrixInterp", MulAddConfig.MatrixInterpretation) && - CoopVecHelpers::IsDataTypeInFilter(L"CoopVecBiasInterp", MulAddConfig.BiasInterpretation) && - CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputInterp", MulAddConfig.InputInterpretation) && - CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputType", MulAddConfig.InputType) && - CoopVecHelpers::IsDataTypeInFilter(L"CoopVecOutputType", MulAddConfig.OutputType); + bool IsInFilter = + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecMatrixInterp", + MulAddConfig.MatrixInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecBiasInterp", + MulAddConfig.BiasInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputInterp", + MulAddConfig.InputInterpretation) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecInputType", + MulAddConfig.InputType) && + CoopVecHelpers::IsDataTypeInFilter(L"CoopVecOutputType", + MulAddConfig.OutputType); if (!IsInFilter) { continue; } - // Run the test RunCoopVecMulTestConfig(pD3DDevice, MulAddConfig); } } -void ExecutionTest::RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { +void ExecutionTest::RunCoopVecMulTestConfig( + ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { - LogCommentFmt(L"Running test for MatrixInterpretation: %s, BiasInterpretation: %s, InputInterpretation: %s, InputType: %s, OutputType: %s", - CoopVecHelpers::DataTypeToFilterString(MulProps.MatrixInterpretation).c_str(), - CoopVecHelpers::DataTypeToFilterString(MulProps.BiasInterpretation).c_str(), - CoopVecHelpers::DataTypeToFilterString(MulProps.InputInterpretation).c_str(), + LogCommentFmt( + L"Running test for MatrixInterpretation: %s, BiasInterpretation: %s, " + L"InputInterpretation: %s, InputType: %s, OutputType: %s", + CoopVecHelpers::DataTypeToFilterString(MulProps.MatrixInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.BiasInterpretation) + .c_str(), + CoopVecHelpers::DataTypeToFilterString(MulProps.InputInterpretation) + .c_str(), CoopVecHelpers::DataTypeToFilterString(MulProps.InputType).c_str(), CoopVecHelpers::DataTypeToFilterString(MulProps.OutputType).c_str()); constexpr CoopVecMulSubtestConfig TestConfigs[] = { - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, - {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, - {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, - {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, - {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, }; for (auto Config : TestConfigs) { - if ((MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + if ((MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || - Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { continue; } - bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter(L"CoopVecMatrixLayout", Config.MatrixLayout); + bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter( + L"CoopVecMatrixLayout", Config.MatrixLayout); if (!IsInFilter) { continue; } @@ -12136,20 +12188,29 @@ void ExecutionTest::RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, D3D12_COOP } } -void ExecutionTest::RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config) { +void ExecutionTest::RunCoopVecMulSubtest( + ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + CoopVecMulSubtestConfig &Config) { - LogCommentFmt(L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: %d, NumLevels: %d, Bias: %s, MatrixLayout: %s", - Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, Config.NumLevels, Config.Bias ? L"true" : L"false", CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + LogCommentFmt( + L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: " + L"%d, NumLevels: %d, Bias: %s, MatrixLayout: %s", + Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, + Config.NumLevels, Config.Bias ? L"true" : L"false", + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); - const int OUTPUT_BUFFER_SIZE = (Config.OutputPerThread * Config.NumThreads * 4); + const int OUTPUT_BUFFER_SIZE = + (Config.OutputPerThread * Config.NumThreads * 4); // Create root signature with a single root entry for all SRVs and UAVs CComPtr pRootSignature; { CD3DX12_DESCRIPTOR_RANGE ranges[2]; - ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, 0); // InputVector, InputMatrix, InputBias + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, + 0); // InputVector, InputMatrix, InputBias ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer - CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, nullptr, 0); + CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, + nullptr, 0); } // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV @@ -12159,9 +12220,11 @@ void ExecutionTest::RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERA desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; desc.NumDescriptors = 4; desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&pDescriptorHeap))); + VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap( + &desc, IID_PPV_ARGS(&pDescriptorHeap))); } - CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle(pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); // Create the compute pipeline state for the CoopVec shader CComPtr pComputePipelineState; @@ -12211,65 +12274,91 @@ void main(uint threadIdx : SV_GroupThreadID) return Stream.str(); }; - auto CreateDefineFromString = [](const wchar_t *Name, const std::wstring &Value) { + auto CreateDefineFromString = [](const wchar_t *Name, + const std::wstring &Value) { std::wstringstream Stream; Stream << L"-D" << Name << L"=" << Value; return Stream.str(); }; int Stride = 0; - const std::wstring HlslMatrixLayout = CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); - int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(MulProps.MatrixInterpretation); + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + MulProps.MatrixInterpretation); switch (Config.MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - Stride = Config.InputPerThread * StrideMultiplier; - break; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - Stride = Config.OutputPerThread * StrideMultiplier; - break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.InputPerThread * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.OutputPerThread * StrideMultiplier; + break; } - const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(MulProps.InputInterpretation); - const std::wstring InputDataType = CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType, MulProps.InputInterpretation); - const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation, MulProps.BiasInterpretation); - const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.MatrixInterpretation); - const std::wstring InputInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.InputInterpretation); - const std::wstring AccumInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(MulProps.BiasInterpretation); - - auto InputPerThreadDefine = CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); - auto OutputPerThreadDefine = CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); - auto NumThreadsDefine = CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + MulProps.InputInterpretation); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType( + MulProps.InputType, MulProps.InputInterpretation); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation, + MulProps.BiasInterpretation); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.MatrixInterpretation); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.InputInterpretation); + const std::wstring AccumInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.BiasInterpretation); + + auto InputPerThreadDefine = + CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); + auto OutputPerThreadDefine = + CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); - auto InputDataTypeDefine = CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); - auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); - auto AccumDataTypeDefine = CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); - auto InputInterpretationEnumDefine = CreateDefineFromString(L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); - auto HlslMatrixLayoutDefine = CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); - auto MatrixDataTypeEnumDefine = CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); + auto MatrixDataTypeEnumDefine = + CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); auto UseBiasDefine = CreateDefineFromInt(L"USE_BIAS", Config.Bias ? 1 : 0); - auto AccumInterpretationEnumDefine = CreateDefineFromString(L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); + auto AccumInterpretationEnumDefine = CreateDefineFromString( + L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); LPCWSTR pOptions[] = { - L"-enable-16bit-types", - InputPerThreadDefine.c_str(), - OutputPerThreadDefine.c_str(), - NumThreadsDefine.c_str(), - StrideDefine.c_str(), - InputDataTypeDefine.c_str(), - InputDivisorDefine.c_str(), - AccumDataTypeDefine.c_str(), - InputInterpretationEnumDefine.c_str(), - HlslMatrixLayoutDefine.c_str(), - MatrixDataTypeEnumDefine.c_str(), - UseBiasDefine.c_str(), - AccumInterpretationEnumDefine.c_str(), + L"-enable-16bit-types", + InputPerThreadDefine.c_str(), + OutputPerThreadDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + UseBiasDefine.c_str(), + AccumInterpretationEnumDefine.c_str(), }; - CComPtr pIncludeHandler = new LinAlgHeaderIncludeHandler(m_support); + CComPtr pIncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), - L"cs_6_9", &pComputePipelineState, pOptions, - _countof(pOptions), pIncludeHandler); + L"cs_6_9", &pComputePipelineState, pOptions, + _countof(pOptions), pIncludeHandler); } // Create a command list for the compute shader. @@ -12277,60 +12366,81 @@ void main(uint threadIdx : SV_GroupThreadID) CComPtr pCommandAllocator; CComPtr pCommandQueue; FenceObj FO; - CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", - &pCommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); + CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", &pCommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); InitFenceObj(pD3DDevice, &FO); VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, pComputePipelineState, - IID_PPV_ARGS(&pCommandList))); - + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, + pComputePipelineState, IID_PPV_ARGS(&pCommandList))); // Setup input data - auto ExpectedOutputBuffer = std::make_unique(Config.OutputPerThread * Config.NumThreads); + auto ExpectedOutputBuffer = + std::make_unique(Config.OutputPerThread * Config.NumThreads); // Setup input matrix as all-ones in sint8 format. This will later be // converted to the appropriate data type by the matrix conversion API. - CComPtr pInputMatrixSRVResource, pInputMatrixSRVUploadResource; + CComPtr pInputMatrixSRVResource, + pInputMatrixSRVUploadResource; std::vector inputMatrix; - if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.InputPerThread, Config.OutputPerThread); - } else if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - // Matrix source data is fp32, which gets converted to fp16 during matrix conversion - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.InputPerThread, Config.OutputPerThread); + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + Config.InputPerThread, Config.OutputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported matrix data type"); return; } - CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), inputMatrix.size(), CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), + inputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), + &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); - // Create input vector of an appropriate type. All integer types start as SINT8 for now. + // Create input vector of an appropriate type. All integer types start as + // SINT8 for now. CComPtr pInputVecSRVResource, pInputVecSRVUploadResource; std::vector inputVector; - + if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && - (MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || + (MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + inputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + inputVector = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + inputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + inputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - inputVector = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.InputPerThread); + inputVector = CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.InputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported input data type"); return; @@ -12339,27 +12449,41 @@ void main(uint threadIdx : SV_GroupThreadID) // Align size to 4 bytes for ByteAddressBuffer inputVector.resize(inputVector.size() + 4 - (inputVector.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputVector.data(), inputVector.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector.size()), &pInputVecSRVResource, &pInputVecSRVUploadResource); + CreateTestResources(pD3DDevice, pCommandList, inputVector.data(), + inputVector.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputVector.size()), + &pInputVecSRVResource, &pInputVecSRVUploadResource); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector.size() / sizeof(int32_t)), pInputVecSRVResource); + CreateRawSRV(pD3DDevice, baseHandle, + (UINT)(inputVector.size() / sizeof(int32_t)), + pInputVecSRVResource); // Create input bias CComPtr pInputBiasSRVResource, pInputBiasSRVUploadResource; std::vector inputBias; - if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || + if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { + inputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { + inputBias = + CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + inputBias = CoopVecHelpers::CreateInputBias( + Config.OutputPerThread); + } else if (MulProps.BiasInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported bias data type"); @@ -12370,22 +12494,30 @@ void main(uint threadIdx : SV_GroupThreadID) // Align size to 4 bytes for ByteAddressBuffer inputBias.resize(inputBias.size() + 4 - (inputBias.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputBias.data(), inputBias.size(), CD3DX12_RESOURCE_DESC::Buffer(inputBias.size()), &pInputBiasSRVResource, &pInputBiasSRVUploadResource); + CreateTestResources(pD3DDevice, pCommandList, inputBias.data(), + inputBias.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputBias.size()), + &pInputBiasSRVResource, &pInputBiasSRVUploadResource); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputBias.size() / sizeof(int32_t)), pInputBiasSRVResource); + CreateRawSRV(pD3DDevice, baseHandle, + (UINT)(inputBias.size() / sizeof(int32_t)), + pInputBiasSRVResource); // Calculate reference output - // FIXME: This does not capture all cases, but is sufficient for the preview feature set + // FIXME: This does not capture all cases, but is sufficient for the preview + // feature set if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { // The input bias is really an array of int32_t std::vector inputBiasI32(inputBias.size() / sizeof(int32_t)); std::memcpy(inputBiasI32.data(), inputBias.data(), inputBias.size()); - // The input vector is really an array of float if our vector input type is FLOAT32 + // The input vector is really an array of float if our vector input type is + // FLOAT32 std::vector inputVectorF32(inputVector.size() / sizeof(int32_t)); if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - std::memcpy(inputVectorF32.data(), inputVector.data(), inputVector.size()); + std::memcpy(inputVectorF32.data(), inputVector.data(), + inputVector.size()); } for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { @@ -12394,12 +12526,15 @@ void main(uint threadIdx : SV_GroupThreadID) for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { int inputElem; - if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 ) { - inputElem = (int)inputVectorF32[threadIdx * Config.InputPerThread + inputIdx]; + if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + inputElem = (int) + inputVectorF32[threadIdx * Config.InputPerThread + inputIdx]; } else { - inputElem = inputVector[threadIdx * Config.InputPerThread + inputIdx]; + inputElem = + inputVector[threadIdx * Config.InputPerThread + inputIdx]; } - int const matrixElem = inputMatrix[outputIdx * Config.InputPerThread + inputIdx]; + int const matrixElem = + inputMatrix[outputIdx * Config.InputPerThread + inputIdx]; acc += inputElem * matrixElem; } @@ -12408,17 +12543,23 @@ void main(uint threadIdx : SV_GroupThreadID) } float result = float(acc); - ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = result; + ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = + result; } } - } else if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { // The input bias/vector is really an array of float16 - std::vector inputVectorFP16(inputVector.size() / sizeof(DirectX::PackedVector::HALF)); + std::vector inputVectorFP16( + inputVector.size() / sizeof(DirectX::PackedVector::HALF)); std::memcpy(inputVectorFP16.data(), inputVector.data(), inputVector.size()); - std::vector inputBiasFP16(inputBias.size() / sizeof(DirectX::PackedVector::HALF)); + std::vector inputBiasFP16( + inputBias.size() / sizeof(DirectX::PackedVector::HALF)); std::memcpy(inputBiasFP16.data(), inputBias.data(), inputBias.size()); // The CPU reference matrix is float @@ -12430,8 +12571,10 @@ void main(uint threadIdx : SV_GroupThreadID) float acc = 0; for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { - float const inputElem = ConvertFloat16ToFloat32(inputVectorFP16[threadIdx * Config.InputPerThread + inputIdx]); - float const matrixElem = inputMatrixFP32[outputIdx * Config.InputPerThread + inputIdx]; + float const inputElem = ConvertFloat16ToFloat32( + inputVectorFP16[threadIdx * Config.InputPerThread + inputIdx]); + float const matrixElem = + inputMatrixFP32[outputIdx * Config.InputPerThread + inputIdx]; acc += inputElem * matrixElem; } @@ -12440,7 +12583,8 @@ void main(uint threadIdx : SV_GroupThreadID) } float result = acc; - ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = result; + ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = + result; } } } @@ -12449,8 +12593,10 @@ void main(uint threadIdx : SV_GroupThreadID) { // Create source matrix info D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; - convertInfo.SrcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); - convertInfo.SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + convertInfo.SrcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); + convertInfo.SrcInfo.SrcLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; // Create destination matrix info convertInfo.DestInfo.DestSize = 0; // Will be populated by driver @@ -12465,22 +12611,25 @@ void main(uint threadIdx : SV_GroupThreadID) break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; - srcEltSize = 4; // FP32 + srcEltSize = 4; // FP32 destEltSize = 2; // FP16 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; - srcEltSize = 4; // FP32 + convertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + srcEltSize = 4; // FP32 destEltSize = 1; // FP8 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; - srcEltSize = 4; // FP32 + convertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + srcEltSize = 4; // FP32 destEltSize = 1; // FP8 break; } convertInfo.SrcInfo.SrcStride = Config.InputPerThread * srcEltSize; - convertInfo.SrcInfo.SrcSize = Config.InputPerThread * Config.OutputPerThread * srcEltSize; + convertInfo.SrcInfo.SrcSize = + Config.InputPerThread * Config.OutputPerThread * srcEltSize; convertInfo.DestInfo.DestLayout = Config.MatrixLayout; convertInfo.DestInfo.DestStride = 0; @@ -12489,29 +12638,38 @@ void main(uint threadIdx : SV_GroupThreadID) if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { convertInfo.DestInfo.DestStride = Config.InputPerThread * destEltSize; - } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { convertInfo.DestInfo.DestStride = Config.OutputPerThread * destEltSize; } // Get destination size using preview interface { CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&previewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &convertInfo.DestInfo); } // Create resource to hold matrix copy - CreateTestResources(pD3DDevice, pCommandList, nullptr, 0, CD3DX12_RESOURCE_DESC::Buffer(convertInfo.DestInfo.DestSize), &pConvertedMatrixResource, nullptr); + CreateTestResources( + pD3DDevice, pCommandList, nullptr, 0, + CD3DX12_RESOURCE_DESC::Buffer(convertInfo.DestInfo.DestSize), + &pConvertedMatrixResource, nullptr); // Set up data descriptors - convertInfo.DataDesc.DestVA = pConvertedMatrixResource->GetGPUVirtualAddress(); - convertInfo.DataDesc.SrcVA = pInputMatrixSRVResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.DestVA = + pConvertedMatrixResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.SrcVA = + pInputMatrixSRVResource->GetGPUVirtualAddress(); // Get command list interface and perform conversion CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + VERIFY_SUCCEEDED(pCommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); // This increments baseHandle @@ -12519,7 +12677,9 @@ void main(uint threadIdx : SV_GroupThreadID) WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); return; } - CreateRawSRV(pD3DDevice, baseHandle, convertInfo.DestInfo.DestSize / sizeof(int32_t), pConvertedMatrixResource); + CreateRawSRV(pD3DDevice, baseHandle, + convertInfo.DestInfo.DestSize / sizeof(int32_t), + pConvertedMatrixResource); } CComPtr pUavResource; @@ -12531,46 +12691,52 @@ void main(uint threadIdx : SV_GroupThreadID) std::vector outputBufferInit(OUTPUT_BUFFER_SIZE); std::fill(outputBufferInit.begin(), outputBufferInit.end(), (uint8_t)0xFF); - CreateTestUavs(pD3DDevice, pCommandList, outputBufferInit.data(), OUTPUT_BUFFER_SIZE, &pUavResource, &pUavUploadResource, &pUavReadResource); + CreateTestUavs(pD3DDevice, pCommandList, outputBufferInit.data(), + OUTPUT_BUFFER_SIZE, &pUavResource, &pUavUploadResource, + &pUavReadResource); CreateRawUAV(pD3DDevice, baseHandle, OUTPUT_BUFFER_SIZE / 4, pUavResource); - pCommandList->Close(); ExecuteCommandList(pCommandQueue, pCommandList); WaitForSignal(pCommandQueue, FO); VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + VERIFY_SUCCEEDED( + pCommandList->Reset(pCommandAllocator, pComputePipelineState)); SetDescriptorHeap(pCommandList, pDescriptorHeap); - CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle(pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle( + pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); pCommandList->SetComputeRootSignature(pRootSignature); pCommandList->SetComputeRootDescriptorTable(0, resHandle); pCommandList->SetPipelineState(pComputePipelineState); pCommandList->Dispatch(1, 1, 1); - RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); + RecordTransitionBarrier(pCommandList, pUavResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); pCommandList->CopyResource(pUavReadResource, pUavResource); pCommandList->Close(); ExecuteCommandList(pCommandQueue, pCommandList); WaitForSignal(pCommandQueue, FO); - { MappedData mappedData(pUavReadResource, OUTPUT_BUFFER_SIZE); float *resultBuffer = (float *)mappedData.data(); bool equal = true; for (int i = 0; i < OUTPUT_BUFFER_SIZE / sizeof(float); i++) { - if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, resultBuffer[i], i, ExpectedOutputBuffer[i]); + LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + resultBuffer[i], i, ExpectedOutputBuffer[i]); equal = false; break; } } VERIFY_IS_TRUE(equal); - } + } } TEST_F(ExecutionTest, CoopVec_Mul) { @@ -12579,13 +12745,11 @@ TEST_F(ExecutionTest, CoopVec_Mul) { RunCoopVecMulTest(); } - void ExecutionTest::RunCoopVecOuterProductTest() { // Create device and verify coopvec support CComPtr pD3DDevice; if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { - WEX::Logging::Log::Comment( - "Device does not support SM 6.9. Skipping."); + WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } @@ -12599,17 +12763,18 @@ void ExecutionTest::RunCoopVecOuterProductTest() { // Query coopvec feature data. First call gets the size of the arrays. The // second call populates the arrays using memory we allocate. D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; - VERIFY_SUCCEEDED( - pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, - &devOptions, sizeof(devOptions))); + VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, + sizeof(devOptions))); // Allocate memory for the arrays in devOptions - std::vector AccumulateProps(devOptions.OuterProductAccumulatePropCount); + std::vector AccumulateProps( + devOptions.OuterProductAccumulatePropCount); devOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); - VERIFY_SUCCEEDED( - pD3DDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, - &devOptions, sizeof(devOptions))); + VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, + sizeof(devOptions))); // Test each supported data type and matrix layout for (auto AccumulateConfig : AccumulateProps) { @@ -12618,23 +12783,30 @@ void ExecutionTest::RunCoopVecOuterProductTest() { } } -void ExecutionTest::RunCoopVecOuterProductTestConfig(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { +void ExecutionTest::RunCoopVecOuterProductTestConfig( + ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { UNREFERENCED_PARAMETER(pD3DDevice); UNREFERENCED_PARAMETER(AccumulateProps); - LogCommentFmt(L"Running test for InputType: %s, AccumulationType: %s", - CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), - CoopVecHelpers::DataTypeToFilterString(AccumulateProps.AccumulationType).c_str()); + LogCommentFmt( + L"Running test for InputType: %s, AccumulationType: %s", + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), + CoopVecHelpers::DataTypeToFilterString(AccumulateProps.AccumulationType) + .c_str()); constexpr CoopVecOuterProductSubtestConfig TestConfigs[] = { - {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL }, + {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, }; for (auto Config : TestConfigs) { - if ((AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && + if ((AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) && (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR || - Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { + Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR)) { continue; } @@ -12642,25 +12814,28 @@ void ExecutionTest::RunCoopVecOuterProductTestConfig(ID3D12Device *pD3DDevice, D } } -void ExecutionTest::RunCoopVecOuterProductSubtest(ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config) { +void ExecutionTest::RunCoopVecOuterProductSubtest( + ID3D12Device *pD3DDevice, + D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, + CoopVecOuterProductSubtestConfig &Config) { UNREFERENCED_PARAMETER(pD3DDevice); UNREFERENCED_PARAMETER(AccumulateProps); UNREFERENCED_PARAMETER(Config); - LogCommentFmt(L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", - Config.DimM, - Config.DimN, - Config.NumThreads, - CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); - + LogCommentFmt( + L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", + Config.DimM, Config.DimN, Config.NumThreads, + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); // Create root signature with a single root entry for all SRVs and UAVs CComPtr pRootSignature; { CD3DX12_DESCRIPTOR_RANGE ranges[2]; - ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0); // InputVector1, InputVector2 + ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, + 0); // InputVector1, InputVector2 ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix - CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, nullptr, 0); + CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, + nullptr, 0); } // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV @@ -12670,9 +12845,11 @@ void ExecutionTest::RunCoopVecOuterProductSubtest(ID3D12Device *pD3DDevice, D3D1 desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; desc.NumDescriptors = 3; desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&pDescriptorHeap))); + VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap( + &desc, IID_PPV_ARGS(&pDescriptorHeap))); } - CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle(pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle( + pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); // Create a compute pipeline state object. CComPtr pComputePipelineState; @@ -12713,60 +12890,81 @@ void main(uint threadIdx : SV_GroupThreadID) return Stream.str(); }; - auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) { + auto CreateDefineFromString = [](const wchar_t *Name, + const wchar_t *Value) { std::wstringstream Stream; Stream << L"-D" << Name << L"=" << Value; return Stream.str(); }; int Stride = 0; - const std::wstring HlslMatrixLayout = CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); - int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(AccumulateProps.AccumulationType); + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + AccumulateProps.AccumulationType); switch (Config.MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - Stride = Config.DimN * StrideMultiplier; - break; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - Stride = Config.DimM * StrideMultiplier; - break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.DimN * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.DimM * StrideMultiplier; + break; } - const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(AccumulateProps.InputType); - const std::wstring InputDataType = CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType, AccumulateProps.InputType); - const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.AccumulationType, AccumulateProps.AccumulationType); - const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType(AccumulateProps.AccumulationType); - const std::wstring InputInterpretationEnum = CoopVecHelpers::GetHlslInterpretationForDataType(AccumulateProps.InputType); + const int InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + AccumulateProps.InputType); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType, + AccumulateProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType( + AccumulateProps.AccumulationType, AccumulateProps.AccumulationType); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.AccumulationType); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.InputType); auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM); auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN); - auto NumThreadsDefine = CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); + auto NumThreadsDefine = + CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); - auto InputDataTypeDefine = CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); - auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); - auto AccumDataTypeDefine = CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); - auto InputInterpretationEnumDefine = CreateDefineFromString(L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); - auto HlslMatrixLayoutDefine = CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); - auto MatrixDataTypeEnumDefine = CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); + auto InputDivisorDefine = + CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); + auto MatrixDataTypeEnumDefine = CreateDefineFromString( + L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); LPCWSTR pOptions[] = { - L"-enable-16bit-types", - DimMDefine.c_str(), - DimNDefine.c_str(), - NumThreadsDefine.c_str(), - StrideDefine.c_str(), - InputDataTypeDefine.c_str(), - InputDivisorDefine.c_str(), - AccumDataTypeDefine.c_str(), - InputInterpretationEnumDefine.c_str(), - HlslMatrixLayoutDefine.c_str(), - MatrixDataTypeEnumDefine.c_str(), + L"-enable-16bit-types", + DimMDefine.c_str(), + DimNDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), }; - CComPtr pIncludeHandler = new LinAlgHeaderIncludeHandler(m_support); + CComPtr pIncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), - L"cs_6_9", &pComputePipelineState, pOptions, - _countof(pOptions), pIncludeHandler); + L"cs_6_9", &pComputePipelineState, pOptions, + _countof(pOptions), pIncludeHandler); } // Create a command list for the compute shader. @@ -12774,52 +12972,74 @@ void main(uint threadIdx : SV_GroupThreadID) CComPtr pCommandAllocator; CComPtr pCommandQueue; FenceObj FO; - CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", - &pCommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); + CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", &pCommandQueue, + D3D12_COMMAND_LIST_TYPE_DIRECT); InitFenceObj(pD3DDevice, &FO); VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, pComputePipelineState, - IID_PPV_ARGS(&pCommandList))); + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, + pComputePipelineState, IID_PPV_ARGS(&pCommandList))); // Setup input matrix as all-ones in sint8/fp32 format. This will later be // converted to the appropriate data type by the matrix conversion API. - CComPtr pInputMatrixSRVResource, pInputMatrixSRVUploadResource; + CComPtr pInputMatrixSRVResource, + pInputMatrixSRVUploadResource; std::vector inputMatrix; if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); - } else if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - // Matrix source data is fp32, which gets converted to fp16 during matrix conversion - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); } else { WEX::Logging::Log::Error(L"Unsupported matrix data type"); return; } - CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), inputMatrix.size(), CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), + inputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), + &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); // Create input vectors CComPtr pInputVecSRVResource1, pInputVecSRVUploadResource1; std::vector inputVector1; CComPtr pInputVecSRVResource2, pInputVecSRVUploadResource2; std::vector inputVector2; - + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); - inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); - } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); - inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); - } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); - inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); + inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + inputVector1 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimM); + inputVector2 = + CoopVecHelpers::CreateInputVector( + Config.NumThreads, Config.DimN); + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimM); + inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + Config.DimN); } else { WEX::Logging::Log::Error(L"Unsupported input data type"); return; @@ -12832,39 +13052,60 @@ void main(uint threadIdx : SV_GroupThreadID) // Align size to 4 bytes for ByteAddressBuffer inputVector2.resize(inputVector2.size() + 4 - (inputVector2.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputVector1.data(), inputVector1.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector1.size()), &pInputVecSRVResource1, &pInputVecSRVUploadResource1); - CreateTestResources(pD3DDevice, pCommandList, inputVector2.data(), inputVector2.size(), CD3DX12_RESOURCE_DESC::Buffer(inputVector2.size()), &pInputVecSRVResource2, &pInputVecSRVUploadResource2); + CreateTestResources(pD3DDevice, pCommandList, inputVector1.data(), + inputVector1.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputVector1.size()), + &pInputVecSRVResource1, &pInputVecSRVUploadResource1); + CreateTestResources(pD3DDevice, pCommandList, inputVector2.data(), + inputVector2.size(), + CD3DX12_RESOURCE_DESC::Buffer(inputVector2.size()), + &pInputVecSRVResource2, &pInputVecSRVUploadResource2); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector1.size() / sizeof(int32_t)), pInputVecSRVResource1); - CreateRawSRV(pD3DDevice, baseHandle, (UINT)(inputVector2.size() / sizeof(int32_t)), pInputVecSRVResource2); + CreateRawSRV(pD3DDevice, baseHandle, + (UINT)(inputVector1.size() / sizeof(int32_t)), + pInputVecSRVResource1); + CreateRawSRV(pD3DDevice, baseHandle, + (UINT)(inputVector2.size() / sizeof(int32_t)), + pInputVecSRVResource2); // Calculate reference output - auto ExpectedOutputBufferI8 = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); - std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / sizeof(float)); - std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), ExpectedOutputBufferI8.size()); + auto ExpectedOutputBufferI8 = + CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / + sizeof(float)); + std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), + ExpectedOutputBufferI8.size()); if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - std::vector inputVector1FP16(inputVector1.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputVector1FP16.data(), inputVector1.data(), inputVector1.size()); + std::vector inputVector1FP16( + inputVector1.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputVector1FP16.data(), inputVector1.data(), + inputVector1.size()); - std::vector inputVector2FP16(inputVector2.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputVector2FP16.data(), inputVector2.data(), inputVector2.size()); + std::vector inputVector2FP16( + inputVector2.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(inputVector2FP16.data(), inputVector2.data(), + inputVector2.size()); for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { for (int m = 0; m < Config.DimM; ++m) { for (int n = 0; n < Config.DimN; ++n) { - float acc = ConvertFloat16ToFloat32(inputVector1FP16[m]) * ConvertFloat16ToFloat32(inputVector2FP16[n]); + float acc = ConvertFloat16ToFloat32(inputVector1FP16[m]) * + ConvertFloat16ToFloat32(inputVector2FP16[n]); ExpectedOutputBuffer[m * Config.DimN + n] += acc; } } } - } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { std::vector inputVector1FP32(inputVector1.size() / sizeof(float)); - std::memcpy(inputVector1FP32.data(), inputVector1.data(), inputVector1.size()); + std::memcpy(inputVector1FP32.data(), inputVector1.data(), + inputVector1.size()); std::vector inputVector2FP32(inputVector2.size() / sizeof(float)); - std::memcpy(inputVector2FP32.data(), inputVector2.data(), inputVector2.size()); + std::memcpy(inputVector2FP32.data(), inputVector2.data(), + inputVector2.size()); for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { for (int m = 0; m < Config.DimM; ++m) { @@ -12876,13 +13117,14 @@ void main(uint threadIdx : SV_GroupThreadID) } } - - CComPtr pConvertedMatrixResource, pConvertedMatrixReadResource; + CComPtr pConvertedMatrixResource, + pConvertedMatrixReadResource; int ConvertedMatrixSize = 0; { // Create source matrix info D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO srcInfo = {}; - srcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); + srcInfo.SrcDataType = + CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); srcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; // Create destination matrix info @@ -12899,17 +13141,17 @@ void main(uint threadIdx : SV_GroupThreadID) break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; - srcEltSize = 4; // FP32 + srcEltSize = 4; // FP32 destEltSize = 2; // FP16 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; - srcEltSize = 4; // FP32 + srcEltSize = 4; // FP32 destEltSize = 1; // FP8 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; - srcEltSize = 4; // FP32 + srcEltSize = 4; // FP32 destEltSize = 1; // FP8 break; } @@ -12923,7 +13165,8 @@ void main(uint threadIdx : SV_GroupThreadID) if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { destInfo.DestStride = Config.DimM * destEltSize; - } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + } else if (Config.MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { destInfo.DestStride = Config.DimM * destEltSize; } @@ -12935,17 +13178,21 @@ void main(uint threadIdx : SV_GroupThreadID) // Get preview device interface { CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&previewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &convertInfo.DestInfo); } ConvertedMatrixSize = convertInfo.DestInfo.DestSize; // Hack to prevent read resource from being created with size 0 std::vector tempData(convertInfo.DestInfo.DestSize); - CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), &pConvertedMatrixResource, nullptr, &pConvertedMatrixReadResource); + CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), + &pConvertedMatrixResource, nullptr, + &pConvertedMatrixReadResource); // Set up data descriptors D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA dataDesc = {}; @@ -12955,7 +13202,8 @@ void main(uint threadIdx : SV_GroupThreadID) // Get command list interface and perform conversion CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + VERIFY_SUCCEEDED(pCommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); // This increments baseHandle @@ -12963,19 +13211,22 @@ void main(uint threadIdx : SV_GroupThreadID) WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); return; } - CreateRawUAV(pD3DDevice, baseHandle, convertInfo.DestInfo.DestSize / sizeof(int32_t), pConvertedMatrixResource); + CreateRawUAV(pD3DDevice, baseHandle, + convertInfo.DestInfo.DestSize / sizeof(int32_t), + pConvertedMatrixResource); } - pCommandList->Close(); ExecuteCommandList(pCommandQueue, pCommandList); WaitForSignal(pCommandQueue, FO); VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + VERIFY_SUCCEEDED( + pCommandList->Reset(pCommandAllocator, pComputePipelineState)); SetDescriptorHeap(pCommandList, pDescriptorHeap); - CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle(pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle( + pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); pCommandList->SetComputeRootSignature(pRootSignature); pCommandList->SetComputeRootDescriptorTable(0, resHandle); @@ -12986,11 +13237,12 @@ void main(uint threadIdx : SV_GroupThreadID) WaitForSignal(pCommandQueue, FO); VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputePipelineState)); - + VERIFY_SUCCEEDED( + pCommandList->Reset(pCommandAllocator, pComputePipelineState)); - // Convert matrix to sint8/fp32 row-major format before reading back to the CPU. - // A new resource is created, along with a readback resource, for the matrix copy. + // Convert matrix to sint8/fp32 row-major format before reading back to the + // CPU. A new resource is created, along with a readback resource, for the + // matrix copy. CComPtr pMatrixRowMajorResource, pMatrixRowMajorReadResource; { // Create source matrix info @@ -13002,14 +13254,19 @@ void main(uint threadIdx : SV_GroupThreadID) // Create destination matrix info convertInfo.DestInfo.DestSize = 0; // Will be populated by driver - convertInfo.DestInfo.DestLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + convertInfo.DestInfo.DestLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; convertInfo.DestInfo.NumRows = Config.DimM; convertInfo.DestInfo.NumColumns = Config.DimN; - if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; convertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); } else { @@ -13020,43 +13277,58 @@ void main(uint threadIdx : SV_GroupThreadID) // Get destination size using preview interface { CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), (void**)&previewDevice)); + VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&previewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(&convertInfo.DestInfo); + previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &convertInfo.DestInfo); } // Create resource to hold matrix copy and a readback resource for it - // Init vector is a hack to prevent read resource from being created with size 0 - // TODO: Fix CreateTestUavs to allow creating readback resource without init data + // Init vector is a hack to prevent read resource from being created with + // size 0 + // TODO: Fix CreateTestUavs to allow creating readback resource without init + // data std::vector tempData(convertInfo.DestInfo.DestSize); - CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), &pMatrixRowMajorResource, nullptr, &pMatrixRowMajorReadResource); + CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), + &pMatrixRowMajorResource, nullptr, + &pMatrixRowMajorReadResource); // Set up data descriptors - convertInfo.DataDesc.DestVA = pMatrixRowMajorResource->GetGPUVirtualAddress(); - convertInfo.DataDesc.SrcVA = pConvertedMatrixResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.DestVA = + pMatrixRowMajorResource->GetGPUVirtualAddress(); + convertInfo.DataDesc.SrcVA = + pConvertedMatrixResource->GetGPUVirtualAddress(); // Get command list interface and perform conversion CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandList11), (void**)&commandList11)); + VERIFY_SUCCEEDED(pCommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); } - RecordTransitionBarrier(pCommandList, pMatrixRowMajorResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); - pCommandList->CopyResource(pMatrixRowMajorReadResource, pMatrixRowMajorResource); + RecordTransitionBarrier(pCommandList, pMatrixRowMajorResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE); + pCommandList->CopyResource(pMatrixRowMajorReadResource, + pMatrixRowMajorResource); pCommandList->Close(); ExecuteCommandList(pCommandQueue, pCommandList); WaitForSignal(pCommandQueue, FO); { - MappedData mappedData(pMatrixRowMajorReadResource, (UINT)inputMatrix.size()); + MappedData mappedData(pMatrixRowMajorReadResource, + (UINT)inputMatrix.size()); float *resultBuffer = (float *)mappedData.data(); bool equal = true; for (int i = 0; i < (UINT)inputMatrix.size() / sizeof(float); i++) { - if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, resultBuffer[i], i, ExpectedOutputBuffer[i]); + LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + resultBuffer[i], i, ExpectedOutputBuffer[i]); equal = false; break; } @@ -13071,10 +13343,10 @@ TEST_F(ExecutionTest, CoopVec_OuterProduct) { RunCoopVecOuterProductTest(); } - // This test expects a that retrieves a signal value from each of a - // few resources that are initialized here. determines if it uses - // the 6.6 Dynamic Resources feature. Values are read back from the result UAV - // and compared to the expected signals +// This test expects a that retrieves a signal value from each of a +// few resources that are initialized here. determines if it uses +// the 6.6 Dynamic Resources feature. Values are read back from the result UAV +// and compared to the expected signals void ExecutionTest::RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic) { WEX::TestExecution::SetVerifyOutput verifySettings( From a7ce090f9ac1bb822152b9072ad556857b0af7e7 Mon Sep 17 00:00:00 2001 From: Damyan Pepper Date: Fri, 2 May 2025 20:46:09 -0700 Subject: [PATCH 03/17] Force Rebuild From d19295453e4b97b24003e7932fdd6ca3bcddb716 Mon Sep 17 00:00:00 2001 From: Damyan Pepper Date: Fri, 2 May 2025 20:49:11 -0700 Subject: [PATCH 04/17] Force Rebuild From 55b7a564da438ea3de7305a64c9c28ee4291cb14 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Sun, 4 May 2025 12:24:02 -0400 Subject: [PATCH 05/17] Remove Hungarian notation --- .../unittests/HLSLExec/ExecutionTest.cpp | 928 +++++++++--------- 1 file changed, 456 insertions(+), 472 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 331dcd3542..59767a244c 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -792,9 +792,9 @@ class ExecutionTest { void RunCoopVecMulTest(); void - RunCoopVecMulTestConfig(ID3D12Device *pD3DDevice, + RunCoopVecMulTestConfig(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); - void RunCoopVecMulSubtest(ID3D12Device *pD3DDevice, + void RunCoopVecMulSubtest(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config); @@ -807,10 +807,10 @@ class ExecutionTest { void RunCoopVecOuterProductTest(); void RunCoopVecOuterProductTestConfig( - ID3D12Device *pD3DDevice, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); void RunCoopVecOuterProductSubtest( - ID3D12Device *pD3DDevice, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config); @@ -1770,9 +1770,9 @@ class ExecutionTest { #endif } - bool DoesDeviceSupportCooperativeVector(ID3D12Device *pDevice) { + bool DoesDeviceSupportCooperativeVector(ID3D12Device *Device) { D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; - if (FAILED(pDevice->CheckFeatureSupport( + if (FAILED(Device->CheckFeatureSupport( (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, sizeof(O)))) return false; @@ -11995,13 +11995,13 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, // The current implementation will always write the final output data as float. void ExecutionTest::RunCoopVecMulTest() { // Create device and verify coopvec support - CComPtr pD3DDevice; - if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } - if (!DoesDeviceSupportCooperativeVector(pD3DDevice)) { + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { WEX::Logging::Log::Comment( "Device does not support cooperative vector. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); @@ -12010,19 +12010,19 @@ void ExecutionTest::RunCoopVecMulTest() { // Query coopvec feature data. First call gets the size of the arrays. The // second call populates the arrays using memory we allocate. - D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; - VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, - sizeof(devOptions))); + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); - // Allocate memory for the arrays in devOptions + // Allocate memory for the arrays in DevOptions std::vector MulAddProps( - devOptions.MatrixVectorMulAddPropCount); - devOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); + DevOptions.MatrixVectorMulAddPropCount); + DevOptions.pMatrixVectorMulAddProperties = MulAddProps.data(); - VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, - sizeof(devOptions))); + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); // Test each supported data type and matrix layout for (auto MulAddConfig : MulAddProps) { @@ -12104,12 +12104,12 @@ void ExecutionTest::RunCoopVecMulTest() { } // Run the test - RunCoopVecMulTestConfig(pD3DDevice, MulAddConfig); + RunCoopVecMulTestConfig(D3DDevice, MulAddConfig); } } void ExecutionTest::RunCoopVecMulTestConfig( - ID3D12Device *pD3DDevice, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { LogCommentFmt( @@ -12184,12 +12184,12 @@ void ExecutionTest::RunCoopVecMulTestConfig( continue; } - RunCoopVecMulSubtest(pD3DDevice, MulProps, Config); + RunCoopVecMulSubtest(D3DDevice, MulProps, Config); } } void ExecutionTest::RunCoopVecMulSubtest( - ID3D12Device *pD3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config) { LogCommentFmt( @@ -12199,37 +12199,36 @@ void ExecutionTest::RunCoopVecMulSubtest( Config.NumLevels, Config.Bias ? L"true" : L"false", CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); - const int OUTPUT_BUFFER_SIZE = - (Config.OutputPerThread * Config.NumThreads * 4); + const int OutputBufferSize = (Config.OutputPerThread * Config.NumThreads * 4); // Create root signature with a single root entry for all SRVs and UAVs - CComPtr pRootSignature; + CComPtr RootSignature; { - CD3DX12_DESCRIPTOR_RANGE ranges[2]; - ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, + CD3DX12_DESCRIPTOR_RANGE Ranges[2]; + Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, 0); // InputVector, InputMatrix, InputBias - ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer - CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, - nullptr, 0); + Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, Ranges, 2, nullptr, + 0); } // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV - CComPtr pDescriptorHeap; + CComPtr DescriptorHeap; { - D3D12_DESCRIPTOR_HEAP_DESC desc = {}; - desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - desc.NumDescriptors = 4; - desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap( - &desc, IID_PPV_ARGS(&pDescriptorHeap))); + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 4; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); } - CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); // Create the compute pipeline state for the CoopVec shader - CComPtr pComputePipelineState; + CComPtr ComputePipelineState; { - std::string shaderSource = R"( + std::string ShaderSource = R"( #include "dx/linalg.h" ByteAddressBuffer InputVector : register(t0); @@ -12337,7 +12336,7 @@ void main(uint threadIdx : SV_GroupThreadID) auto AccumInterpretationEnumDefine = CreateDefineFromString( L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); - LPCWSTR pOptions[] = { + LPCWSTR Options[] = { L"-enable-16bit-types", InputPerThreadDefine.c_str(), OutputPerThreadDefine.c_str(), @@ -12353,27 +12352,27 @@ void main(uint threadIdx : SV_GroupThreadID) AccumInterpretationEnumDefine.c_str(), }; - CComPtr pIncludeHandler = + CComPtr IncludeHandler = new LinAlgHeaderIncludeHandler(m_support); - CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), - L"cs_6_9", &pComputePipelineState, pOptions, - _countof(pOptions), pIncludeHandler); + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); } // Create a command list for the compute shader. - CComPtr pCommandList; - CComPtr pCommandAllocator; - CComPtr pCommandQueue; + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; FenceObj FO; - CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", &pCommandQueue, + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); - InitFenceObj(pD3DDevice, &FO); - VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( - D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); - VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, - pComputePipelineState, IID_PPV_ARGS(&pCommandList))); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); // Setup input data auto ExpectedOutputBuffer = @@ -12381,16 +12380,15 @@ void main(uint threadIdx : SV_GroupThreadID) // Setup input matrix as all-ones in sint8 format. This will later be // converted to the appropriate data type by the matrix conversion API. - CComPtr pInputMatrixSRVResource, - pInputMatrixSRVUploadResource; - std::vector inputMatrix; + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( Config.InputPerThread, Config.OutputPerThread); } else if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || @@ -12400,22 +12398,22 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { // Matrix source data is fp32, which gets converted to fp16 during matrix // conversion - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( Config.InputPerThread, Config.OutputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported matrix data type"); return; } - CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), - inputMatrix.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), - &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); // Create input vector of an appropriate type. All integer types start as // SINT8 for now. - CComPtr pInputVecSRVResource, pInputVecSRVUploadResource; - std::vector inputVector; + CComPtr InputVecSRVResource, InputVecSRVUploadResource; + std::vector InputVector; if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && (MulProps.InputInterpretation == @@ -12424,44 +12422,44 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputVector = CoopVecHelpers::CreateInputVector( + InputVector = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - inputVector = + InputVector = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputVector = CoopVecHelpers::CreateInputVector( + InputVector = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - inputVector = CoopVecHelpers::CreateInputVector( + InputVector = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.InputPerThread); } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - inputVector = CoopVecHelpers::CreateInputVector( + InputVector = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.InputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported input data type"); return; } - if (inputVector.size() % 4 != 0) { + if (InputVector.size() % 4 != 0) { // Align size to 4 bytes for ByteAddressBuffer - inputVector.resize(inputVector.size() + 4 - (inputVector.size() % 4)); + InputVector.resize(InputVector.size() + 4 - (InputVector.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputVector.data(), - inputVector.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputVector.size()), - &pInputVecSRVResource, &pInputVecSRVUploadResource); + CreateTestResources(D3DDevice, CommandList, InputVector.data(), + InputVector.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector.size()), + &InputVecSRVResource, &InputVecSRVUploadResource); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, - (UINT)(inputVector.size() / sizeof(int32_t)), - pInputVecSRVResource); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector.size() / sizeof(int32_t)), + InputVecSRVResource); // Create input bias - CComPtr pInputBiasSRVResource, pInputBiasSRVUploadResource; - std::vector inputBias; + CComPtr InputBiasSRVResource, InputBiasSRVUploadResource; + std::vector InputBias; if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || @@ -12469,82 +12467,82 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - inputBias = + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - inputBias = + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - inputBias = CoopVecHelpers::CreateInputBias( + InputBias = CoopVecHelpers::CreateInputBias( Config.OutputPerThread); } else if (MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); + InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); } else { WEX::Logging::Log::Error(L"Unsupported bias data type"); return; } - if (inputBias.size() % 4 != 0) { + if (InputBias.size() % 4 != 0) { // Align size to 4 bytes for ByteAddressBuffer - inputBias.resize(inputBias.size() + 4 - (inputBias.size() % 4)); + InputBias.resize(InputBias.size() + 4 - (InputBias.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputBias.data(), - inputBias.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputBias.size()), - &pInputBiasSRVResource, &pInputBiasSRVUploadResource); + CreateTestResources(D3DDevice, CommandList, InputBias.data(), + InputBias.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputBias.size()), + &InputBiasSRVResource, &InputBiasSRVUploadResource); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, - (UINT)(inputBias.size() / sizeof(int32_t)), - pInputBiasSRVResource); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputBias.size() / sizeof(int32_t)), + InputBiasSRVResource); // Calculate reference output // FIXME: This does not capture all cases, but is sufficient for the preview // feature set if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { // The input bias is really an array of int32_t - std::vector inputBiasI32(inputBias.size() / sizeof(int32_t)); - std::memcpy(inputBiasI32.data(), inputBias.data(), inputBias.size()); + std::vector InputBiasI32(InputBias.size() / sizeof(int32_t)); + std::memcpy(InputBiasI32.data(), InputBias.data(), InputBias.size()); // The input vector is really an array of float if our vector input type is // FLOAT32 - std::vector inputVectorF32(inputVector.size() / sizeof(int32_t)); + std::vector InputVectorF32(InputVector.size() / sizeof(int32_t)); if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - std::memcpy(inputVectorF32.data(), inputVector.data(), - inputVector.size()); + std::memcpy(InputVectorF32.data(), InputVector.data(), + InputVector.size()); } - for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { - for (int outputIdx = 0; outputIdx < Config.OutputPerThread; ++outputIdx) { - int acc = 0; + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + int Acc = 0; - for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { - int inputElem; + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + int InputElem; if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputElem = (int) - inputVectorF32[threadIdx * Config.InputPerThread + inputIdx]; + InputElem = (int) + InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx]; } else { - inputElem = - inputVector[threadIdx * Config.InputPerThread + inputIdx]; + InputElem = + InputVector[ThreadIdx * Config.InputPerThread + InputIdx]; } - int const matrixElem = - inputMatrix[outputIdx * Config.InputPerThread + inputIdx]; - acc += inputElem * matrixElem; + int const MatrixElem = + InputMatrix[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; } if (Config.Bias) { - acc += inputBiasI32[outputIdx]; + Acc += InputBiasI32[OutputIdx]; } - float result = float(acc); - ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = - result; + float Result = float(Acc); + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; } } } else if (MulProps.MatrixInterpretation == @@ -12554,188 +12552,186 @@ void main(uint threadIdx : SV_GroupThreadID) MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { // The input bias/vector is really an array of float16 - std::vector inputVectorFP16( - inputVector.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputVectorFP16.data(), inputVector.data(), inputVector.size()); + std::vector InputVectorFP16( + InputVector.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVectorFP16.data(), InputVector.data(), InputVector.size()); - std::vector inputBiasFP16( - inputBias.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputBiasFP16.data(), inputBias.data(), inputBias.size()); + std::vector InputBiasFP16( + InputBias.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputBiasFP16.data(), InputBias.data(), InputBias.size()); // The CPU reference matrix is float - std::vector inputMatrixFP32(inputMatrix.size() / sizeof(float)); - std::memcpy(inputMatrixFP32.data(), inputMatrix.data(), inputMatrix.size()); - - for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { - for (int outputIdx = 0; outputIdx < Config.OutputPerThread; ++outputIdx) { - float acc = 0; - - for (int inputIdx = 0; inputIdx < Config.InputPerThread; ++inputIdx) { - float const inputElem = ConvertFloat16ToFloat32( - inputVectorFP16[threadIdx * Config.InputPerThread + inputIdx]); - float const matrixElem = - inputMatrixFP32[outputIdx * Config.InputPerThread + inputIdx]; - acc += inputElem * matrixElem; + std::vector InputMatrixFP32(InputMatrix.size() / sizeof(float)); + std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { + float Acc = 0; + + for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { + float const InputElem = ConvertFloat16ToFloat32( + InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]); + float const MatrixElem = + InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx]; + Acc += InputElem * MatrixElem; } if (Config.Bias) { - acc += ConvertFloat16ToFloat32(inputBiasFP16[outputIdx]); + Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]); } - float result = acc; - ExpectedOutputBuffer[threadIdx * Config.OutputPerThread + outputIdx] = - result; + float Result = Acc; + ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = + Result; } } } - CComPtr pConvertedMatrixResource; + CComPtr ConvertedMatrixResource; { // Create source matrix info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; - convertInfo.SrcInfo.SrcDataType = + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); - convertInfo.SrcInfo.SrcLayout = + ConvertInfo.SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; // Create destination matrix info - convertInfo.DestInfo.DestSize = 0; // Will be populated by driver - int srcEltSize = 0; - int destEltSize = 0; + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; switch (MulProps.MatrixInterpretation) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - srcEltSize = 1; - destEltSize = 1; + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; - srcEltSize = 4; // FP32 - destEltSize = 2; // FP16 + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - convertInfo.DestInfo.DestDataType = + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; - srcEltSize = 4; // FP32 - destEltSize = 1; // FP8 + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - convertInfo.DestInfo.DestDataType = + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; - srcEltSize = 4; // FP32 - destEltSize = 1; // FP8 + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 break; } - convertInfo.SrcInfo.SrcStride = Config.InputPerThread * srcEltSize; - convertInfo.SrcInfo.SrcSize = - Config.InputPerThread * Config.OutputPerThread * srcEltSize; + ConvertInfo.SrcInfo.SrcStride = Config.InputPerThread * SrcEltSize; + ConvertInfo.SrcInfo.SrcSize = + Config.InputPerThread * Config.OutputPerThread * SrcEltSize; - convertInfo.DestInfo.DestLayout = Config.MatrixLayout; - convertInfo.DestInfo.DestStride = 0; - convertInfo.DestInfo.NumRows = Config.OutputPerThread; - convertInfo.DestInfo.NumColumns = Config.InputPerThread; + ConvertInfo.DestInfo.DestLayout = Config.MatrixLayout; + ConvertInfo.DestInfo.DestStride = 0; + ConvertInfo.DestInfo.NumRows = Config.OutputPerThread; + ConvertInfo.DestInfo.NumColumns = Config.InputPerThread; if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { - convertInfo.DestInfo.DestStride = Config.InputPerThread * destEltSize; + ConvertInfo.DestInfo.DestStride = Config.InputPerThread * DestEltSize; } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { - convertInfo.DestInfo.DestStride = Config.OutputPerThread * destEltSize; + ConvertInfo.DestInfo.DestStride = Config.OutputPerThread * DestEltSize; } // Get destination size using preview interface { - CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), - (void **)&previewDevice)); + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( - &convertInfo.DestInfo); + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); } // Create resource to hold matrix copy CreateTestResources( - pD3DDevice, pCommandList, nullptr, 0, - CD3DX12_RESOURCE_DESC::Buffer(convertInfo.DestInfo.DestSize), - &pConvertedMatrixResource, nullptr); + D3DDevice, CommandList, nullptr, 0, + CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize), + &ConvertedMatrixResource, nullptr); // Set up data descriptors - convertInfo.DataDesc.DestVA = - pConvertedMatrixResource->GetGPUVirtualAddress(); - convertInfo.DataDesc.SrcVA = - pInputMatrixSRVResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.DestVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); // Get command list interface and perform conversion - CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); - commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); // This increments baseHandle - if ((convertInfo.DestInfo.DestSize % 4) != 0) { + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); return; } - CreateRawSRV(pD3DDevice, baseHandle, - convertInfo.DestInfo.DestSize / sizeof(int32_t), - pConvertedMatrixResource); + CreateRawSRV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); } - CComPtr pUavResource; - CComPtr pUavUploadResource; - CComPtr pUavReadResource; + CComPtr UavResource; + CComPtr UavUploadResource; + CComPtr UavReadResource; // Create buffer for output and fill with 0xFF to make it obvious if it's not // written in the shader. - std::vector outputBufferInit(OUTPUT_BUFFER_SIZE); - std::fill(outputBufferInit.begin(), outputBufferInit.end(), (uint8_t)0xFF); - - CreateTestUavs(pD3DDevice, pCommandList, outputBufferInit.data(), - OUTPUT_BUFFER_SIZE, &pUavResource, &pUavUploadResource, - &pUavReadResource); - CreateRawUAV(pD3DDevice, baseHandle, OUTPUT_BUFFER_SIZE / 4, pUavResource); - - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED( - pCommandList->Reset(pCommandAllocator, pComputePipelineState)); - - SetDescriptorHeap(pCommandList, pDescriptorHeap); - - CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle( - pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); - - pCommandList->SetComputeRootSignature(pRootSignature); - pCommandList->SetComputeRootDescriptorTable(0, resHandle); - pCommandList->SetPipelineState(pComputePipelineState); - pCommandList->Dispatch(1, 1, 1); - RecordTransitionBarrier(pCommandList, pUavResource, + std::vector OutputBufferInit(OutputBufferSize); + std::fill(OutputBufferInit.begin(), OutputBufferInit.end(), (uint8_t)0xFF); + + CreateTestUavs(D3DDevice, CommandList, OutputBufferInit.data(), + OutputBufferSize, &UavResource, &UavUploadResource, + &UavReadResource); + CreateRawUAV(D3DDevice, BaseHandle, OutputBufferSize / 4, UavResource); + + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + + SetDescriptorHeap(CommandList, DescriptorHeap); + + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + RecordTransitionBarrier(CommandList, UavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); - pCommandList->CopyResource(pUavReadResource, pUavResource); - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->CopyResource(UavReadResource, UavResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); { - MappedData mappedData(pUavReadResource, OUTPUT_BUFFER_SIZE); + MappedData MappedData(UavReadResource, OutputBufferSize); - float *resultBuffer = (float *)mappedData.data(); - bool equal = true; - for (int i = 0; i < OUTPUT_BUFFER_SIZE / sizeof(float); i++) { - if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || - fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < OutputBufferSize / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, - resultBuffer[i], i, ExpectedOutputBuffer[i]); - equal = false; + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; break; } } - VERIFY_IS_TRUE(equal); + VERIFY_IS_TRUE(Equal); } } @@ -12747,13 +12743,13 @@ TEST_F(ExecutionTest, CoopVec_Mul) { void ExecutionTest::RunCoopVecOuterProductTest() { // Create device and verify coopvec support - CComPtr pD3DDevice; - if (!CreateDevice(&pD3DDevice, D3D_SHADER_MODEL_6_9)) { + CComPtr D3DDevice; + if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } - if (!DoesDeviceSupportCooperativeVector(pD3DDevice)) { + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { WEX::Logging::Log::Comment( "Device does not support cooperative vector. Skipping."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); @@ -12762,33 +12758,30 @@ void ExecutionTest::RunCoopVecOuterProductTest() { // Query coopvec feature data. First call gets the size of the arrays. The // second call populates the arrays using memory we allocate. - D3D12_FEATURE_DATA_COOPERATIVE_VECTOR devOptions = {}; - VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, - sizeof(devOptions))); + D3D12_FEATURE_DATA_COOPERATIVE_VECTOR DevOptions = {}; + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); - // Allocate memory for the arrays in devOptions + // Allocate memory for the arrays in DevOptions std::vector AccumulateProps( - devOptions.OuterProductAccumulatePropCount); - devOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); + DevOptions.OuterProductAccumulatePropCount); + DevOptions.pOuterProductAccumulateProperties = AccumulateProps.data(); - VERIFY_SUCCEEDED(pD3DDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &devOptions, - sizeof(devOptions))); + VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_COOPERATIVE_VECTOR, &DevOptions, + sizeof(DevOptions))); // Test each supported data type and matrix layout for (auto AccumulateConfig : AccumulateProps) { // Run the test - RunCoopVecOuterProductTestConfig(pD3DDevice, AccumulateConfig); + RunCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); } } void ExecutionTest::RunCoopVecOuterProductTestConfig( - ID3D12Device *pD3DDevice, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { - UNREFERENCED_PARAMETER(pD3DDevice); - UNREFERENCED_PARAMETER(AccumulateProps); - LogCommentFmt( L"Running test for InputType: %s, AccumulationType: %s", CoopVecHelpers::DataTypeToFilterString(AccumulateProps.InputType).c_str(), @@ -12810,17 +12803,14 @@ void ExecutionTest::RunCoopVecOuterProductTestConfig( continue; } - RunCoopVecOuterProductSubtest(pD3DDevice, AccumulateProps, Config); + RunCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); } } void ExecutionTest::RunCoopVecOuterProductSubtest( - ID3D12Device *pD3DDevice, + ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config) { - UNREFERENCED_PARAMETER(pD3DDevice); - UNREFERENCED_PARAMETER(AccumulateProps); - UNREFERENCED_PARAMETER(Config); LogCommentFmt( L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", @@ -12828,33 +12818,33 @@ void ExecutionTest::RunCoopVecOuterProductSubtest( CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); // Create root signature with a single root entry for all SRVs and UAVs - CComPtr pRootSignature; + CComPtr RootSignature; { CD3DX12_DESCRIPTOR_RANGE ranges[2]; ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0); // InputVector1, InputVector2 ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix - CreateRootSignatureFromRanges(pD3DDevice, &pRootSignature, ranges, 2, - nullptr, 0); + CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr, + 0); } // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV - CComPtr pDescriptorHeap; + CComPtr DescriptorHeap; { - D3D12_DESCRIPTOR_HEAP_DESC desc = {}; - desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - desc.NumDescriptors = 3; - desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - VERIFY_SUCCEEDED(pD3DDevice->CreateDescriptorHeap( - &desc, IID_PPV_ARGS(&pDescriptorHeap))); + D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; + Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + Desc.NumDescriptors = 3; + Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + VERIFY_SUCCEEDED( + D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); } - CD3DX12_CPU_DESCRIPTOR_HANDLE baseHandle( - pDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); + CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( + DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); // Create a compute pipeline state object. - CComPtr pComputePipelineState; + CComPtr ComputePipelineState; { - std::string shaderSource = R"( + std::string ShaderSource = R"( #include "dx/linalg.h" ByteAddressBuffer InputVector1 : register(t0); @@ -12945,7 +12935,7 @@ void main(uint threadIdx : SV_GroupThreadID) auto MatrixDataTypeEnumDefine = CreateDefineFromString( L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); - LPCWSTR pOptions[] = { + LPCWSTR Options[] = { L"-enable-16bit-types", DimMDefine.c_str(), DimNDefine.c_str(), @@ -12959,36 +12949,35 @@ void main(uint threadIdx : SV_GroupThreadID) MatrixDataTypeEnumDefine.c_str(), }; - CComPtr pIncludeHandler = + CComPtr IncludeHandler = new LinAlgHeaderIncludeHandler(m_support); - CreateComputePSO(pD3DDevice, pRootSignature, shaderSource.c_str(), - L"cs_6_9", &pComputePipelineState, pOptions, - _countof(pOptions), pIncludeHandler); + CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", + &ComputePipelineState, Options, _countof(Options), + IncludeHandler); } // Create a command list for the compute shader. - CComPtr pCommandList; - CComPtr pCommandAllocator; - CComPtr pCommandQueue; + CComPtr CommandList; + CComPtr CommandAllocator; + CComPtr CommandQueue; FenceObj FO; - CreateCommandQueue(pD3DDevice, L"CoopVec Test Command Queue", &pCommandQueue, + CreateCommandQueue(D3DDevice, L"CoopVec Test Command Queue", &CommandQueue, D3D12_COMMAND_LIST_TYPE_DIRECT); - InitFenceObj(pD3DDevice, &FO); - VERIFY_SUCCEEDED(pD3DDevice->CreateCommandAllocator( - D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&pCommandAllocator))); - VERIFY_SUCCEEDED(pD3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, pCommandAllocator, - pComputePipelineState, IID_PPV_ARGS(&pCommandList))); + InitFenceObj(D3DDevice, &FO); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); + VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + IID_PPV_ARGS(&CommandList))); // Setup input matrix as all-ones in sint8/fp32 format. This will later be // converted to the appropriate data type by the matrix conversion API. - CComPtr pInputMatrixSRVResource, - pInputMatrixSRVUploadResource; - std::vector inputMatrix; + CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; + std::vector InputMatrix; if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); } else if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || @@ -12998,29 +12987,29 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { // Matrix source data is fp32, which gets converted to fp16 during matrix // conversion - inputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); } else { WEX::Logging::Log::Error(L"Unsupported matrix data type"); return; } - CreateTestResources(pD3DDevice, pCommandList, inputMatrix.data(), - inputMatrix.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputMatrix.size()), - &pInputMatrixSRVResource, &pInputMatrixSRVUploadResource); + CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), + InputMatrix.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), + &InputMatrixSRVResource, &InputMatrixSRVUploadResource); // Create input vectors - CComPtr pInputVecSRVResource1, pInputVecSRVUploadResource1; - std::vector inputVector1; - CComPtr pInputVecSRVResource2, pInputVecSRVUploadResource2; - std::vector inputVector2; + CComPtr InputVecSRVResource1, InputVecSRVUploadResource1; + std::vector InputVector1; + CComPtr InputVecSRVResource2, InputVecSRVUploadResource2; + std::vector InputVector2; if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); - inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || @@ -13028,46 +13017,46 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - inputVector1 = + InputVector1 = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.DimM); - inputVector2 = + InputVector2 = CoopVecHelpers::CreateInputVector( Config.NumThreads, Config.DimN); } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - inputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimM); - inputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, + InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, Config.DimN); } else { WEX::Logging::Log::Error(L"Unsupported input data type"); return; } - if (inputVector1.size() % 4 != 0) { + if (InputVector1.size() % 4 != 0) { // Align size to 4 bytes for ByteAddressBuffer - inputVector1.resize(inputVector1.size() + 4 - (inputVector1.size() % 4)); + InputVector1.resize(InputVector1.size() + 4 - (InputVector1.size() % 4)); } - if (inputVector2.size() % 4 != 0) { + if (InputVector2.size() % 4 != 0) { // Align size to 4 bytes for ByteAddressBuffer - inputVector2.resize(inputVector2.size() + 4 - (inputVector2.size() % 4)); + InputVector2.resize(InputVector2.size() + 4 - (InputVector2.size() % 4)); } - CreateTestResources(pD3DDevice, pCommandList, inputVector1.data(), - inputVector1.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputVector1.size()), - &pInputVecSRVResource1, &pInputVecSRVUploadResource1); - CreateTestResources(pD3DDevice, pCommandList, inputVector2.data(), - inputVector2.size(), - CD3DX12_RESOURCE_DESC::Buffer(inputVector2.size()), - &pInputVecSRVResource2, &pInputVecSRVUploadResource2); + CreateTestResources(D3DDevice, CommandList, InputVector1.data(), + InputVector1.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector1.size()), + &InputVecSRVResource1, &InputVecSRVUploadResource1); + CreateTestResources(D3DDevice, CommandList, InputVector2.data(), + InputVector2.size(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector2.size()), + &InputVecSRVResource2, &InputVecSRVUploadResource2); // This increments baseHandle - CreateRawSRV(pD3DDevice, baseHandle, - (UINT)(inputVector1.size() / sizeof(int32_t)), - pInputVecSRVResource1); - CreateRawSRV(pD3DDevice, baseHandle, - (UINT)(inputVector2.size() / sizeof(int32_t)), - pInputVecSRVResource2); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector1.size() / sizeof(int32_t)), + InputVecSRVResource1); + CreateRawSRV(D3DDevice, BaseHandle, + (UINT)(InputVector2.size() / sizeof(int32_t)), + InputVecSRVResource2); // Calculate reference output auto ExpectedOutputBufferI8 = @@ -13078,186 +13067,183 @@ void main(uint threadIdx : SV_GroupThreadID) ExpectedOutputBufferI8.size()); if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - std::vector inputVector1FP16( - inputVector1.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputVector1FP16.data(), inputVector1.data(), - inputVector1.size()); - - std::vector inputVector2FP16( - inputVector2.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(inputVector2FP16.data(), inputVector2.data(), - inputVector2.size()); - - for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { - for (int m = 0; m < Config.DimM; ++m) { - for (int n = 0; n < Config.DimN; ++n) { - float acc = ConvertFloat16ToFloat32(inputVector1FP16[m]) * - ConvertFloat16ToFloat32(inputVector2FP16[n]); - ExpectedOutputBuffer[m * Config.DimN + n] += acc; + std::vector InputVector1FP16( + InputVector1.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector1FP16.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP16( + InputVector2.size() / sizeof(DirectX::PackedVector::HALF)); + std::memcpy(InputVector2FP16.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float acc = ConvertFloat16ToFloat32(InputVector1FP16[M]) * + ConvertFloat16ToFloat32(InputVector2FP16[N]); + ExpectedOutputBuffer[M * Config.DimN + N] += acc; } } } } else if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - std::vector inputVector1FP32(inputVector1.size() / sizeof(float)); - std::memcpy(inputVector1FP32.data(), inputVector1.data(), - inputVector1.size()); - - std::vector inputVector2FP32(inputVector2.size() / sizeof(float)); - std::memcpy(inputVector2FP32.data(), inputVector2.data(), - inputVector2.size()); - - for (int threadIdx = 0; threadIdx < Config.NumThreads; ++threadIdx) { - for (int m = 0; m < Config.DimM; ++m) { - for (int n = 0; n < Config.DimN; ++n) { - float acc = inputVector1FP32[m] * inputVector2FP32[n]; - ExpectedOutputBuffer[m * Config.DimN + n] += acc; + std::vector InputVector1FP32(InputVector1.size() / sizeof(float)); + std::memcpy(InputVector1FP32.data(), InputVector1.data(), + InputVector1.size()); + + std::vector InputVector2FP32(InputVector2.size() / sizeof(float)); + std::memcpy(InputVector2FP32.data(), InputVector2.data(), + InputVector2.size()); + + for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + for (int M = 0; M < Config.DimM; ++M) { + for (int N = 0; N < Config.DimN; ++N) { + float Acc = InputVector1FP32[M] * InputVector2FP32[N]; + ExpectedOutputBuffer[M * Config.DimN + N] += Acc; } } } } - CComPtr pConvertedMatrixResource, - pConvertedMatrixReadResource; + CComPtr ConvertedMatrixResource, ConvertedMatrixReadResource; int ConvertedMatrixSize = 0; { // Create source matrix info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO srcInfo = {}; - srcInfo.SrcDataType = + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo = {}; + SrcInfo.SrcDataType = CoopVecHelpers::GetMatrixSrcDataType(AccumulateProps.AccumulationType); - srcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + SrcInfo.SrcLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; // Create destination matrix info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO destInfo = {}; - destInfo.DestSize = 0; // Will be populated by driver - int srcEltSize = 0; - int destEltSize = 0; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo = {}; + DestInfo.DestSize = 0; // Will be populated by driver + int SrcEltSize = 0; + int DestEltSize = 0; switch (AccumulateProps.AccumulationType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - srcEltSize = 1; - destEltSize = 1; + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + SrcEltSize = 1; + DestEltSize = 1; break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; - srcEltSize = 4; // FP32 - destEltSize = 2; // FP16 + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + SrcEltSize = 4; // FP32 + DestEltSize = 2; // FP16 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; - srcEltSize = 4; // FP32 - destEltSize = 1; // FP8 + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 break; case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - destInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; - srcEltSize = 4; // FP32 - destEltSize = 1; // FP8 + DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + SrcEltSize = 4; // FP32 + DestEltSize = 1; // FP8 break; } - srcInfo.SrcStride = Config.DimM * srcEltSize; - srcInfo.SrcSize = Config.DimM * Config.DimN * srcEltSize; + SrcInfo.SrcStride = Config.DimM * SrcEltSize; + SrcInfo.SrcSize = Config.DimM * Config.DimN * SrcEltSize; - destInfo.DestLayout = Config.MatrixLayout; - destInfo.DestStride = 0; - destInfo.NumRows = Config.DimM; - destInfo.NumColumns = Config.DimN; + DestInfo.DestLayout = Config.MatrixLayout; + DestInfo.DestStride = 0; + DestInfo.NumRows = Config.DimM; + DestInfo.NumColumns = Config.DimN; if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { - destInfo.DestStride = Config.DimM * destEltSize; + DestInfo.DestStride = Config.DimM * DestEltSize; } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { - destInfo.DestStride = Config.DimM * destEltSize; + DestInfo.DestStride = Config.DimM * DestEltSize; } // Create conversion info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; - convertInfo.SrcInfo = srcInfo; - convertInfo.DestInfo = destInfo; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo = SrcInfo; + ConvertInfo.DestInfo = DestInfo; // Get preview device interface { - CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), - (void **)&previewDevice)); + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( - &convertInfo.DestInfo); + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); } - ConvertedMatrixSize = convertInfo.DestInfo.DestSize; + ConvertedMatrixSize = ConvertInfo.DestInfo.DestSize; // Hack to prevent read resource from being created with size 0 - std::vector tempData(convertInfo.DestInfo.DestSize); - CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), - &pConvertedMatrixResource, nullptr, - &pConvertedMatrixReadResource); + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &ConvertedMatrixResource, nullptr, + &ConvertedMatrixReadResource); // Set up data descriptors - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA dataDesc = {}; - dataDesc.DestVA = pConvertedMatrixResource->GetGPUVirtualAddress(); - dataDesc.SrcVA = pInputMatrixSRVResource->GetGPUVirtualAddress(); - convertInfo.DataDesc = dataDesc; + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DATA DataDesc = {}; + DataDesc.DestVA = ConvertedMatrixResource->GetGPUVirtualAddress(); + DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc = DataDesc; // Get command list interface and perform conversion - CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); - commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); // This increments baseHandle - if ((convertInfo.DestInfo.DestSize % 4) != 0) { + if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); return; } - CreateRawUAV(pD3DDevice, baseHandle, - convertInfo.DestInfo.DestSize / sizeof(int32_t), - pConvertedMatrixResource); + CreateRawUAV(D3DDevice, BaseHandle, + ConvertInfo.DestInfo.DestSize / sizeof(int32_t), + ConvertedMatrixResource); } - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED( - pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); - SetDescriptorHeap(pCommandList, pDescriptorHeap); + SetDescriptorHeap(CommandList, DescriptorHeap); - CD3DX12_GPU_DESCRIPTOR_HANDLE resHandle( - pDescriptorHeap->GetGPUDescriptorHandleForHeapStart()); + CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( + DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); - pCommandList->SetComputeRootSignature(pRootSignature); - pCommandList->SetComputeRootDescriptorTable(0, resHandle); - pCommandList->SetPipelineState(pComputePipelineState); - pCommandList->Dispatch(1, 1, 1); - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(ComputePipelineState); + CommandList->Dispatch(1, 1, 1); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); - VERIFY_SUCCEEDED(pCommandAllocator->Reset()); - VERIFY_SUCCEEDED( - pCommandList->Reset(pCommandAllocator, pComputePipelineState)); + VERIFY_SUCCEEDED(CommandAllocator->Reset()); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); // Convert matrix to sint8/fp32 row-major format before reading back to the // CPU. A new resource is created, along with a readback resource, for the // matrix copy. - CComPtr pMatrixRowMajorResource, pMatrixRowMajorReadResource; + CComPtr MatrixRowMajorResource, MatrixRowMajorReadResource; { // Create source matrix info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO convertInfo = {}; - convertInfo.SrcInfo.SrcLayout = Config.MatrixLayout; - convertInfo.SrcInfo.SrcSize = ConvertedMatrixSize; - convertInfo.SrcInfo.SrcDataType = AccumulateProps.AccumulationType; - convertInfo.SrcInfo.SrcStride = 0; // OUTER_PRODUCT_OPTIMAL + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcLayout = Config.MatrixLayout; + ConvertInfo.SrcInfo.SrcSize = ConvertedMatrixSize; + ConvertInfo.SrcInfo.SrcDataType = AccumulateProps.AccumulationType; + ConvertInfo.SrcInfo.SrcStride = 0; // OUTER_PRODUCT_OPTIMAL // Create destination matrix info - convertInfo.DestInfo.DestSize = 0; // Will be populated by driver - convertInfo.DestInfo.DestLayout = + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + ConvertInfo.DestInfo.DestLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; - convertInfo.DestInfo.NumRows = Config.DimM; - convertInfo.DestInfo.NumColumns = Config.DimN; + ConvertInfo.DestInfo.NumRows = Config.DimM; + ConvertInfo.DestInfo.NumColumns = Config.DimN; if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || @@ -13267,22 +13253,22 @@ void main(uint threadIdx : SV_GroupThreadID) D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; - convertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); } else { - convertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - convertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); } // Get destination size using preview interface { - CComPtr previewDevice; - VERIFY_SUCCEEDED(pD3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), - (void **)&previewDevice)); + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); // Query required destination size - previewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( - &convertInfo.DestInfo); + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); } // Create resource to hold matrix copy and a readback resource for it @@ -13290,50 +13276,48 @@ void main(uint threadIdx : SV_GroupThreadID) // size 0 // TODO: Fix CreateTestUavs to allow creating readback resource without init // data - std::vector tempData(convertInfo.DestInfo.DestSize); - CreateTestUavs(pD3DDevice, pCommandList, tempData.data(), tempData.size(), - &pMatrixRowMajorResource, nullptr, - &pMatrixRowMajorReadResource); + std::vector TempData(ConvertInfo.DestInfo.DestSize); + CreateTestUavs(D3DDevice, CommandList, TempData.data(), TempData.size(), + &MatrixRowMajorResource, nullptr, + &MatrixRowMajorReadResource); // Set up data descriptors - convertInfo.DataDesc.DestVA = - pMatrixRowMajorResource->GetGPUVirtualAddress(); - convertInfo.DataDesc.SrcVA = - pConvertedMatrixResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.DestVA = + MatrixRowMajorResource->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = + ConvertedMatrixResource->GetGPUVirtualAddress(); // Get command list interface and perform conversion - CComPtr commandList11; - VERIFY_SUCCEEDED(pCommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&commandList11)); - commandList11->ConvertLinearAlgebraMatrix(&convertInfo, 1); + CComPtr CommandList11; + VERIFY_SUCCEEDED(CommandList->QueryInterface( + __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); + CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); } - RecordTransitionBarrier(pCommandList, pMatrixRowMajorResource, + RecordTransitionBarrier(CommandList, MatrixRowMajorResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); - pCommandList->CopyResource(pMatrixRowMajorReadResource, - pMatrixRowMajorResource); - pCommandList->Close(); - ExecuteCommandList(pCommandQueue, pCommandList); - WaitForSignal(pCommandQueue, FO); + CommandList->CopyResource(MatrixRowMajorReadResource, MatrixRowMajorResource); + CommandList->Close(); + ExecuteCommandList(CommandQueue, CommandList); + WaitForSignal(CommandQueue, FO); { - MappedData mappedData(pMatrixRowMajorReadResource, - (UINT)inputMatrix.size()); - - float *resultBuffer = (float *)mappedData.data(); - bool equal = true; - for (int i = 0; i < (UINT)inputMatrix.size() / sizeof(float); i++) { - if (isnan(resultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || - fabs(resultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { + MappedData MappedData(MatrixRowMajorReadResource, (UINT)InputMatrix.size()); + + float *ResultBuffer = (float *)MappedData.data(); + bool Equal = true; + for (int i = 0; i < (UINT)InputMatrix.size() / sizeof(float); i++) { + if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || + fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"resultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, - resultBuffer[i], i, ExpectedOutputBuffer[i]); - equal = false; + LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + ResultBuffer[i], i, ExpectedOutputBuffer[i]); + Equal = false; break; } } - VERIFY_IS_TRUE(equal); + VERIFY_IS_TRUE(Equal); } } From 6dc644bd4a75507d82b97e076463b9ce7cf4a3ec Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:15:22 -0400 Subject: [PATCH 06/17] s/RunCoopVec/runCoopVec/ --- .../unittests/HLSLExec/ExecutionTest.cpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 59767a244c..dc555939e2 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -790,11 +790,11 @@ class ExecutionTest { bool Bias; }; - void RunCoopVecMulTest(); + void runCoopVecMulTest(); void - RunCoopVecMulTestConfig(ID3D12Device *D3DDevice, + runCoopVecMulTestConfig(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); - void RunCoopVecMulSubtest(ID3D12Device *D3DDevice, + void runCoopVecMulSubtest(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config); @@ -805,11 +805,11 @@ class ExecutionTest { D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; }; - void RunCoopVecOuterProductTest(); - void RunCoopVecOuterProductTestConfig( + void runCoopVecOuterProductTest(); + void runCoopVecOuterProductTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); - void RunCoopVecOuterProductSubtest( + void runCoopVecOuterProductSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config); @@ -11993,7 +11993,7 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, // -p:CoopVecMatrixLayout=MUL_OPTIMAL // // The current implementation will always write the final output data as float. -void ExecutionTest::RunCoopVecMulTest() { +void ExecutionTest::runCoopVecMulTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { @@ -12104,11 +12104,11 @@ void ExecutionTest::RunCoopVecMulTest() { } // Run the test - RunCoopVecMulTestConfig(D3DDevice, MulAddConfig); + runCoopVecMulTestConfig(D3DDevice, MulAddConfig); } } -void ExecutionTest::RunCoopVecMulTestConfig( +void ExecutionTest::runCoopVecMulTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { @@ -12184,11 +12184,11 @@ void ExecutionTest::RunCoopVecMulTestConfig( continue; } - RunCoopVecMulSubtest(D3DDevice, MulProps, Config); + runCoopVecMulSubtest(D3DDevice, MulProps, Config); } } -void ExecutionTest::RunCoopVecMulSubtest( +void ExecutionTest::runCoopVecMulSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, CoopVecMulSubtestConfig &Config) { @@ -12738,10 +12738,10 @@ void main(uint threadIdx : SV_GroupThreadID) TEST_F(ExecutionTest, CoopVec_Mul) { WEX::TestExecution::SetVerifyOutput verifySettings( WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - RunCoopVecMulTest(); + runCoopVecMulTest(); } -void ExecutionTest::RunCoopVecOuterProductTest() { +void ExecutionTest::runCoopVecOuterProductTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { @@ -12775,11 +12775,11 @@ void ExecutionTest::RunCoopVecOuterProductTest() { // Test each supported data type and matrix layout for (auto AccumulateConfig : AccumulateProps) { // Run the test - RunCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); + runCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); } } -void ExecutionTest::RunCoopVecOuterProductTestConfig( +void ExecutionTest::runCoopVecOuterProductTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { LogCommentFmt( @@ -12803,11 +12803,11 @@ void ExecutionTest::RunCoopVecOuterProductTestConfig( continue; } - RunCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); + runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); } } -void ExecutionTest::RunCoopVecOuterProductSubtest( +void ExecutionTest::runCoopVecOuterProductSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config) { @@ -13324,7 +13324,7 @@ void main(uint threadIdx : SV_GroupThreadID) TEST_F(ExecutionTest, CoopVec_OuterProduct) { WEX::TestExecution::SetVerifyOutput verifySettings( WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - RunCoopVecOuterProductTest(); + runCoopVecOuterProductTest(); } // This test expects a that retrieves a signal value from each of a From 5d4fb6b04be1df8ce571fccbfbfe5553c9fa0cc7 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:17:24 -0400 Subject: [PATCH 07/17] Fix header comment --- tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h b/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h index 6166254294..98c1cdb1c4 100644 --- a/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h +++ b/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h @@ -2,8 +2,8 @@ #if D3D12_PREVIEW_SDK_VERSION < 717 -// This file contains the definitions of the D3D12 cooperative vector API extensions. -// It is used to test the cooperative vector API extensions on older SDKs. +// This file contains the definitions of the D3D12 cooperative vector API. +// It is used to test the cooperative vector API on older SDKs. constexpr int D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL = 9; constexpr int D3D12_FEATURE_COOPERATIVE_VECTOR = 11; From 72c725d0c0b03da7118aa833c6060f529d243d0f Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:21:45 -0400 Subject: [PATCH 08/17] Rename CoopVecAPIExtensions.h to CoopVecAPI.h --- tools/clang/unittests/HLSLExec/CoopVec.h | 2 +- .../HLSLExec/{CoopVecAPIExtensions.h => CoopVecAPI.h} | 0 tools/clang/unittests/HLSLExec/ExecutionTest.cpp | 6 +----- 3 files changed, 2 insertions(+), 6 deletions(-) rename tools/clang/unittests/HLSLExec/{CoopVecAPIExtensions.h => CoopVecAPI.h} (100%) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 3ba2c2babf..474a15d9bd 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -6,7 +6,7 @@ #include "dxc/Support/microcom.h" -#include "CoopVecAPIExtensions.h" +#include "CoopVecAPI.h" struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { private: diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h similarity index 100% rename from tools/clang/unittests/HLSLExec/CoopVecAPIExtensions.h rename to tools/clang/unittests/HLSLExec/CoopVecAPI.h diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index dc555939e2..97c9f01ff2 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -65,7 +65,7 @@ #include #include "LongVectors.h" #include "CoopVec.h" -#include "CoopVecAPIExtensions.h" +#include "CoopVecAPI.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -11997,8 +11997,6 @@ void ExecutionTest::runCoopVecMulTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { - WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { @@ -12745,8 +12743,6 @@ void ExecutionTest::runCoopVecOuterProductTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { - WEX::Logging::Log::Comment("Device does not support SM 6.9. Skipping."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; } if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { From 757d59d2531091573405b16331f0349f497f8d7f Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:25:39 -0400 Subject: [PATCH 09/17] Disable clang-format on CoopVecAPI.h since it comes from d3d12.h --- tools/clang/unittests/HLSLExec/CoopVecAPI.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPI.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h index 98c1cdb1c4..027b30af24 100644 --- a/tools/clang/unittests/HLSLExec/CoopVecAPI.h +++ b/tools/clang/unittests/HLSLExec/CoopVecAPI.h @@ -1,4 +1,5 @@ #pragma once +// clang-format off #if D3D12_PREVIEW_SDK_VERSION < 717 From 19ea49fcba1cd4fec1d90b9e1748618f21030264 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:27:07 -0400 Subject: [PATCH 10/17] Fix missing ThreadIdx increment in OuterProduct reference --- tools/clang/unittests/HLSLExec/ExecutionTest.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 97c9f01ff2..3616130d74 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -13095,7 +13095,8 @@ void main(uint threadIdx : SV_GroupThreadID) for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { for (int M = 0; M < Config.DimM; ++M) { for (int N = 0; N < Config.DimN; ++N) { - float Acc = InputVector1FP32[M] * InputVector2FP32[N]; + float Acc = InputVector1FP32[ThreadIdx * Config.DimM + M] * + InputVector2FP32[ThreadIdx * Config.DimN + N]; ExpectedOutputBuffer[M * Config.DimN + N] += Acc; } } From 3c92b660492304160bb3eb7a218249e2a05ea6b4 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:35:28 -0400 Subject: [PATCH 11/17] Address comment feedback in CoopVec.h --- tools/clang/unittests/HLSLExec/CoopVec.h | 10 ++++++---- tools/clang/unittests/HLSLExec/ExecutionTest.cpp | 11 +++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 474a15d9bd..38fddf5911 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -59,7 +59,7 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { } }; -struct CoopVecHelpers { +namespace CoopVecHelpers { template static std::vector CreateAllOnesInputMatrix(uint32_t Width, uint32_t Height) { @@ -72,6 +72,9 @@ struct CoopVecHelpers { inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); } else if constexpr (std::is_same_v) { inputMatrix[i] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; } } @@ -104,6 +107,7 @@ struct CoopVecHelpers { inputVector[TID * EltsPerThread + 1] = 1.0f; } else { WEX::Logging::Log::Error(L"Unsupported input type"); + break; } } @@ -277,9 +281,7 @@ struct CoopVecHelpers { // This type is used in generated HLSL source to represent the vector type // for the given data type. static std::wstring GetHlslDataTypeForDataType( - D3D12_LINEAR_ALGEBRA_DATATYPE DataType, - D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { - UNREFERENCED_PARAMETER(InputInterpretation); + D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: return L"int16_t"; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 3616130d74..d2d4707d9a 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -12297,10 +12297,10 @@ void main(uint threadIdx : SV_GroupThreadID) MulProps.InputInterpretation); const std::wstring InputDataType = CoopVecHelpers::GetHlslDataTypeForDataType( - MulProps.InputType, MulProps.InputInterpretation); + MulProps.InputType); const std::wstring AccumDataType = - CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation, - MulProps.BiasInterpretation); + CoopVecHelpers::GetHlslDataTypeForDataType( + MulProps.BiasInterpretation); const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType( MulProps.MatrixInterpretation); @@ -12901,11 +12901,10 @@ void main(uint threadIdx : SV_GroupThreadID) CoopVecHelpers::GetNumPackedElementsForInputDataType( AccumulateProps.InputType); const std::wstring InputDataType = - CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType, - AccumulateProps.InputType); + CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType); const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType( - AccumulateProps.AccumulationType, AccumulateProps.AccumulationType); + AccumulateProps.AccumulationType); const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType( AccumulateProps.AccumulationType); From 3db2db0fba83e40c17b78cce2e60a1a7b693b463 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Mon, 5 May 2025 12:40:33 -0400 Subject: [PATCH 12/17] clang-format --- tools/clang/unittests/HLSLExec/CoopVec.h | 48 +++++++++---------- .../unittests/HLSLExec/ExecutionTest.cpp | 6 +-- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 38fddf5911..fbffc57066 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -60,29 +60,29 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { }; namespace CoopVecHelpers { - template - static std::vector CreateAllOnesInputMatrix(uint32_t Width, - uint32_t Height) { - std::vector inputMatrix(Width * Height); - for (uint32_t i = 0; i < Width * Height; i++) { - if constexpr (std::is_same_v || - std::is_same_v) { - inputMatrix[i] = 1; - } else if constexpr (std::is_same_v) { - inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); - } else if constexpr (std::is_same_v) { - inputMatrix[i] = 1.0f; - } else { - WEX::Logging::Log::Error(L"Unsupported input type"); - break; - } +template +static std::vector CreateAllOnesInputMatrix(uint32_t Width, + uint32_t Height) { + std::vector inputMatrix(Width * Height); + for (uint32_t i = 0; i < Width * Height; i++) { + if constexpr (std::is_same_v || + std::is_same_v) { + inputMatrix[i] = 1; + } else if constexpr (std::is_same_v) { + inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + inputMatrix[i] = 1.0f; + } else { + WEX::Logging::Log::Error(L"Unsupported input type"); + break; } + } - // Convert to uint8_t vector - std::vector uint8InputMatrix(inputMatrix.size() * sizeof(EltTy)); - std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), - inputMatrix.size() * sizeof(EltTy)); - return uint8InputMatrix; + // Convert to uint8_t vector + std::vector uint8InputMatrix(inputMatrix.size() * sizeof(EltTy)); + std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), + inputMatrix.size() * sizeof(EltTy)); + return uint8InputMatrix; } template @@ -280,8 +280,8 @@ namespace CoopVecHelpers { // This type is used in generated HLSL source to represent the vector type // for the given data type. - static std::wstring GetHlslDataTypeForDataType( - D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + static std::wstring + GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: return L"int16_t"; @@ -353,4 +353,4 @@ namespace CoopVecHelpers { return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; } } -}; \ No newline at end of file + }; // namespace CoopVecHelpers \ No newline at end of file diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index d2d4707d9a..a9a7e8eb16 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -12296,11 +12296,9 @@ void main(uint threadIdx : SV_GroupThreadID) CoopVecHelpers::GetNumPackedElementsForInputDataType( MulProps.InputInterpretation); const std::wstring InputDataType = - CoopVecHelpers::GetHlslDataTypeForDataType( - MulProps.InputType); + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType); const std::wstring AccumDataType = - CoopVecHelpers::GetHlslDataTypeForDataType( - MulProps.BiasInterpretation); + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation); const std::wstring MatrixDataTypeEnum = CoopVecHelpers::GetHlslInterpretationForDataType( MulProps.MatrixInterpretation); From 5d6fc1ffd261e049199fe05fecadb8ee81f1e442 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 08:54:20 -0400 Subject: [PATCH 13/17] Fix vector indexing --- tools/clang/unittests/HLSLExec/ExecutionTest.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index a9a7e8eb16..59bf4b9323 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -12241,7 +12241,7 @@ void main(uint threadIdx : SV_GroupThreadID) using namespace dx::linalg; // Ensure 4-byte alignment for vector loads - uint inputOffset = (INPUT_PER_THREAD * threadIdx); + uint inputOffset = (INPUT_PER_THREAD * threadIdx * (sizeof(INPUT_DATA_TYPE) / INPUT_DIVISOR)); inputOffset = (inputOffset + 3) & ~3; // Align to 4 bytes vector input = InputVector.Load >(inputOffset); @@ -12259,9 +12259,9 @@ void main(uint threadIdx : SV_GroupThreadID) vector result = (vector)accum; // Ensure 4-byte alignment for vector store - uint outputOffset = OUTPUT_PER_THREAD * threadIdx; + uint outputOffset = OUTPUT_PER_THREAD * threadIdx * sizeof(float); outputOffset = (outputOffset + 3) & ~3; // Align to 4 bytes - OutputBuffer.Store >(outputOffset * 4, result); + OutputBuffer.Store >(outputOffset, result); } )"; @@ -12853,11 +12853,11 @@ void main(uint threadIdx : SV_GroupThreadID) using namespace dx::linalg; // Ensure 4-byte alignment for vector loads - uint inputOffset1 = (DIM_M * threadIdx); + uint inputOffset1 = (DIM_M * threadIdx * sizeof(INPUT_DATA_TYPE)); inputOffset1 = (inputOffset1 + 3) & ~3; // Align to 4 bytes vector input1 = InputVector1.Load >(inputOffset1); - uint inputOffset2 = (DIM_N * threadIdx); + uint inputOffset2 = (DIM_N * threadIdx * sizeof(INPUT_DATA_TYPE)); inputOffset2 = (inputOffset2 + 3) & ~3; // Align to 4 bytes vector input2 = InputVector2.Load >(inputOffset2); From f64db41beb9012b1817bdb4d84ea3422ac128e7d Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 09:12:41 -0400 Subject: [PATCH 14/17] clang-format --- tools/clang/unittests/HLSLExec/CoopVec.h | 484 +++++++++++------------ 1 file changed, 242 insertions(+), 242 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index fbffc57066..2de6461420 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -83,274 +83,274 @@ static std::vector CreateAllOnesInputMatrix(uint32_t Width, std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), inputMatrix.size() * sizeof(EltTy)); return uint8InputMatrix; - } - - template - static std::vector CreateInputVector(uint32_t NumThreads, - uint32_t EltsPerThread) { - std::vector inputVector(NumThreads * EltsPerThread); - std::fill(inputVector.begin(), inputVector.end(), EltTy(0)); - if (EltsPerThread < 2) { - WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); - return std::vector(); - } - for (uint32_t TID = 0; TID < NumThreads; TID++) { - if constexpr (std::is_same_v || - std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = 1; - inputVector[TID * EltsPerThread + 1] = 1; - } else if constexpr (std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); - inputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); - } else if constexpr (std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = 1.0f; - inputVector[TID * EltsPerThread + 1] = 1.0f; - } else { - WEX::Logging::Log::Error(L"Unsupported input type"); - break; - } - } +} - // Convert to uint8_t vector - std::vector uint8InputVector(inputVector.size() * sizeof(EltTy)); - std::memcpy(uint8InputVector.data(), inputVector.data(), - inputVector.size() * sizeof(EltTy)); - return uint8InputVector; +template +static std::vector CreateInputVector(uint32_t NumThreads, + uint32_t EltsPerThread) { + std::vector inputVector(NumThreads * EltsPerThread); + std::fill(inputVector.begin(), inputVector.end(), EltTy(0)); + if (EltsPerThread < 2) { + WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); + return std::vector(); } - - template - static std::vector CreateInputBias(uint32_t NumElts) { - std::vector inputBias(NumElts); + for (uint32_t TID = 0; TID < NumThreads; TID++) { if constexpr (std::is_same_v || std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), EltTy(1)); + inputVector[TID * EltsPerThread + 0] = 1; + inputVector[TID * EltsPerThread + 1] = 1; } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), - ConvertFloat32ToFloat16(1.0f)); - } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), 1); + inputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); + inputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); + } else if constexpr (std::is_same_v) { + inputVector[TID * EltsPerThread + 0] = 1.0f; + inputVector[TID * EltsPerThread + 1] = 1.0f; } else { - WEX::Logging::Log::Error(L"Unsupported bias type"); + WEX::Logging::Log::Error(L"Unsupported input type"); + break; } - // Convert to uint8_t vector - std::vector uint8InputBias(inputBias.size() * sizeof(EltTy)); - std::memcpy(uint8InputBias.data(), inputBias.data(), - inputBias.size() * sizeof(EltTy)); - return uint8InputBias; } - static std::wstring - DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { - switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - return L"SINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return L"UINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - return L"SINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - return L"UINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"SINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"UINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"SINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"UINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"FLOAT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"FLOAT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - return L"FLOAT_E4M3"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return L"FLOAT_E5M2"; - default: - return L""; - } - } + // Convert to uint8_t vector + std::vector uint8InputVector(inputVector.size() * sizeof(EltTy)); + std::memcpy(uint8InputVector.data(), inputVector.data(), + inputVector.size() * sizeof(EltTy)); + return uint8InputVector; +} - static bool IsDataTypeInFilter(const wchar_t *FilterKey, - D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { - WEX::Common::String ParamValue; - if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( - FilterKey, ParamValue))) { - // Filter not set, so treat as no filter - return true; - } - if (ParamValue.IsEmpty()) { - // Empty filter, so treat as no filter - return true; - } +template +static std::vector CreateInputBias(uint32_t NumElts) { + std::vector inputBias(NumElts); + if constexpr (std::is_same_v || + std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), EltTy(1)); + } else if constexpr (std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), + ConvertFloat32ToFloat16(1.0f)); + } else if constexpr (std::is_same_v) { + std::fill(inputBias.begin(), inputBias.end(), 1); + } else { + WEX::Logging::Log::Error(L"Unsupported bias type"); + } + // Convert to uint8_t vector + std::vector uint8InputBias(inputBias.size() * sizeof(EltTy)); + std::memcpy(uint8InputBias.data(), inputBias.data(), + inputBias.size() * sizeof(EltTy)); + return uint8InputBias; +} - // Check if the filter matches the target data type - LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); - return DataTypeToFilterString(DataType) == FilterString; +static std::wstring +DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"FLOAT_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"FLOAT_E5M2"; + default: + return L""; } +} - static std::wstring - MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { - switch (MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - return L"ROW_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - return L"COLUMN_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: - return L"MUL_OPTIMAL"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: - return L"OUTER_PRODUCT_OPTIMAL"; - default: - return L""; - } +static bool IsDataTypeInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; } - static bool - IsMatrixLayoutInFilter(const wchar_t *FilterKey, - D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { - WEX::Common::String ParamValue; - if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( - FilterKey, ParamValue))) { - // Filter not set, so treat as no filter - return true; - } - if (ParamValue.IsEmpty()) { - // Empty filter, so treat as no filter - return true; - } + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return DataTypeToFilterString(DataType) == FilterString; +} - // Check if the filter matches the target data type - LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); - return MatrixLayoutToFilterString(MatrixLayout) == FilterString; +static std::wstring +MatrixLayoutToFilterString(D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"OUTER_PRODUCT_OPTIMAL"; + default: + return L""; } +} - static std::wstring MatrixLayoutToHlslLayoutString( - D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { - switch (MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - return L"MATRIX_LAYOUT_ROW_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - return L"MATRIX_LAYOUT_COLUMN_MAJOR"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: - return L"MATRIX_LAYOUT_MUL_OPTIMAL"; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: - return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; - default: - return L""; - } +static bool +IsMatrixLayoutInFilter(const wchar_t *FilterKey, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + WEX::Common::String ParamValue; + if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue(FilterKey, + ParamValue))) { + // Filter not set, so treat as no filter + return true; + } + if (ParamValue.IsEmpty()) { + // Empty filter, so treat as no filter + return true; } - // This multiplier is used to compute the row/column stride for a matrix - // given it's element size. - static int - GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { - switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return 1; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return 2; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return 4; - default: - WEX::Logging::Log::Error(L"Unsupported matrix data type"); - return 1; - } + // Check if the filter matches the target data type + LPCWSTR FilterString = reinterpret_cast(ParamValue.GetBuffer()); + return MatrixLayoutToFilterString(MatrixLayout) == FilterString; +} + +static std::wstring MatrixLayoutToHlslLayoutString( + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + switch (MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + return L"MATRIX_LAYOUT_ROW_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + return L"MATRIX_LAYOUT_COLUMN_MAJOR"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL: + return L"MATRIX_LAYOUT_MUL_OPTIMAL"; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL: + return L"MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL"; + default: + return L""; } +} - static int GetNumPackedElementsForInputDataType( - D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { - // Int8 packed types are the only ones that have more than 1 element per - // shader variable - switch (InputInterpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return 4; - default: - return 1; - } +// This multiplier is used to compute the row/column stride for a matrix +// given it's element size. +static int +GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return 1; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return 2; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return 4; + default: + WEX::Logging::Log::Error(L"Unsupported matrix data type"); + return 1; } +} - // This type is used in generated HLSL source to represent the vector type - // for the given data type. - static std::wstring - GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { - switch (DataType) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"int16_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"uint16_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"int32_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"uint32_t"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"half"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"float"; - default: - WEX::Logging::Log::Error(L"Unsupported input data type"); - return L""; - } +static int GetNumPackedElementsForInputDataType( + D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { + // Int8 packed types are the only ones that have more than 1 element per + // shader variable + switch (InputInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return 4; + default: + return 1; } +} - static std::wstring GetHlslInterpretationForDataType( - D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { - switch (Interpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - return L"DATA_TYPE_SINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - return L"DATA_TYPE_UINT8_T4_PACKED"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - return L"DATA_TYPE_SINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - return L"DATA_TYPE_UINT8"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - return L"DATA_TYPE_SINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - return L"DATA_TYPE_UINT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - return L"DATA_TYPE_SINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return L"DATA_TYPE_UINT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - return L"DATA_TYPE_FLOAT16"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: - return L"DATA_TYPE_FLOAT32"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - return L"DATA_TYPE_FLOAT8_E4M3"; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - return L"DATA_TYPE_FLOAT8_E5M2"; - default: - WEX::Logging::Log::Error(L"Unsupported interpretation"); - return L""; - } +// This type is used in generated HLSL source to represent the vector type +// for the given data type. +static std::wstring +GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"int16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"uint16_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"int32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"uint32_t"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"half"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"float"; + default: + WEX::Logging::Log::Error(L"Unsupported input data type"); + return L""; } +} - // The returned data type is used for matrix conversion. It is hard-coded - // for the test framework where all integer matrices start as SINT8 and - // all FP matrices start as FLOAT32. - static D3D12_LINEAR_ALGEBRA_DATATYPE - GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { - switch (MatrixInterpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: - case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: - return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - default: - return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; - } +static std::wstring +GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { + switch (Interpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + return L"DATA_TYPE_SINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + return L"DATA_TYPE_UINT8_T4_PACKED"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + return L"DATA_TYPE_SINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return L"DATA_TYPE_UINT8"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + return L"DATA_TYPE_SINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return L"DATA_TYPE_UINT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + return L"DATA_TYPE_SINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return L"DATA_TYPE_UINT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + return L"DATA_TYPE_FLOAT16"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return L"DATA_TYPE_FLOAT32"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + return L"DATA_TYPE_FLOAT8_E4M3"; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return L"DATA_TYPE_FLOAT8_E5M2"; + default: + WEX::Logging::Log::Error(L"Unsupported interpretation"); + return L""; + } +} + +// The returned data type is used for matrix conversion. It is hard-coded +// for the test framework where all integer matrices start as SINT8 and +// all FP matrices start as FLOAT32. +static D3D12_LINEAR_ALGEBRA_DATATYPE +GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { + switch (MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + return D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + default: + return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; } - }; // namespace CoopVecHelpers \ No newline at end of file +} +}; // namespace CoopVecHelpers \ No newline at end of file From 906b8b7cf65eb435e3f9e6afa86963e305b06528 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 09:41:42 -0400 Subject: [PATCH 15/17] Only support CoopVec tests if using d3d12.h with at least ID3D12GraphicsCommandList10 --- tools/clang/unittests/HLSLExec/CoopVecAPI.h | 16 ++++++++-- .../unittests/HLSLExec/ExecutionTest.cpp | 30 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPI.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h index 027b30af24..16c1105edc 100644 --- a/tools/clang/unittests/HLSLExec/CoopVecAPI.h +++ b/tools/clang/unittests/HLSLExec/CoopVecAPI.h @@ -1,7 +1,10 @@ #pragma once // clang-format off -#if D3D12_PREVIEW_SDK_VERSION < 717 +#if !defined(D3D12_PREVIEW_SDK_VERSION) || D3D12_PREVIEW_SDK_VERSION < 717 + +#ifdef __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +#define HAVE_COOPVEC_API 1 // This file contains the definitions of the D3D12 cooperative vector API. // It is used to test the cooperative vector API on older SDKs. @@ -160,7 +163,16 @@ ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 _In_ UINT DescCount) = 0; }; - + #endif /* __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ */ +#else // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ +// The used d3d12.h header does not support ID3D12GraphicsCommandList10, +// so we cannot define ID3D12GraphicsCommandList11. +#define HAVE_COOPVEC_API 0 +#endif // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ + +#else // D3D12_PREVIEW_SDK_VERSION < 717 +// Preview header has CoopVec support +#define HAVE_COOPVEC_API 1 #endif // D3D12_PREVIEW_SDK_VERSION < 717 diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 59bf4b9323..81b335074a 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -781,6 +781,10 @@ class ExecutionTest { void RunResourceTest(ID3D12Device *pDevice, const char *pShader, const wchar_t *sm, bool isDynamic); + void runCoopVecMulTest(); + void runCoopVecOuterProductTest(); + +#if HAVE_COOPVEC_API struct CoopVecMulSubtestConfig { int InputPerThread; int OutputPerThread; @@ -790,7 +794,6 @@ class ExecutionTest { bool Bias; }; - void runCoopVecMulTest(); void runCoopVecMulTestConfig(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); @@ -805,7 +808,6 @@ class ExecutionTest { D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; }; - void runCoopVecOuterProductTest(); void runCoopVecOuterProductTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps); @@ -813,6 +815,7 @@ class ExecutionTest { ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, CoopVecOuterProductSubtestConfig &Config); +#endif // HAVE_COOPVEC_API template void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList, @@ -1771,6 +1774,7 @@ class ExecutionTest { } bool DoesDeviceSupportCooperativeVector(ID3D12Device *Device) { +#if HAVE_COOPVEC_API D3D12_FEATURE_DATA_D3D12_OPTIONS_EXPERIMENTAL O; if (FAILED(Device->CheckFeatureSupport( (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS_EXPERIMENTAL, &O, @@ -1778,6 +1782,10 @@ class ExecutionTest { return false; return O.CooperativeVectorTier != D3D12_COOPERATIVE_VECTOR_TIER_NOT_SUPPORTED; +#else + UNREFERENCED_PARAMETER(Device); + return false; +#endif } bool IsFallbackPathEnabled() { @@ -11994,6 +12002,12 @@ VERIFY_SUCCEEDED(DoArraysMatch(OutputVector, ExpectedVector, // // The current implementation will always write the final output data as float. void ExecutionTest::runCoopVecMulTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { @@ -12104,8 +12118,10 @@ void ExecutionTest::runCoopVecMulTest() { // Run the test runCoopVecMulTestConfig(D3DDevice, MulAddConfig); } +#endif // HAVE_COOPVEC_API } +#if HAVE_COOPVEC_API void ExecutionTest::runCoopVecMulTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps) { @@ -12730,6 +12746,7 @@ void main(uint threadIdx : SV_GroupThreadID) VERIFY_IS_TRUE(Equal); } } +#endif // HAVE_COOPVEC_API TEST_F(ExecutionTest, CoopVec_Mul) { WEX::TestExecution::SetVerifyOutput verifySettings( @@ -12738,6 +12755,12 @@ TEST_F(ExecutionTest, CoopVec_Mul) { } void ExecutionTest::runCoopVecOuterProductTest() { +#if !HAVE_COOPVEC_API + WEX::Logging::Log::Comment( + "Cooperative vector API not supported in build configuration. Skipping."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; +#else // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { @@ -12771,8 +12794,10 @@ void ExecutionTest::runCoopVecOuterProductTest() { // Run the test runCoopVecOuterProductTestConfig(D3DDevice, AccumulateConfig); } +#endif // HAVE_COOPVEC_API } +#if HAVE_COOPVEC_API void ExecutionTest::runCoopVecOuterProductTestConfig( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps) { @@ -13314,6 +13339,7 @@ void main(uint threadIdx : SV_GroupThreadID) VERIFY_IS_TRUE(Equal); } } +#endif // HAVE_COOPVEC_API TEST_F(ExecutionTest, CoopVec_OuterProduct) { WEX::TestExecution::SetVerifyOutput verifySettings( From be3992bf6fe249b47e354655129d328a7b11edad Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 12:19:10 -0400 Subject: [PATCH 16/17] Fix build without coopvec support --- tools/clang/unittests/HLSLExec/CoopVec.h | 6 +++++- tools/clang/unittests/HLSLExec/ExecutionTest.cpp | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 2de6461420..2b1dd0b9d6 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -1,5 +1,7 @@ #pragma once +#if HAVE_COOPVEC_API + #include #include #include @@ -353,4 +355,6 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; } } -}; // namespace CoopVecHelpers \ No newline at end of file +}; // namespace CoopVecHelpers + +#endif // HAVE_COOPVEC_API diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 81b335074a..55d569dd8d 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -64,8 +64,8 @@ #include #include #include "LongVectors.h" -#include "CoopVec.h" #include "CoopVecAPI.h" +#include "CoopVec.h" // clang-format on #pragma comment(lib, "d3dcompiler.lib") @@ -1904,10 +1904,12 @@ class ExecutionTest { std::vector Features; Features.push_back(D3D12ExperimentalShaderModels); + +#if HAVE_COOPVEC_API if (GetTestParamBool(L"CooperativeVectorExperimental")) { Features.push_back(D3D12CooperativeVectorExperiment); } - +#endif return pD3D12EnableExperimentalFeatures((UINT)Features.size(), Features.data(), nullptr, nullptr); } From 608840a3ae6155a8350cdf08e60cb5367ba77c6a Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Tue, 6 May 2025 12:22:30 -0400 Subject: [PATCH 17/17] Style fixes for CoopVec.h --- tools/clang/unittests/HLSLExec/CoopVec.h | 85 ++++++++++++------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index 2b1dd0b9d6..f166c61f67 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -12,20 +12,20 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { private: - DXC_MICROCOM_REF_FIELD(m_dwRef) + DXC_MICROCOM_REF_FIELD(RefCount) dxc::DxcDllSupport &DxcSupport; public: LinAlgHeaderIncludeHandler() = delete; LinAlgHeaderIncludeHandler(dxc::DxcDllSupport &DxcSupport) - : m_dwRef(0), DxcSupport(DxcSupport) {} + : RefCount(0), DxcSupport(DxcSupport) {} - DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef) + DXC_MICROCOM_ADDREF_RELEASE_IMPL(RefCount) - HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR pFilename, - IDxcBlob **ppIncludeSource) { - if (wcscmp(pFilename, L"dx/linalg.h") == 0 || - wcscmp(pFilename, L".\\dx\\linalg.h") == 0) { + HRESULT STDMETHODCALLTYPE LoadSource(LPCWSTR Filename, + IDxcBlob **IncludeSource) { + if (wcscmp(Filename, L"dx/linalg.h") == 0 || + wcscmp(Filename, L".\\dx\\linalg.h") == 0) { WEX::Common::String ParamValue; if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( L"LinAlgHeader", ParamValue))) { @@ -37,26 +37,25 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { LPCWSTR RealHeaderPath = reinterpret_cast(ParamValue.GetBuffer()); - CComPtr pHeaderUtils; + CComPtr HeaderUtils; - IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &pHeaderUtils)); + IFT(DxcSupport.CreateInstance(CLSID_DxcUtils, &HeaderUtils)); - IDxcBlobEncoding *pHeaderBlob; - IFT(pHeaderUtils->LoadFile(RealHeaderPath, nullptr, &pHeaderBlob)); + IDxcBlobEncoding *HeaderBlob; + IFT(HeaderUtils->LoadFile(RealHeaderPath, nullptr, &HeaderBlob)); - *ppIncludeSource = pHeaderBlob; + *IncludeSource = HeaderBlob; return S_OK; } return E_FAIL; } - HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, - void **ppvObject) override { + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID IID, void **Object) override { // FIXME: This is a workaround for a warning-as-error about unused parameters. #pragma warning(push) #pragma warning(disable : 4100) - return DoBasicQueryInterface(this, iid, ppvObject); + return DoBasicQueryInterface(this, IID, Object); #pragma warning(pop) } }; @@ -65,15 +64,15 @@ namespace CoopVecHelpers { template static std::vector CreateAllOnesInputMatrix(uint32_t Width, uint32_t Height) { - std::vector inputMatrix(Width * Height); + std::vector InputMatrix(Width * Height); for (uint32_t i = 0; i < Width * Height; i++) { if constexpr (std::is_same_v || std::is_same_v) { - inputMatrix[i] = 1; + InputMatrix[i] = 1; } else if constexpr (std::is_same_v) { - inputMatrix[i] = ConvertFloat32ToFloat16(1.0f); + InputMatrix[i] = ConvertFloat32ToFloat16(1.0f); } else if constexpr (std::is_same_v) { - inputMatrix[i] = 1.0f; + InputMatrix[i] = 1.0f; } else { WEX::Logging::Log::Error(L"Unsupported input type"); break; @@ -81,17 +80,17 @@ static std::vector CreateAllOnesInputMatrix(uint32_t Width, } // Convert to uint8_t vector - std::vector uint8InputMatrix(inputMatrix.size() * sizeof(EltTy)); - std::memcpy(uint8InputMatrix.data(), inputMatrix.data(), - inputMatrix.size() * sizeof(EltTy)); - return uint8InputMatrix; + std::vector Uint8InputMatrix(InputMatrix.size() * sizeof(EltTy)); + std::memcpy(Uint8InputMatrix.data(), InputMatrix.data(), + InputMatrix.size() * sizeof(EltTy)); + return Uint8InputMatrix; } template static std::vector CreateInputVector(uint32_t NumThreads, uint32_t EltsPerThread) { - std::vector inputVector(NumThreads * EltsPerThread); - std::fill(inputVector.begin(), inputVector.end(), EltTy(0)); + std::vector InputVector(NumThreads * EltsPerThread); + std::fill(InputVector.begin(), InputVector.end(), EltTy(0)); if (EltsPerThread < 2) { WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); return std::vector(); @@ -99,14 +98,14 @@ static std::vector CreateInputVector(uint32_t NumThreads, for (uint32_t TID = 0; TID < NumThreads; TID++) { if constexpr (std::is_same_v || std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = 1; - inputVector[TID * EltsPerThread + 1] = 1; + InputVector[TID * EltsPerThread + 0] = 1; + InputVector[TID * EltsPerThread + 1] = 1; } else if constexpr (std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); - inputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); + InputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); + InputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); } else if constexpr (std::is_same_v) { - inputVector[TID * EltsPerThread + 0] = 1.0f; - inputVector[TID * EltsPerThread + 1] = 1.0f; + InputVector[TID * EltsPerThread + 0] = 1.0f; + InputVector[TID * EltsPerThread + 1] = 1.0f; } else { WEX::Logging::Log::Error(L"Unsupported input type"); break; @@ -114,31 +113,31 @@ static std::vector CreateInputVector(uint32_t NumThreads, } // Convert to uint8_t vector - std::vector uint8InputVector(inputVector.size() * sizeof(EltTy)); - std::memcpy(uint8InputVector.data(), inputVector.data(), - inputVector.size() * sizeof(EltTy)); - return uint8InputVector; + std::vector Uint8InputVector(InputVector.size() * sizeof(EltTy)); + std::memcpy(Uint8InputVector.data(), InputVector.data(), + InputVector.size() * sizeof(EltTy)); + return Uint8InputVector; } template static std::vector CreateInputBias(uint32_t NumElts) { - std::vector inputBias(NumElts); + std::vector InputBias(NumElts); if constexpr (std::is_same_v || std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), EltTy(1)); + std::fill(InputBias.begin(), InputBias.end(), EltTy(1)); } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), + std::fill(InputBias.begin(), InputBias.end(), ConvertFloat32ToFloat16(1.0f)); } else if constexpr (std::is_same_v) { - std::fill(inputBias.begin(), inputBias.end(), 1); + std::fill(InputBias.begin(), InputBias.end(), 1); } else { WEX::Logging::Log::Error(L"Unsupported bias type"); } // Convert to uint8_t vector - std::vector uint8InputBias(inputBias.size() * sizeof(EltTy)); - std::memcpy(uint8InputBias.data(), inputBias.data(), - inputBias.size() * sizeof(EltTy)); - return uint8InputBias; + std::vector Uint8InputBias(InputBias.size() * sizeof(EltTy)); + std::memcpy(Uint8InputBias.data(), InputBias.data(), + InputBias.size() * sizeof(EltTy)); + return Uint8InputBias; } static std::wstring