diff --git a/tools/clang/unittests/HLSLExec/CoopVec.h b/tools/clang/unittests/HLSLExec/CoopVec.h index f166c61f67..c5c81800ac 100644 --- a/tools/clang/unittests/HLSLExec/CoopVec.h +++ b/tools/clang/unittests/HLSLExec/CoopVec.h @@ -4,13 +4,17 @@ #include #include + +#include +#include +#include #include #include "dxc/Support/microcom.h" #include "CoopVecAPI.h" -struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { +class LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { private: DXC_MICROCOM_REF_FIELD(RefCount) dxc::DxcDllSupport &DxcSupport; @@ -29,11 +33,14 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { WEX::Common::String ParamValue; if (FAILED(WEX::TestExecution::RuntimeParameters::TryGetValue( L"LinAlgHeader", ParamValue))) { + WEX::Logging::Log::Error( + L"Missing expected TAEF runtime parameter LinAlgHeader"); return E_FAIL; } - if (ParamValue.IsEmpty()) { + + if (ParamValue.IsEmpty()) return E_FAIL; - } + LPCWSTR RealHeaderPath = reinterpret_cast(ParamValue.GetBuffer()); @@ -61,11 +68,12 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { }; namespace CoopVecHelpers { + template -static std::vector CreateAllOnesInputMatrix(uint32_t Width, - uint32_t Height) { +static std::vector CreateAllOnesInputMatrix(size_t Width, + size_t Height) { std::vector InputMatrix(Width * Height); - for (uint32_t i = 0; i < Width * Height; i++) { + for (size_t i = 0; i < Width * Height; i++) { if constexpr (std::is_same_v || std::is_same_v) { InputMatrix[i] = 1; @@ -74,7 +82,7 @@ static std::vector CreateAllOnesInputMatrix(uint32_t Width, } else if constexpr (std::is_same_v) { InputMatrix[i] = 1.0f; } else { - WEX::Logging::Log::Error(L"Unsupported input type"); + VERIFY_FAIL(L"Unsupported input type"); break; } } @@ -86,60 +94,6 @@ static std::vector CreateAllOnesInputMatrix(uint32_t Width, return Uint8InputMatrix; } -template -static std::vector CreateInputVector(uint32_t NumThreads, - uint32_t EltsPerThread) { - std::vector InputVector(NumThreads * EltsPerThread); - std::fill(InputVector.begin(), InputVector.end(), EltTy(0)); - if (EltsPerThread < 2) { - WEX::Logging::Log::Error(L"EltsPerThread must be at least 2"); - return std::vector(); - } - for (uint32_t TID = 0; TID < NumThreads; TID++) { - if constexpr (std::is_same_v || - std::is_same_v) { - InputVector[TID * EltsPerThread + 0] = 1; - InputVector[TID * EltsPerThread + 1] = 1; - } else if constexpr (std::is_same_v) { - InputVector[TID * EltsPerThread + 0] = ConvertFloat32ToFloat16(1.0f); - InputVector[TID * EltsPerThread + 1] = ConvertFloat32ToFloat16(1.0f); - } else if constexpr (std::is_same_v) { - InputVector[TID * EltsPerThread + 0] = 1.0f; - InputVector[TID * EltsPerThread + 1] = 1.0f; - } else { - WEX::Logging::Log::Error(L"Unsupported input type"); - break; - } - } - - // Convert to uint8_t vector - std::vector Uint8InputVector(InputVector.size() * sizeof(EltTy)); - std::memcpy(Uint8InputVector.data(), InputVector.data(), - InputVector.size() * sizeof(EltTy)); - return Uint8InputVector; -} - -template -static std::vector CreateInputBias(uint32_t NumElts) { - std::vector InputBias(NumElts); - if constexpr (std::is_same_v || - std::is_same_v) { - std::fill(InputBias.begin(), InputBias.end(), EltTy(1)); - } else if constexpr (std::is_same_v) { - std::fill(InputBias.begin(), InputBias.end(), - ConvertFloat32ToFloat16(1.0f)); - } else if constexpr (std::is_same_v) { - std::fill(InputBias.begin(), InputBias.end(), 1); - } else { - WEX::Logging::Log::Error(L"Unsupported bias type"); - } - // Convert to uint8_t vector - std::vector Uint8InputBias(InputBias.size() * sizeof(EltTy)); - std::memcpy(Uint8InputBias.data(), InputBias.data(), - InputBias.size() * sizeof(EltTy)); - return Uint8InputBias; -} - static std::wstring DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { @@ -168,7 +122,9 @@ DataTypeToFilterString(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: return L"FLOAT_E5M2"; default: - return L""; + VERIFY_FAIL(WEX::Common::String().Format( + L"Unrecognized D3D12_LINEAR_ALGEBRA_DATATYPE: %d", DataType)); + return L""; } } @@ -243,7 +199,7 @@ static std::wstring MatrixLayoutToHlslLayoutString( // This multiplier is used to compute the row/column stride for a matrix // given it's element size. -static int +static size_t GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { switch (DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: @@ -261,12 +217,12 @@ GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: return 4; default: - WEX::Logging::Log::Error(L"Unsupported matrix data type"); + VERIFY_FAIL(L"Unsupported matrix data type"); return 1; } } -static int GetNumPackedElementsForInputDataType( +static size_t GetNumPackedElementsForInputDataType( D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) { // Int8 packed types are the only ones that have more than 1 element per // shader variable @@ -297,8 +253,8 @@ GetHlslDataTypeForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: return L"float"; default: - WEX::Logging::Log::Error(L"Unsupported input data type"); - return L""; + VERIFY_FAIL(L"Unsupported input data type"); + return L""; } } @@ -330,8 +286,8 @@ GetHlslInterpretationForDataType(D3D12_LINEAR_ALGEBRA_DATATYPE Interpretation) { case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: return L"DATA_TYPE_FLOAT8_E5M2"; default: - WEX::Logging::Log::Error(L"Unsupported interpretation"); - return L""; + VERIFY_FAIL(L"Unsupported interpretation"); + return L""; } } @@ -354,6 +310,486 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; } } + +static bool IsIntegralDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) { + return DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 || + DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 || + DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 || + DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 || + DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32; +} + +static size_t +GetVectorElementSize(D3D12_LINEAR_ALGEBRA_DATATYPE DataType, + D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) { + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + return sizeof(int8_t); + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + return sizeof(int16_t); + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) { + return sizeof(int8_t); + } else { + return sizeof(int32_t); + } + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + return sizeof(DirectX::PackedVector::HALF); + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + return sizeof(float); + default: + VERIFY_FAIL(L"Unsupported data type"); + return 0; + } +} + +static size_t +GetMatrixElementSize(D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) { + switch (DataInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + // The CPU reference matrix is always int8 for all integer + // interpretations. The GPU version will be converted to the destination + // format by ConvertLinearAlgebraMatrix. + return sizeof(int8_t); + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + // The CPU reference matrix is always FP32 for all FP interpretations. + // The GPU version will be converted to the destination format by + // ConvertLinearAlgebraMatrix. + return sizeof(float); + default: + VERIFY_FAIL(L"Unsupported data type"); + return 0; + } +} + +class TestVector { +private: + size_t NumVectors = 0; + size_t VectorSize = 0; + size_t ElementSize = 0; + size_t Stride = 0; + size_t TotalBytes = 0; + std::unique_ptr Buffer; + +public: + TestVector(size_t NumVectors, size_t VectorSize, size_t ElementSize, + size_t Alignment = 16) + : NumVectors(NumVectors), VectorSize(VectorSize), + ElementSize(ElementSize) { + if (NumVectors == 0) + VERIFY_FAIL(L"NumVectors must be greater than 0"); + if (VectorSize == 0) + VERIFY_FAIL(L"VectorSize must be greater than 0"); + if (ElementSize == 0) + VERIFY_FAIL(L"ElementSize must be greater than 0"); + + const size_t VectorBytes = VectorSize * ElementSize; + Stride = ((VectorBytes + Alignment - 1) / Alignment) * Alignment; + TotalBytes = Stride * NumVectors; + + Buffer = std::make_unique(TotalBytes); + std::fill(Buffer.get(), Buffer.get() + TotalBytes, (uint8_t)0xFF); + } + + // Copy constructor + TestVector(const TestVector &other) + : NumVectors(other.NumVectors), VectorSize(other.VectorSize), + ElementSize(other.ElementSize), Stride(other.Stride), + TotalBytes(other.TotalBytes) { + if (other.Buffer) { + Buffer = std::make_unique(TotalBytes); + std::memcpy(Buffer.get(), other.Buffer.get(), TotalBytes); + } + } + + // Move constructor + TestVector(TestVector &&other) noexcept + : NumVectors(other.NumVectors), VectorSize(other.VectorSize), + ElementSize(other.ElementSize), Stride(other.Stride), + TotalBytes(other.TotalBytes), Buffer(std::move(other.Buffer)) { + // Reset the source object + other.NumVectors = 0; + other.VectorSize = 0; + other.ElementSize = 0; + other.Stride = 0; + other.TotalBytes = 0; + } + + ~TestVector() = default; + + size_t getNumVectors() const { return NumVectors; } + size_t getVectorSize() const { return VectorSize; } + size_t getElementSize() const { return ElementSize; } + size_t getStride() const { return Stride; } + size_t getTotalBytes() const { return TotalBytes; } + uint8_t *getBuffer() { return Buffer.get(); } + const uint8_t *getBuffer() const { return Buffer.get(); } + + // Copy assignment operator + TestVector &operator=(const TestVector &other) { + if (this != &other) { + // Copy metadata + NumVectors = other.NumVectors; + VectorSize = other.VectorSize; + ElementSize = other.ElementSize; + Stride = other.Stride; + TotalBytes = other.TotalBytes; + + // Copy data + if (other.Buffer) { + Buffer = std::make_unique(TotalBytes); + std::memcpy(Buffer.get(), other.Buffer.get(), TotalBytes); + } else { + Buffer.reset(); + } + } + return *this; + } + + // Move assignment operator + TestVector &operator=(TestVector &&other) noexcept { + if (this != &other) { + // Move metadata and buffer + NumVectors = other.NumVectors; + VectorSize = other.VectorSize; + ElementSize = other.ElementSize; + Stride = other.Stride; + TotalBytes = other.TotalBytes; + Buffer = std::move(other.Buffer); + + // Reset the source object + other.NumVectors = 0; + other.VectorSize = 0; + other.ElementSize = 0; + other.Stride = 0; + other.TotalBytes = 0; + } + return *this; + } + + template T *getVector(size_t I) { + return reinterpret_cast(Buffer.get() + I * Stride); + } + + template const T *getVector(size_t I) const { + return reinterpret_cast(Buffer.get() + I * Stride); + } + + template void fill(const T &Value) { + for (size_t I = 0; I < NumVectors; ++I) { + T *Vec = getVector(I); + for (size_t J = 0; J < VectorSize; ++J) + Vec[J] = Value; + } + } + + template + void fillSimpleTestData(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation, + std::mt19937 &Rnd) { + for (size_t I = 0; I < NumVectors; ++I) { + T *Vec = getVector(I); + for (size_t J = 0; J < VectorSize; ++J) + if constexpr (std::is_same_v || + std::is_same_v) { + float Elt = 0.0f; + + // Generate random input in the following ranges: + // - Integral types: [-3, 4] by 1 + // - FP types: [-0.5, 1] by 0.5 + if (IsIntegralDataType(MatrixInterpretation)) + Elt = static_cast(Rnd() & 0x7) - 3.0f; + else + Elt = (static_cast(Rnd() & 0x3) - 1.0f) / 2.0f; + + if constexpr (std::is_same_v) + Vec[J] = static_cast(ConvertFloat32ToFloat16(Elt)); + else + Vec[J] = static_cast(Elt); + } else { + // Generate random input in the following ranges: + // - Signed types: [-8, 7] by 1 + // - Unsigned types: [0, 15] by 1 + if constexpr (std::is_signed_v) + Vec[J] = static_cast((int32_t)(Rnd() & 0xf) - 8); + else + Vec[J] = static_cast((uint32_t)(Rnd() & 0xf)); + } + } + } + + template void FillSimpleMatrixTestData(std::mt19937 &Rnd) { + for (size_t I = 0; I < NumVectors; ++I) { + T *Vec = getVector(I); + for (size_t J = 0; J < VectorSize; ++J) + if constexpr (std::is_same_v) { + float Elt = (static_cast(Rnd() & 0x3) - 1.0f) / 2.0f; + Vec[J] = static_cast(ConvertFloat32ToFloat16(Elt)); + } else if constexpr (std::is_same_v) { + float Elt = (static_cast(Rnd() & 0x3) - 1.0f) / 2.0f; + Vec[J] = static_cast(Elt); + } else { + if constexpr (std::is_signed_v) { + Vec[J] = static_cast((int32_t)(Rnd() & 0xf) - 8); + } else { + Vec[J] = static_cast((uint32_t)(Rnd() & 0xf)); + } + } + } + } + + static TestVector + createSimpleTestVector(size_t NumVectors, size_t VectorSize, + D3D12_LINEAR_ALGEBRA_DATATYPE DataType, + D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation, + D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation, + std::mt19937 &Rnd) { + const size_t ElementSize = + ::CoopVecHelpers::GetVectorElementSize(DataType, DataInterpretation); + + TestVector Vec(NumVectors, VectorSize, ElementSize); + switch (DataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) { + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + } else { + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + } + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + Vec.fillSimpleTestData(MatrixInterpretation, + Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + Vec.fillSimpleTestData(MatrixInterpretation, Rnd); + break; + default: + VERIFY_FAIL(L"Unsupported data type"); + break; + } + return Vec; + } + + static TestVector + createSimpleTestMatrix(size_t NumVectors, size_t VectorSize, + D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation, + std::mt19937 &Rnd) { + const size_t ElementSize = + ::CoopVecHelpers::GetMatrixElementSize(DataInterpretation); + + TestVector Vec(NumVectors, VectorSize, ElementSize); + switch (DataInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: + // The CPU reference matrix is always int8 for all integer + // interpretations. The GPU version will be converted to the destination + // format by ConvertLinearAlgebraMatrix. + Vec.FillSimpleMatrixTestData(Rnd); + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: + // The CPU reference matrix is always FP32 for all FP interpretations. + // The GPU version will be converted to the destination format by + // ConvertLinearAlgebraMatrix. + Vec.FillSimpleMatrixTestData(Rnd); + break; + default: + VERIFY_FAIL(L"Unsupported data type"); + break; + } + return Vec; + } + + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO + getConversionInfo(ID3D12Device *D3DDevice, + D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType, + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) { + // Create source matrix info + D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; + ConvertInfo.SrcInfo.SrcDataType = + ::CoopVecHelpers::GetMatrixSrcDataType(DestDataType); + ConvertInfo.SrcInfo.SrcLayout = + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; + + // Create destination matrix info + ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver + UINT DestEltSize = 0; + switch (DestDataType) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; + DestEltSize = 1; + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; + DestEltSize = 2; // FP16 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; + DestEltSize = 1; // FP8 + break; + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + ConvertInfo.DestInfo.DestDataType = + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; + DestEltSize = 1; // FP8 + break; + } + ConvertInfo.SrcInfo.SrcStride = static_cast(getStride()); + ConvertInfo.SrcInfo.SrcSize = static_cast(getTotalBytes()); + + ConvertInfo.DestInfo.DestLayout = MatrixLayout; + ConvertInfo.DestInfo.DestStride = 0; + ConvertInfo.DestInfo.NumRows = static_cast(getNumVectors()); + ConvertInfo.DestInfo.NumColumns = static_cast(getVectorSize()); + + if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { + // Align to 16 bytes + ConvertInfo.DestInfo.DestStride = + (static_cast(getVectorSize()) * DestEltSize + 15) & ~15; + } else if (MatrixLayout == + D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { + // Align to 16 bytes + ConvertInfo.DestInfo.DestStride = + (static_cast(getNumVectors()) * DestEltSize + 15) & ~15; + } + + // Get destination size using preview interface + { + CComPtr PreviewDevice; + VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), + (void **)&PreviewDevice)); + + // Query required destination size + PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( + &ConvertInfo.DestInfo); + } + + return ConvertInfo; + } + + static TestVector + matrixVectorMultiply(const TestVector &Matrix, const TestVector &InputVector, + const TestVector &Bias, bool HasBias, + D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation, + D3D12_LINEAR_ALGEBRA_DATATYPE InputType) { + // The CPU reference matrix is FP32 for all FP interpretations. + bool IsMatrixFP32 = false; + switch (MatrixInterpretation) { + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: + IsMatrixFP32 = true; + break; + default: + break; + } + + TestVector ResultVec(InputVector.getNumVectors(), Matrix.getNumVectors(), + sizeof(float)); + + if (IsMatrixFP32) { + for (size_t VecIdx = 0; VecIdx < InputVector.getNumVectors(); ++VecIdx) { + const DirectX::PackedVector::HALF *InputBiasFP16 = + Bias.getVector(0); + for (size_t OutputIdx = 0; OutputIdx < Matrix.getNumVectors(); + ++OutputIdx) { + float Acc = 0; + + for (size_t InputIdx = 0; InputIdx < Matrix.getVectorSize(); + ++InputIdx) { + float InputElem; + if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) + InputElem = InputVector.getVector(VecIdx)[InputIdx]; + else + InputElem = ConvertFloat16ToFloat32( + InputVector.getVector( + VecIdx)[InputIdx]); + + float const MatrixElem = + Matrix.getVector(OutputIdx)[InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (HasBias) + Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]); + + ResultVec.getVector(VecIdx)[OutputIdx] = Acc; + } + } + } else if (MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { + for (size_t VecIdx = 0; VecIdx < InputVector.getNumVectors(); ++VecIdx) { + const int32_t *InputBiasI32 = Bias.getVector(0); + for (size_t OutputIdx = 0; OutputIdx < Matrix.getNumVectors(); + ++OutputIdx) { + int Acc = 0; + + for (size_t InputIdx = 0; InputIdx < Matrix.getVectorSize(); + ++InputIdx) { + int InputElem; + if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) + InputElem = static_cast( + InputVector.getVector(VecIdx)[InputIdx]); + else + InputElem = InputVector.getVector(VecIdx)[InputIdx]; + + int const MatrixElem = + Matrix.getVector(OutputIdx)[InputIdx]; + Acc += InputElem * MatrixElem; + } + + if (HasBias) + Acc += InputBiasI32[OutputIdx]; + + ResultVec.getVector(VecIdx)[OutputIdx] = + static_cast(Acc); + } + } + } else { + VERIFY_FAIL(L"Unsupported matrix interpretation"); + } + + return ResultVec; + } +}; }; // namespace CoopVecHelpers #endif // HAVE_COOPVEC_API diff --git a/tools/clang/unittests/HLSLExec/CoopVecAPI.h b/tools/clang/unittests/HLSLExec/CoopVecAPI.h index 16c1105edc..563366e0bc 100644 --- a/tools/clang/unittests/HLSLExec/CoopVecAPI.h +++ b/tools/clang/unittests/HLSLExec/CoopVecAPI.h @@ -145,18 +145,16 @@ ID3D12DevicePreview : public IUnknown #endif /* __ID3D12DevicePreview_INTERFACE_DEFINED__ */ -#ifndef __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ -#define __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ +#ifndef __ID3D12GraphicsCommandListPreview_INTERFACE_DEFINED__ +#define __ID3D12GraphicsCommandListPreview_INTERFACE_DEFINED__ -EXTERN_C const IID IID_ID3D12GraphicsCommandList11; +EXTERN_C const IID IID_ID3D12GraphicsCommandListPreview; -MIDL_INTERFACE("f0dcfabc-a84a-4fe3-b3b9-eab26b306c38") -ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 +MIDL_INTERFACE("536d9bb6-9eee-4c75-86e8-e29e29e08ed3") +ID3D12GraphicsCommandListPreview : public ID3D12GraphicsCommandList10 { public: virtual void STDMETHODCALLTYPE Reserved0() = 0; - virtual void STDMETHODCALLTYPE Reserved1() = 0; - virtual void STDMETHODCALLTYPE Reserved2() = 0; virtual void STDMETHODCALLTYPE ConvertLinearAlgebraMatrix( _In_ const D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO *pDesc, @@ -164,7 +162,7 @@ ID3D12GraphicsCommandList11 : public ID3D12GraphicsCommandList10 }; -#endif /* __ID3D12GraphicsCommandList11_INTERFACE_DEFINED__ */ +#endif /* __ID3D12GraphicsCommandListPreview_INTERFACE_DEFINED__ */ #else // __ID3D12GraphicsCommandList10_INTERFACE_DEFINED__ // The used d3d12.h header does not support ID3D12GraphicsCommandList10, diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 55d569dd8d..b54ebe6f95 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -758,7 +758,7 @@ class ExecutionTest { #endif } - bool UseDebugIfaces() { return false; } + bool UseDebugIfaces() { return true; } bool SaveImages() { return GetTestParamBool(L"SaveImages"); } @@ -786,10 +786,10 @@ class ExecutionTest { #if HAVE_COOPVEC_API struct CoopVecMulSubtestConfig { - int InputPerThread; - int OutputPerThread; - int NumThreads; - int NumLevels; + size_t InputPerThread; + size_t OutputPerThread; + size_t NumThreads; + size_t NumLayers; D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; bool Bias; }; @@ -799,12 +799,12 @@ class ExecutionTest { D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps); void runCoopVecMulSubtest(ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, - CoopVecMulSubtestConfig &Config); + CoopVecMulSubtestConfig &Config, bool RunCompute); struct CoopVecOuterProductSubtestConfig { - int DimM; // Row Count - int DimN; // Column Count - int NumThreads; + size_t DimM; // Row Count + size_t DimN; // Column Count + size_t NumThreads; D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout; }; @@ -814,7 +814,8 @@ class ExecutionTest { void runCoopVecOuterProductSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, - CoopVecOuterProductSubtestConfig &Config); + CoopVecOuterProductSubtestConfig &Config, bool RunCompute); + #endif // HAVE_COOPVEC_API template @@ -12013,13 +12014,27 @@ void ExecutionTest::runCoopVecMulTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { +#ifdef _HLK_CONF + LOG_ERROR_FMT_THROW( + L"Device does not support SM 6.9. Can't run these tests."); +#else + WEX::Logging::Log::Comment( + "Device does not support SM 6.9. Can't run these tests."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; +#endif } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { +#ifdef _HLK_CONF + LOG_ERROR_FMT_THROW( + L"Device does not support cooperative vectors. Can't run these tests."); +#else WEX::Logging::Log::Comment( - "Device does not support cooperative vector. Skipping."); + "Device does not support cooperative vectors. Can't run these tests."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; +#endif } // Query coopvec feature data. First call gets the size of the arrays. The @@ -12149,6 +12164,14 @@ void ExecutionTest::runCoopVecMulTestConfig( {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, @@ -12157,6 +12180,14 @@ void ExecutionTest::runCoopVecMulTestConfig( {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, @@ -12165,6 +12196,14 @@ void ExecutionTest::runCoopVecMulTestConfig( {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, false}, {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, @@ -12181,6 +12220,104 @@ void ExecutionTest::runCoopVecMulTestConfig( false}, {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, true}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + + // NumLayers=2 tests + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {16, 16, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {32, 8, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {17, 63, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {17, 63, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {1, 1, 16, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + false}, + {1, 1, 32, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL, + true}, }; for (auto Config : TestConfigs) { @@ -12194,38 +12331,74 @@ void ExecutionTest::runCoopVecMulTestConfig( continue; } + if (Config.NumLayers > 1) { + const bool IsPackedType = + MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + MulProps.InputInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 || + MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || + MulProps.InputInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED; + + const bool IsFullPrecisionIntegerBias = + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 || + MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32; + + if (IsPackedType && IsFullPrecisionIntegerBias) + // In the current framework this would require repacking the accumulator + // vectors in HLSL. + continue; + } + bool IsInFilter = CoopVecHelpers::IsMatrixLayoutInFilter( L"CoopVecMatrixLayout", Config.MatrixLayout); if (!IsInFilter) { continue; } - runCoopVecMulSubtest(D3DDevice, MulProps, Config); + // Run once as compute, then again as graphics (pixel shader) + runCoopVecMulSubtest(D3DDevice, MulProps, Config, true); + runCoopVecMulSubtest(D3DDevice, MulProps, Config, false); } } void ExecutionTest::runCoopVecMulSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_MUL &MulProps, - CoopVecMulSubtestConfig &Config) { + CoopVecMulSubtestConfig &Config, bool RunCompute) { + + std::mt19937 Rnd(0x42); LogCommentFmt( - L"Running test for InputPerThread: %d, OutputPerThread: %d, NumThreads: " - L"%d, NumLevels: %d, Bias: %s, MatrixLayout: %s", + L"Running test for InputPerThread: %zu, OutputPerThread: %zu, " + L"NumThreads: " + L"%zu, NumLayers: %zu, Bias: %s, MatrixLayout: %s, Stage: %s", Config.InputPerThread, Config.OutputPerThread, Config.NumThreads, - Config.NumLevels, Config.Bias ? L"true" : L"false", - CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + Config.NumLayers, Config.Bias ? L"true" : L"false", + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(), + RunCompute ? L"Compute" : L"Pixel"); - const int OutputBufferSize = (Config.OutputPerThread * Config.NumThreads * 4); + const size_t OutputBufferSize = + (Config.OutputPerThread * Config.NumThreads * 4); // Create root signature with a single root entry for all SRVs and UAVs CComPtr RootSignature; { CD3DX12_DESCRIPTOR_RANGE Ranges[2]; - Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 3, 0, - 0); // InputVector, InputMatrix, InputBias + Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, + 2 + static_cast(Config.NumLayers), 0, + 0); // InputVector, InputBias, InputMatrices[] Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // OutputBuffer - CreateRootSignatureFromRanges(D3DDevice, &RootSignature, Ranges, 2, nullptr, - 0); + + CD3DX12_ROOT_PARAMETER RootParams[2]; + RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges, + D3D12_SHADER_VISIBILITY_ALL); + RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0, + D3D12_SHADER_VISIBILITY_ALL); + + CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc; + RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr, + D3D12_ROOT_SIGNATURE_FLAG_NONE); + CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature); } // Create descriptor heap with space for 4 descriptors: 3 SRVs and 1 UAV @@ -12233,7 +12406,7 @@ void ExecutionTest::runCoopVecMulSubtest( { D3D12_DESCRIPTOR_HEAP_DESC Desc = {}; Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - Desc.NumDescriptors = 4; + Desc.NumDescriptors = 3 + static_cast(Config.NumLayers); Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; VERIFY_SUCCEEDED( D3DDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&DescriptorHeap))); @@ -12241,137 +12414,281 @@ void ExecutionTest::runCoopVecMulSubtest( CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); - // Create the compute pipeline state for the CoopVec shader - CComPtr ComputePipelineState; - { - std::string ShaderSource = R"( + // Our input matrix is really a set of row vectors, which we can represent + // as a TestVector. + std::vector<::CoopVecHelpers::TestVector> InputMatrices; + for (size_t I = 0; I < Config.NumLayers - 1; ++I) { + // Each layer except the last is InputPerThread x InputPerThread + InputMatrices.push_back( + ::CoopVecHelpers::TestVector::createSimpleTestMatrix( + Config.InputPerThread, Config.InputPerThread, + MulProps.MatrixInterpretation, Rnd)); + } + // Last layer, matrix size is OutputPerThread x InputPerThread + InputMatrices.push_back(::CoopVecHelpers::TestVector::createSimpleTestMatrix( + Config.OutputPerThread, Config.InputPerThread, + MulProps.MatrixInterpretation, Rnd)); + + auto InputVector = CoopVecHelpers::TestVector::createSimpleTestVector( + Config.NumThreads, Config.InputPerThread, MulProps.InputType, + MulProps.InputInterpretation, MulProps.MatrixInterpretation, Rnd); + auto InputBias = CoopVecHelpers::TestVector::createSimpleTestVector( + 1, std::max(Config.OutputPerThread, Config.InputPerThread), + MulProps.BiasInterpretation, MulProps.BiasInterpretation, + MulProps.MatrixInterpretation, Rnd); + + // Calculate reference output + auto ExpectedOutput = InputVector; + for (size_t I = 0; I < Config.NumLayers; ++I) { + ExpectedOutput = ::CoopVecHelpers::TestVector::matrixVectorMultiply( + InputMatrices[I], ExpectedOutput, InputBias, Config.Bias, + MulProps.MatrixInterpretation, + I == 0 ? MulProps.InputType : D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32); + } + + std::string ShaderSource = R"( #include "dx/linalg.h" ByteAddressBuffer InputVector : register(t0); ByteAddressBuffer InputBias : register(t1); -ByteAddressBuffer InputMatrix : register(t2); +ByteAddressBuffer InputMatrix[NUM_LAYERS] : register(t2); RWByteAddressBuffer OutputBuffer: register(u0); -[shader("compute")] -[numthreads(NUM_THREADS, 1, 1)] -void main(uint threadIdx : SV_GroupThreadID) +RWStructuredBuffer AtomicCounter : register(u10); + +#if USE_GROUPSHARED +groupshared vector inputGS[NUM_THREADS]; +groupshared vector outputGS[NUM_THREADS]; +#endif + +void RunCoopVecTest(uint threadIdx) { using namespace dx::linalg; - // Ensure 4-byte alignment for vector loads - uint inputOffset = (INPUT_PER_THREAD * threadIdx * (sizeof(INPUT_DATA_TYPE) / INPUT_DIVISOR)); - inputOffset = (inputOffset + 3) & ~3; // Align to 4 bytes - vector input = InputVector.Load >(inputOffset); + uint inputOffset = (threadIdx * INPUT_VECTOR_STRIDE); + vector input = InputVector.Load >(inputOffset); + +#if USE_GROUPSHARED + // Use groupshared memory to grab the "next" thread's input vector. + inputGS[threadIdx] = input; + GroupMemoryBarrierWithGroupSync(); + input = inputGS[(threadIdx + 1) % NUM_THREADS]; +#endif - MatrixRef mat = { InputMatrix, 0, STRIDE }; + VectorRef biasVec = { InputBias, 0 }; + vector output; +)"; - vector accum; + if (Config.NumLayers == 1) { + ShaderSource += R"( + MatrixRef mat = { InputMatrix[0], 0, STRIDE0 }; if (USE_BIAS) { - VectorRef biasVec = { InputBias, 0 }; - accum = MulAdd(mat, MakeInterpretedVector(input), biasVec); + output = MulAdd(mat, MakeInterpretedVector(input), biasVec); } else { - accum = Mul(mat, MakeInterpretedVector(input)); + output = Mul(mat, MakeInterpretedVector(input)); } +)"; + } else if (Config.NumLayers == 2) { + ShaderSource += R"( + vector accum; - vector result = (vector)accum; + MatrixRef mat0 = { InputMatrix[0], 0, STRIDE0 }; + if (USE_BIAS) { + accum = MulAdd(mat0, MakeInterpretedVector(input), biasVec); + } else { + accum = Mul(mat0, MakeInterpretedVector(input)); + } + + // Dummy activation function; all of our intermediates above -10000 + accum = max(accum, -10000); + + MatrixRef mat1 = { InputMatrix[1], 0, STRIDE1 }; + if (USE_BIAS) { + output = MulAdd(mat1, MakeInterpretedVector(accum), biasVec); + } else { + output = Mul(mat1, MakeInterpretedVector(accum)); + } +)"; + } + + ShaderSource += R"( + vector result = (vector)output; + +#if USE_GROUPSHARED + // Use groupshared memory to grab the "previous" thread's output vector. + outputGS[threadIdx] = result; + GroupMemoryBarrierWithGroupSync(); + result = outputGS[(threadIdx + NUM_THREADS - 1) % NUM_THREADS]; +#endif // Ensure 4-byte alignment for vector store uint outputOffset = OUTPUT_PER_THREAD * threadIdx * sizeof(float); - outputOffset = (outputOffset + 3) & ~3; // Align to 4 bytes OutputBuffer.Store >(outputOffset, result); } - )"; - auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { - std::wstringstream Stream; - Stream << L"-D" << Name << L"=" << Value; - return Stream.str(); - }; +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ + RunCoopVecTest(threadIdx); +} + +float4 vs_main(uint vid : SV_VertexID) : SV_Position { + switch (vid) { + case 0: + return float4(-1, 1, 0, 1); + case 1: + return float4(3, 1, 0, 1); + case 2: + return float4(-1, -3, 0, 1); + } + return float4(0, 0, 0, 0); +} + +float4 ps_main() : SV_Target { + uint threadIdx; + InterlockedAdd(AtomicCounter[0], 1, threadIdx); + // threadIdx may exceed NUM_THREADS, but bounds checking on the vector + // loads/stores will prevent any faults from occurring. This lets us + // exercise the CoopVec implementation on more threads, giving us + // further confidence that there are no bad interactions between "good" + // threads and threads that fail bounds checking and operate on all-zero + // input data. This also gives us some additional testing of long vector + // bounds-checking. + RunCoopVecTest(threadIdx); + return float4(1, 1, 1, 1); +} +)"; + + auto CreateDefineFromSize = [](const wchar_t *Name, size_t Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; - auto CreateDefineFromString = [](const wchar_t *Name, - const std::wstring &Value) { - std::wstringstream Stream; - Stream << L"-D" << Name << L"=" << Value; - return Stream.str(); - }; + auto CreateDefineFromString = [](const wchar_t *Name, + const std::wstring &Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; - int Stride = 0; - const std::wstring HlslMatrixLayout = - CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); - int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( - MulProps.MatrixInterpretation); - switch (Config.MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - Stride = Config.InputPerThread * StrideMultiplier; - break; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - Stride = Config.OutputPerThread * StrideMultiplier; - break; - } + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + const size_t InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + MulProps.InputInterpretation); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType); + const std::wstring AccumDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.MatrixInterpretation); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.InputInterpretation); + const std::wstring BiasInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + MulProps.BiasInterpretation); + + auto InputPerThreadDefine = + CreateDefineFromSize(L"INPUT_PER_THREAD", Config.InputPerThread); + auto OutputPerThreadDefine = + CreateDefineFromSize(L"OUTPUT_PER_THREAD", Config.OutputPerThread); + auto NumThreadsDefine = + CreateDefineFromSize(L"NUM_THREADS", Config.NumThreads); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); + auto InputDivisorDefine = CreateDefineFromSize( + L"INPUT_VECTOR_NUM_ELEMENTS", + (Config.InputPerThread + InputDivisor - 1) / InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); + auto MatrixDataTypeEnumDefine = + CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); + auto UseBiasDefine = CreateDefineFromSize(L"USE_BIAS", Config.Bias ? 1 : 0); + // Treat the accumulator interpretation the same as the input interpretation + // for the purposes of MakeInterpretedVector. + auto AccumInterpretationEnumDefine = CreateDefineFromString( + L"ACCUM_INTERPRETATION_ENUM", InputInterpretationEnum); + auto InputVectorStrideDefine = + CreateDefineFromSize(L"INPUT_VECTOR_STRIDE", InputVector.getStride()); + auto NumLayersDefine = CreateDefineFromSize(L"NUM_LAYERS", Config.NumLayers); + auto BiasInterpretationEnumDefine = CreateDefineFromString( + L"BIAS_INTERPRETATION_ENUM", BiasInterpretationEnum); + auto UseGroupsharedDefine = + CreateDefineFromSize(L"USE_GROUPSHARED", RunCompute ? 1 : 0); + + std::vector Options = { + L"-enable-16bit-types", + InputPerThreadDefine.c_str(), + OutputPerThreadDefine.c_str(), + NumThreadsDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + UseBiasDefine.c_str(), + AccumInterpretationEnumDefine.c_str(), + InputVectorStrideDefine.c_str(), + NumLayersDefine.c_str(), + BiasInterpretationEnumDefine.c_str(), + UseGroupsharedDefine.c_str(), + }; - const int InputDivisor = - CoopVecHelpers::GetNumPackedElementsForInputDataType( - MulProps.InputInterpretation); - const std::wstring InputDataType = - CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.InputType); - const std::wstring AccumDataType = - CoopVecHelpers::GetHlslDataTypeForDataType(MulProps.BiasInterpretation); - const std::wstring MatrixDataTypeEnum = - CoopVecHelpers::GetHlslInterpretationForDataType( - MulProps.MatrixInterpretation); - const std::wstring InputInterpretationEnum = - CoopVecHelpers::GetHlslInterpretationForDataType( - MulProps.InputInterpretation); - const std::wstring AccumInterpretationEnum = - CoopVecHelpers::GetHlslInterpretationForDataType( - MulProps.BiasInterpretation); - - auto InputPerThreadDefine = - CreateDefineFromInt(L"INPUT_PER_THREAD", Config.InputPerThread); - auto OutputPerThreadDefine = - CreateDefineFromInt(L"OUTPUT_PER_THREAD", Config.OutputPerThread); - auto NumThreadsDefine = - CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); - auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); - auto InputDataTypeDefine = - CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType); - auto InputDivisorDefine = - CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); - auto AccumDataTypeDefine = - CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType); - auto InputInterpretationEnumDefine = CreateDefineFromString( - L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum); - auto HlslMatrixLayoutDefine = - CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout); - auto MatrixDataTypeEnumDefine = - CreateDefineFromString(L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum); - auto UseBiasDefine = CreateDefineFromInt(L"USE_BIAS", Config.Bias ? 1 : 0); - auto AccumInterpretationEnumDefine = CreateDefineFromString( - L"ACCUM_INTERPRETATION_ENUM", AccumInterpretationEnum); - - LPCWSTR Options[] = { - L"-enable-16bit-types", - InputPerThreadDefine.c_str(), - OutputPerThreadDefine.c_str(), - NumThreadsDefine.c_str(), - StrideDefine.c_str(), - InputDataTypeDefine.c_str(), - InputDivisorDefine.c_str(), - AccumDataTypeDefine.c_str(), - InputInterpretationEnumDefine.c_str(), - HlslMatrixLayoutDefine.c_str(), - MatrixDataTypeEnumDefine.c_str(), - UseBiasDefine.c_str(), - AccumInterpretationEnumDefine.c_str(), - }; + std::vector StrideDefines; + for (size_t I = 0; I < Config.NumLayers; ++I) { + auto ConvertInfo = InputMatrices[I].getConversionInfo( + D3DDevice, MulProps.MatrixInterpretation, Config.MatrixLayout); + wchar_t StrideName[16]; + swprintf(StrideName, _countof(StrideName), L"STRIDE%zu", I); + StrideDefines.push_back( + CreateDefineFromSize(StrideName, ConvertInfo.DestInfo.DestStride)); + Options.push_back(StrideDefines[I].c_str()); + } + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); - CComPtr IncludeHandler = - new LinAlgHeaderIncludeHandler(m_support); + // Create the pipeline state for the CoopVec shaders + CComPtr PipelineState; + if (RunCompute) { CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", - &ComputePipelineState, Options, _countof(Options), + &PipelineState, Options.data(), (int)Options.size(), IncludeHandler); + } else { + CComPtr VertexShader; + CComPtr PixelShader; + + CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader, + Options.data(), (int)Options.size(), IncludeHandler); + CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader, + Options.data(), (int)Options.size(), IncludeHandler); + + D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {}; + PsoDesc.pRootSignature = RootSignature; + PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader); + PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader); + PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT); + PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT); + PsoDesc.DepthStencilState.DepthEnable = FALSE; + PsoDesc.DepthStencilState.StencilEnable = FALSE; + PsoDesc.SampleMask = UINT_MAX; + PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; + PsoDesc.NumRenderTargets = 1; + PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM; + PsoDesc.SampleDesc.Count = 1; + VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState( + &PsoDesc, IID_PPV_ARGS(&PipelineState))); } // Create a command list for the compute shader. @@ -12385,314 +12702,89 @@ void main(uint threadIdx : SV_GroupThreadID) VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, PipelineState, IID_PPV_ARGS(&CommandList))); - // Setup input data - auto ExpectedOutputBuffer = - std::make_unique(Config.OutputPerThread * Config.NumThreads); - - // Setup input matrix as all-ones in sint8 format. This will later be - // converted to the appropriate data type by the matrix conversion API. - CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; - std::vector InputMatrix; - if (MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || - MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( - Config.InputPerThread, Config.OutputPerThread); - } else if (MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - // Matrix source data is fp32, which gets converted to fp16 during matrix - // conversion - InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix( - Config.InputPerThread, Config.OutputPerThread); - } else { - WEX::Logging::Log::Error(L"Unsupported matrix data type"); - return; + std::vector> InputMatrixSRVResources( + Config.NumLayers); + std::vector> InputMatrixSRVUploadResources( + Config.NumLayers); + for (size_t I = 0; I < Config.NumLayers; ++I) { + CreateTestResources( + D3DDevice, CommandList, InputMatrices[I].getBuffer(), + InputMatrices[I].getTotalBytes(), + CD3DX12_RESOURCE_DESC::Buffer(InputMatrices[I].getTotalBytes()), + &InputMatrixSRVResources[I], &InputMatrixSRVUploadResources[I]); } - CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), - InputMatrix.size(), - CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), - &InputMatrixSRVResource, &InputMatrixSRVUploadResource); - // Create input vector of an appropriate type. All integer types start as // SINT8 for now. CComPtr InputVecSRVResource, InputVecSRVUploadResource; - std::vector InputVector; - - if ((MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32 && - (MulProps.InputInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.InputInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED)) || - MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || - MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - InputVector = CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.InputPerThread); - } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - InputVector = - CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.InputPerThread); - } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - InputVector = CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.InputPerThread); - } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - InputVector = CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.InputPerThread); - } else if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - InputVector = CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.InputPerThread); - } else { - WEX::Logging::Log::Error(L"Unsupported input data type"); - return; - } - if (InputVector.size() % 4 != 0) { - // Align size to 4 bytes for ByteAddressBuffer - InputVector.resize(InputVector.size() + 4 - (InputVector.size() % 4)); - } - CreateTestResources(D3DDevice, CommandList, InputVector.data(), - InputVector.size(), - CD3DX12_RESOURCE_DESC::Buffer(InputVector.size()), - &InputVecSRVResource, &InputVecSRVUploadResource); + + CreateTestResources( + D3DDevice, CommandList, InputVector.getBuffer(), + InputVector.getTotalBytes(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector.getTotalBytes()), + &InputVecSRVResource, &InputVecSRVUploadResource); // This increments baseHandle CreateRawSRV(D3DDevice, BaseHandle, - (UINT)(InputVector.size() / sizeof(int32_t)), + static_cast(InputVector.getTotalBytes() / sizeof(int32_t)), InputVecSRVResource); // Create input bias CComPtr InputBiasSRVResource, InputBiasSRVUploadResource; - std::vector InputBias; - - if (MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || - MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED || - MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || - MulProps.BiasInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32) { - InputBias = - CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32) { - InputBias = - CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - InputBias = CoopVecHelpers::CreateInputBias( - Config.OutputPerThread); - } else if (MulProps.BiasInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - InputBias = CoopVecHelpers::CreateInputBias(Config.OutputPerThread); - } else { - WEX::Logging::Log::Error(L"Unsupported bias data type"); - return; - } - if (InputBias.size() % 4 != 0) { - // Align size to 4 bytes for ByteAddressBuffer - InputBias.resize(InputBias.size() + 4 - (InputBias.size() % 4)); - } - CreateTestResources(D3DDevice, CommandList, InputBias.data(), - InputBias.size(), - CD3DX12_RESOURCE_DESC::Buffer(InputBias.size()), + CreateTestResources(D3DDevice, CommandList, InputBias.getBuffer(), + InputBias.getTotalBytes(), + CD3DX12_RESOURCE_DESC::Buffer(InputBias.getTotalBytes()), &InputBiasSRVResource, &InputBiasSRVUploadResource); // This increments baseHandle CreateRawSRV(D3DDevice, BaseHandle, - (UINT)(InputBias.size() / sizeof(int32_t)), + static_cast(InputBias.getTotalBytes() / sizeof(int32_t)), InputBiasSRVResource); - // Calculate reference output - // FIXME: This does not capture all cases, but is sufficient for the preview - // feature set - if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) { - // The input bias is really an array of int32_t - std::vector InputBiasI32(InputBias.size() / sizeof(int32_t)); - std::memcpy(InputBiasI32.data(), InputBias.data(), InputBias.size()); - - // The input vector is really an array of float if our vector input type is - // FLOAT32 - std::vector InputVectorF32(InputVector.size() / sizeof(int32_t)); - if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - std::memcpy(InputVectorF32.data(), InputVector.data(), - InputVector.size()); - } - - for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { - for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { - int Acc = 0; - - for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { - int InputElem; - if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - InputElem = (int) - InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx]; - } else { - InputElem = - InputVector[ThreadIdx * Config.InputPerThread + InputIdx]; - } - int const MatrixElem = - InputMatrix[OutputIdx * Config.InputPerThread + InputIdx]; - Acc += InputElem * MatrixElem; - } - - if (Config.Bias) { - Acc += InputBiasI32[OutputIdx]; - } - - float Result = float(Acc); - ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = - Result; - } - } - } else if (MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - MulProps.MatrixInterpretation == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - // The input bias/vector is really an array of float16 - std::vector InputVectorFP16( - InputVector.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(InputVectorFP16.data(), InputVector.data(), InputVector.size()); - - std::vector InputBiasFP16( - InputBias.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(InputBiasFP16.data(), InputBias.data(), InputBias.size()); - - // The CPU reference matrix is float - std::vector InputMatrixFP32(InputMatrix.size() / sizeof(float)); - std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size()); - - for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { - for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) { - float Acc = 0; - - for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) { - float const InputElem = ConvertFloat16ToFloat32( - InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]); - float const MatrixElem = - InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx]; - Acc += InputElem * MatrixElem; - } - - if (Config.Bias) { - Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]); - } - - float Result = Acc; - ExpectedOutputBuffer[ThreadIdx * Config.OutputPerThread + OutputIdx] = - Result; - } - } - } - - CComPtr ConvertedMatrixResource; - { - // Create source matrix info - D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {}; - ConvertInfo.SrcInfo.SrcDataType = - CoopVecHelpers::GetMatrixSrcDataType(MulProps.MatrixInterpretation); - ConvertInfo.SrcInfo.SrcLayout = - D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; - - // Create destination matrix info - ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver - int SrcEltSize = 0; - int DestEltSize = 0; - switch (MulProps.MatrixInterpretation) { - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: - case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: - ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - SrcEltSize = 1; - DestEltSize = 1; - break; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: - ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16; - SrcEltSize = 4; // FP32 - DestEltSize = 2; // FP16 - break; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: - ConvertInfo.DestInfo.DestDataType = - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3; - SrcEltSize = 4; // FP32 - DestEltSize = 1; // FP8 - break; - case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: - ConvertInfo.DestInfo.DestDataType = - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2; - SrcEltSize = 4; // FP32 - DestEltSize = 1; // FP8 - break; - } - ConvertInfo.SrcInfo.SrcStride = Config.InputPerThread * SrcEltSize; - ConvertInfo.SrcInfo.SrcSize = - Config.InputPerThread * Config.OutputPerThread * SrcEltSize; - - ConvertInfo.DestInfo.DestLayout = Config.MatrixLayout; - ConvertInfo.DestInfo.DestStride = 0; - ConvertInfo.DestInfo.NumRows = Config.OutputPerThread; - ConvertInfo.DestInfo.NumColumns = Config.InputPerThread; + // Create converted matrix resource and SRV for each input matrix + std::vector> ConvertedMatrixResources( + Config.NumLayers); + for (size_t I = 0; I < Config.NumLayers; ++I) { + auto ConvertInfo = InputMatrices[I].getConversionInfo( + D3DDevice, MulProps.MatrixInterpretation, Config.MatrixLayout); - if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { - ConvertInfo.DestInfo.DestStride = Config.InputPerThread * DestEltSize; - } else if (Config.MatrixLayout == - D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { - ConvertInfo.DestInfo.DestStride = Config.OutputPerThread * DestEltSize; - } - - // Get destination size using preview interface - { - CComPtr PreviewDevice; - VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview), - (void **)&PreviewDevice)); - - // Query required destination size - PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo( - &ConvertInfo.DestInfo); - } + UINT SRVSize = (ConvertInfo.DestInfo.DestSize + 15) / 16 * 16; // Create resource to hold matrix copy - CreateTestResources( - D3DDevice, CommandList, nullptr, 0, - CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize), - &ConvertedMatrixResource, nullptr); + CreateTestResources(D3DDevice, CommandList, nullptr, SRVSize, + CD3DX12_RESOURCE_DESC::Buffer(SRVSize), + &ConvertedMatrixResources[I], nullptr); // Set up data descriptors ConvertInfo.DataDesc.DestVA = - ConvertedMatrixResource->GetGPUVirtualAddress(); - ConvertInfo.DataDesc.SrcVA = InputMatrixSRVResource->GetGPUVirtualAddress(); + ConvertedMatrixResources[I]->GetGPUVirtualAddress(); + ConvertInfo.DataDesc.SrcVA = + InputMatrixSRVResources[I]->GetGPUVirtualAddress(); // Get command list interface and perform conversion - CComPtr CommandList11; - VERIFY_SUCCEEDED(CommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); - CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + CComPtr CommandListPreview; + VERIFY_SUCCEEDED( + CommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandListPreview), + (void **)&CommandListPreview)); + CommandListPreview->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); - // This increments baseHandle - if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { - WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes"); - return; - } - CreateRawSRV(D3DDevice, BaseHandle, - ConvertInfo.DestInfo.DestSize / sizeof(int32_t), - ConvertedMatrixResource); + // This increments BaseHandle + CreateRawSRV(D3DDevice, BaseHandle, SRVSize / sizeof(int32_t), + ConvertedMatrixResources[I]); } + // Create resource for atomic counter + CComPtr AtomicCounterResource; + uint32_t AtomicCounterInit = 0; + CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit, + sizeof(AtomicCounterInit), + CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)), + &AtomicCounterResource, nullptr); + CComPtr UavResource; CComPtr UavUploadResource; CComPtr UavReadResource; @@ -12705,23 +12797,68 @@ void main(uint threadIdx : SV_GroupThreadID) CreateTestUavs(D3DDevice, CommandList, OutputBufferInit.data(), OutputBufferSize, &UavResource, &UavUploadResource, &UavReadResource); - CreateRawUAV(D3DDevice, BaseHandle, OutputBufferSize / 4, UavResource); + CreateRawUAV(D3DDevice, BaseHandle, static_cast(OutputBufferSize / 4), + UavResource); CommandList->Close(); ExecuteCommandList(CommandQueue, CommandList); WaitForSignal(CommandQueue, FO); VERIFY_SUCCEEDED(CommandAllocator->Reset()); - VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, PipelineState)); SetDescriptorHeap(CommandList, DescriptorHeap); CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); - CommandList->SetComputeRootSignature(RootSignature); - CommandList->SetComputeRootDescriptorTable(0, ResHandle); - CommandList->SetPipelineState(ComputePipelineState); - CommandList->Dispatch(1, 1, 1); + CComPtr RtvHeap; + CComPtr RenderTarget; + CComPtr RenderTargetRead; + + if (RunCompute) { + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(PipelineState); + CommandList->Dispatch(1, 1, 1); + } else { + UINT FrameCount = 1; + UINT RtvDescSize = 0; + CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize); + CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget, + &RenderTargetRead); + + D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc(); + D3D12_VIEWPORT Viewport; + D3D12_RECT ScissorRect; + + memset(&Viewport, 0, sizeof(Viewport)); + Viewport.Height = static_cast(RtDesc.Height); + Viewport.Width = static_cast(RtDesc.Width); + Viewport.MaxDepth = 1.0f; + memset(&ScissorRect, 0, sizeof(ScissorRect)); + ScissorRect.right = static_cast(RtDesc.Width); + ScissorRect.bottom = static_cast(RtDesc.Height); + CommandList->SetGraphicsRootSignature(RootSignature); + CommandList->SetGraphicsRootDescriptorTable(0, ResHandle); + CommandList->SetGraphicsRootUnorderedAccessView( + 1, AtomicCounterResource->GetGPUVirtualAddress()); + CommandList->RSSetViewports(1, &Viewport); + CommandList->RSSetScissorRects(1, &ScissorRect); + + // Indicate that the buffer will be used as a render target. + RecordTransitionBarrier(CommandList, RenderTarget, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_RESOURCE_STATE_RENDER_TARGET); + + CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle( + RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize); + CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr); + + CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr); + CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST); + CommandList->DrawInstanced(3, 1, 0, 0); + } + RecordTransitionBarrier(CommandList, UavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); @@ -12731,18 +12868,36 @@ void main(uint threadIdx : SV_GroupThreadID) WaitForSignal(CommandQueue, FO); { - MappedData MappedData(UavReadResource, OutputBufferSize); + MappedData MappedData(UavReadResource, static_cast(OutputBufferSize)); - float *ResultBuffer = (float *)MappedData.data(); + float *ResultBuffer = reinterpret_cast(MappedData.data()); bool Equal = true; - for (int i = 0; i < OutputBufferSize / sizeof(float); i++) { - if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || - fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { - LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, - ResultBuffer[i], i, ExpectedOutputBuffer[i]); - Equal = false; - break; + + float MaxError = 0.00001f; + if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + // Allow for more error in fp16 relative to the fp32 reference + MaxError = 0.1f; + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3) { + // And even more error for the fp8 formats + MaxError = 1.0f; + } else if (MulProps.MatrixInterpretation == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + MaxError = 3.0f; + } + + for (size_t i = 0; i < Config.NumThreads && Equal; ++i) { + for (size_t j = 0; j < Config.OutputPerThread; ++j) { + float Result = ResultBuffer[i * Config.OutputPerThread + j]; + float Expected = ExpectedOutput.getVector(i)[j]; + if (isnan(Result) || isnan(Expected) || + fabs(Result - Expected) > MaxError) { + LogErrorFmt(L"Result mismatch at vector %zu, element %zu", i, j); + LogErrorFmt(L"Result: %f, Expected: %f", Result, Expected); + Equal = false; + break; + } } } VERIFY_IS_TRUE(Equal); @@ -12766,13 +12921,27 @@ void ExecutionTest::runCoopVecOuterProductTest() { // Create device and verify coopvec support CComPtr D3DDevice; if (!CreateDevice(&D3DDevice, D3D_SHADER_MODEL_6_9)) { +#ifdef _HLK_CONF + LOG_ERROR_FMT_THROW( + L"Device does not support SM 6.9. Can't run these tests."); +#else + WEX::Logging::Log::Comment( + "Device does not support SM 6.9. Can't run these tests."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; +#endif } + if (!DoesDeviceSupportCooperativeVector(D3DDevice)) { +#ifdef _HLK_CONF + LOG_ERROR_FMT_THROW( + L"Device does not support cooperative vectors. Can't run these tests."); +#else WEX::Logging::Log::Comment( - "Device does not support cooperative vector. Skipping."); + "Device does not support cooperative vectors. Can't run these tests."); WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); return; +#endif } // Query coopvec feature data. First call gets the size of the arrays. The @@ -12810,7 +12979,14 @@ void ExecutionTest::runCoopVecOuterProductTestConfig( .c_str()); constexpr CoopVecOuterProductSubtestConfig TestConfigs[] = { - {4, 4, 2, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {4, 4, 16, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {4, 4, 32, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {16, 16, 16, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {16, 16, 32, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {32, 32, 16, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {32, 32, 32, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {64, 64, 16, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, + {64, 64, 32, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL}, }; for (auto Config : TestConfigs) { @@ -12824,29 +13000,44 @@ void ExecutionTest::runCoopVecOuterProductTestConfig( continue; } - runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config); + // Run once in compute, then once in graphics (pixel shader) + runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, true); + runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, false); } } void ExecutionTest::runCoopVecOuterProductSubtest( ID3D12Device *D3DDevice, D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps, - CoopVecOuterProductSubtestConfig &Config) { + CoopVecOuterProductSubtestConfig &Config, bool RunCompute) { + + std::mt19937 Rnd(0x42); LogCommentFmt( - L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s", + L"Running test for DimM: %zu, DimN: %zu, NumThreads: %zu, MatrixLayout: " + L"%s, " + L"Stage: %s", Config.DimM, Config.DimN, Config.NumThreads, - CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str()); + CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(), + RunCompute ? L"Compute" : L"Pixel"); // Create root signature with a single root entry for all SRVs and UAVs CComPtr RootSignature; { - CD3DX12_DESCRIPTOR_RANGE ranges[2]; - ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, - 0); // InputVector1, InputVector2 - ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix - CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr, - 0); + CD3DX12_DESCRIPTOR_RANGE Ranges[2]; + Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0); + Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); + + CD3DX12_ROOT_PARAMETER RootParams[2]; + RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges, + D3D12_SHADER_VISIBILITY_ALL); + RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0, + D3D12_SHADER_VISIBILITY_ALL); + + CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc; + RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr, + D3D12_ROOT_SIGNATURE_FLAG_NONE); + CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature); } // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV @@ -12862,119 +13053,242 @@ void ExecutionTest::runCoopVecOuterProductSubtest( CD3DX12_CPU_DESCRIPTOR_HANDLE BaseHandle( DescriptorHeap->GetCPUDescriptorHandleForHeapStart()); - // Create a compute pipeline state object. - CComPtr ComputePipelineState; - { - std::string ShaderSource = R"( + // Setup input matrix as all-ones in sint8/fp32 format. This will later be + // converted to the appropriate data type by the matrix conversion API. + + std::vector InputMatrix; + if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || + AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else if (AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || + AccumulateProps.AccumulationType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { + // Matrix source data is fp32, which gets converted to fp16 during matrix + // conversion + InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, + Config.DimM); + } else { + WEX::Logging::Log::Comment(L"Unsupported matrix data type"); + return; + } + + // Create input vectors + auto InputVector1 = CoopVecHelpers::TestVector::createSimpleTestVector( + Config.NumThreads, Config.DimM, AccumulateProps.InputType, + AccumulateProps.InputType, AccumulateProps.AccumulationType, Rnd); + auto InputVector2 = CoopVecHelpers::TestVector::createSimpleTestVector( + Config.NumThreads, Config.DimN, AccumulateProps.InputType, + AccumulateProps.InputType, AccumulateProps.AccumulationType, Rnd); + + // Calculate reference output + auto ExpectedOutputBufferI8 = + CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); + std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / + sizeof(float)); + std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), + ExpectedOutputBufferI8.size()); + + if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { + for (size_t ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + auto *InputVector1FP16 = + InputVector1.getVector(ThreadIdx); + auto *InputVector2FP16 = + InputVector2.getVector(ThreadIdx); + for (size_t M = 0; M < Config.DimM; ++M) { + for (size_t N = 0; N < Config.DimN; ++N) { + float acc = ConvertFloat16ToFloat32(InputVector1FP16[M]) * + ConvertFloat16ToFloat32(InputVector2FP16[N]); + ExpectedOutputBuffer[M * Config.DimN + N] += acc; + } + } + } + } else if (AccumulateProps.InputType == + D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { + for (size_t ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { + auto *InputVector1FP32 = InputVector1.getVector(ThreadIdx); + auto *InputVector2FP32 = InputVector2.getVector(ThreadIdx); + for (size_t M = 0; M < Config.DimM; ++M) { + for (size_t N = 0; N < Config.DimN; ++N) { + float Acc = InputVector1FP32[M] * InputVector2FP32[N]; + ExpectedOutputBuffer[M * Config.DimN + N] += Acc; + } + } + } + } else { + WEX::Logging::Log::Comment(L"Unsupported input data type"); + return; + } + + std::string ShaderSource = R"( #include "dx/linalg.h" ByteAddressBuffer InputVector1 : register(t0); ByteAddressBuffer InputVector2 : register(t1); RWByteAddressBuffer AccumMatrix : register(u0); -[shader("compute")] -[numthreads(NUM_THREADS, 1, 1)] -void main(uint threadIdx : SV_GroupThreadID) +RWStructuredBuffer AtomicCounter : register(u10); + +void RunCoopVecTest(uint threadIdx) { -#if 1 using namespace dx::linalg; // Ensure 4-byte alignment for vector loads - uint inputOffset1 = (DIM_M * threadIdx * sizeof(INPUT_DATA_TYPE)); - inputOffset1 = (inputOffset1 + 3) & ~3; // Align to 4 bytes + uint inputOffset1 = threadIdx * INPUT_VECTOR_1_STRIDE; vector input1 = InputVector1.Load >(inputOffset1); - uint inputOffset2 = (DIM_N * threadIdx * sizeof(INPUT_DATA_TYPE)); - inputOffset2 = (inputOffset2 + 3) & ~3; // Align to 4 bytes + uint inputOffset2 = threadIdx * INPUT_VECTOR_2_STRIDE; vector input2 = InputVector2.Load >(inputOffset2); RWMatrixRef mat = { AccumMatrix, 0, STRIDE }; OuterProductAccumulate(input1, input2, mat); -#endif } - )"; - auto CreateDefineFromInt = [](const wchar_t *Name, int Value) { - std::wstringstream Stream; - Stream << L"-D" << Name << L"=" << Value; - return Stream.str(); - }; +[shader("compute")] +[numthreads(NUM_THREADS, 1, 1)] +void main(uint threadIdx : SV_GroupThreadID) +{ + RunCoopVecTest(threadIdx); +} - auto CreateDefineFromString = [](const wchar_t *Name, - const wchar_t *Value) { - std::wstringstream Stream; - Stream << L"-D" << Name << L"=" << Value; - return Stream.str(); - }; +float4 vs_main(uint vid : SV_VertexID) : SV_Position { + switch (vid) { + case 0: + return float4(-1, 1, 0, 1); + case 1: + return float4(3, 1, 0, 1); + case 2: + return float4(-1, -3, 0, 1); + } + return float4(0, 0, 0, 0); +} - int Stride = 0; - const std::wstring HlslMatrixLayout = - CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); - int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType( - AccumulateProps.AccumulationType); - switch (Config.MatrixLayout) { - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: - Stride = Config.DimN * StrideMultiplier; - break; - case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: - Stride = Config.DimM * StrideMultiplier; - break; - } +float4 ps_main() : SV_Target { + uint threadIdx; + InterlockedAdd(AtomicCounter[0], 1, threadIdx); + if (threadIdx < NUM_THREADS) + RunCoopVecTest(threadIdx); + return float4(1, 1, 1, 1); +} +)"; - const int InputDivisor = - CoopVecHelpers::GetNumPackedElementsForInputDataType( - AccumulateProps.InputType); - const std::wstring InputDataType = - CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType); - const std::wstring AccumDataType = - CoopVecHelpers::GetHlslDataTypeForDataType( - AccumulateProps.AccumulationType); - const std::wstring MatrixDataTypeEnum = - CoopVecHelpers::GetHlslInterpretationForDataType( - AccumulateProps.AccumulationType); - const std::wstring InputInterpretationEnum = - CoopVecHelpers::GetHlslInterpretationForDataType( - AccumulateProps.InputType); - - auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM); - auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN); - auto NumThreadsDefine = - CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads); - auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride); - auto InputDataTypeDefine = - CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); - auto InputDivisorDefine = - CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor); - auto AccumDataTypeDefine = - CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); - auto InputInterpretationEnumDefine = CreateDefineFromString( - L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); - auto HlslMatrixLayoutDefine = - CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); - auto MatrixDataTypeEnumDefine = CreateDefineFromString( - L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); - - LPCWSTR Options[] = { - L"-enable-16bit-types", - DimMDefine.c_str(), - DimNDefine.c_str(), - NumThreadsDefine.c_str(), - StrideDefine.c_str(), - InputDataTypeDefine.c_str(), - InputDivisorDefine.c_str(), - AccumDataTypeDefine.c_str(), - InputInterpretationEnumDefine.c_str(), - HlslMatrixLayoutDefine.c_str(), - MatrixDataTypeEnumDefine.c_str(), - }; + auto CreateDefineFromSize = [](const wchar_t *Name, size_t Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; - CComPtr IncludeHandler = - new LinAlgHeaderIncludeHandler(m_support); + auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) { + std::wstringstream Stream; + Stream << L"-D" << Name << L"=" << Value; + return Stream.str(); + }; + + size_t Stride = 0; + const std::wstring HlslMatrixLayout = + CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout); + size_t StrideMultiplier = + CoopVecHelpers::GetStrideMultiplierForMatrixDataType( + AccumulateProps.AccumulationType); + switch (Config.MatrixLayout) { + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR: + Stride = Config.DimN * StrideMultiplier; + break; + case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR: + Stride = Config.DimM * StrideMultiplier; + break; + } + + const size_t InputDivisor = + CoopVecHelpers::GetNumPackedElementsForInputDataType( + AccumulateProps.InputType); + const std::wstring InputDataType = + CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType); + const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType( + AccumulateProps.AccumulationType); + const std::wstring MatrixDataTypeEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.AccumulationType); + const std::wstring InputInterpretationEnum = + CoopVecHelpers::GetHlslInterpretationForDataType( + AccumulateProps.InputType); + + auto DimMDefine = CreateDefineFromSize(L"DIM_M", Config.DimM); + auto DimNDefine = CreateDefineFromSize(L"DIM_N", Config.DimN); + auto NumThreadsDefine = + CreateDefineFromSize(L"NUM_THREADS", Config.NumThreads); + auto StrideDefine = CreateDefineFromSize(L"STRIDE", Stride); + auto InputDataTypeDefine = + CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str()); + auto InputDivisorDefine = + CreateDefineFromSize(L"INPUT_DIVISOR", InputDivisor); + auto AccumDataTypeDefine = + CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str()); + auto InputInterpretationEnumDefine = CreateDefineFromString( + L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str()); + auto HlslMatrixLayoutDefine = + CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str()); + auto MatrixDataTypeEnumDefine = CreateDefineFromString( + L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str()); + auto InputVector1StrideDefine = + CreateDefineFromSize(L"INPUT_VECTOR_1_STRIDE", InputVector1.getStride()); + auto InputVector2StrideDefine = + CreateDefineFromSize(L"INPUT_VECTOR_2_STRIDE", InputVector2.getStride()); + + LPCWSTR Options[] = { + L"-enable-16bit-types", + DimMDefine.c_str(), + DimNDefine.c_str(), + NumThreadsDefine.c_str(), + StrideDefine.c_str(), + InputDataTypeDefine.c_str(), + InputDivisorDefine.c_str(), + AccumDataTypeDefine.c_str(), + InputInterpretationEnumDefine.c_str(), + HlslMatrixLayoutDefine.c_str(), + MatrixDataTypeEnumDefine.c_str(), + InputVector1StrideDefine.c_str(), + InputVector2StrideDefine.c_str(), + }; + + CComPtr IncludeHandler = + new LinAlgHeaderIncludeHandler(m_support); + + // Create the pipeline state for the CoopVec shaders + CComPtr PipelineState; + if (RunCompute) { CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9", - &ComputePipelineState, Options, _countof(Options), + &PipelineState, Options, _countof(Options), IncludeHandler); + } else { + CComPtr VertexShader; + CComPtr PixelShader; + + CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader, + Options, _countof(Options), IncludeHandler); + CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader, + Options, _countof(Options), IncludeHandler); + + D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {}; + PsoDesc.pRootSignature = RootSignature; + PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader); + PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader); + PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT); + PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT); + PsoDesc.DepthStencilState.DepthEnable = FALSE; + PsoDesc.DepthStencilState.StencilEnable = FALSE; + PsoDesc.SampleMask = UINT_MAX; + PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; + PsoDesc.NumRenderTargets = 1; + PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM; + PsoDesc.SampleDesc.Count = 1; + VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState( + &PsoDesc, IID_PPV_ARGS(&PipelineState))); } // Create a command list for the compute shader. @@ -12988,147 +13302,41 @@ void main(uint threadIdx : SV_GroupThreadID) VERIFY_SUCCEEDED(D3DDevice->CreateCommandAllocator( D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&CommandAllocator))); VERIFY_SUCCEEDED(D3DDevice->CreateCommandList( - 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, ComputePipelineState, + 0, D3D12_COMMAND_LIST_TYPE_DIRECT, CommandAllocator, PipelineState, IID_PPV_ARGS(&CommandList))); - // Setup input matrix as all-ones in sint8/fp32 format. This will later be - // converted to the appropriate data type by the matrix conversion API. CComPtr InputMatrixSRVResource, InputMatrixSRVUploadResource; - std::vector InputMatrix; - if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || - AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, - Config.DimM); - } else if (AccumulateProps.AccumulationType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - AccumulateProps.AccumulationType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.AccumulationType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - // Matrix source data is fp32, which gets converted to fp16 during matrix - // conversion - InputMatrix = CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, - Config.DimM); - } else { - WEX::Logging::Log::Error(L"Unsupported matrix data type"); - return; - } - CreateTestResources(D3DDevice, CommandList, InputMatrix.data(), InputMatrix.size(), CD3DX12_RESOURCE_DESC::Buffer(InputMatrix.size()), &InputMatrixSRVResource, &InputMatrixSRVUploadResource); - // Create input vectors CComPtr InputVecSRVResource1, InputVecSRVUploadResource1; - std::vector InputVector1; CComPtr InputVecSRVResource2, InputVecSRVUploadResource2; - std::vector InputVector2; - - if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 || - AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8) { - InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, - Config.DimM); - InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, - Config.DimN); - } else if (AccumulateProps.InputType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16 || - AccumulateProps.InputType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 || - AccumulateProps.InputType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { - InputVector1 = - CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.DimM); - InputVector2 = - CoopVecHelpers::CreateInputVector( - Config.NumThreads, Config.DimN); - } else if (AccumulateProps.InputType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - InputVector1 = CoopVecHelpers::CreateInputVector(Config.NumThreads, - Config.DimM); - InputVector2 = CoopVecHelpers::CreateInputVector(Config.NumThreads, - Config.DimN); - } else { - WEX::Logging::Log::Error(L"Unsupported input data type"); - return; - } - if (InputVector1.size() % 4 != 0) { - // Align size to 4 bytes for ByteAddressBuffer - InputVector1.resize(InputVector1.size() + 4 - (InputVector1.size() % 4)); - } - if (InputVector2.size() % 4 != 0) { - // Align size to 4 bytes for ByteAddressBuffer - InputVector2.resize(InputVector2.size() + 4 - (InputVector2.size() % 4)); - } - CreateTestResources(D3DDevice, CommandList, InputVector1.data(), - InputVector1.size(), - CD3DX12_RESOURCE_DESC::Buffer(InputVector1.size()), - &InputVecSRVResource1, &InputVecSRVUploadResource1); - CreateTestResources(D3DDevice, CommandList, InputVector2.data(), - InputVector2.size(), - CD3DX12_RESOURCE_DESC::Buffer(InputVector2.size()), - &InputVecSRVResource2, &InputVecSRVUploadResource2); - - // This increments baseHandle - CreateRawSRV(D3DDevice, BaseHandle, - (UINT)(InputVector1.size() / sizeof(int32_t)), - InputVecSRVResource1); - CreateRawSRV(D3DDevice, BaseHandle, - (UINT)(InputVector2.size() / sizeof(int32_t)), - InputVecSRVResource2); - // Calculate reference output - auto ExpectedOutputBufferI8 = - CoopVecHelpers::CreateAllOnesInputMatrix(Config.DimN, Config.DimM); - std::vector ExpectedOutputBuffer(ExpectedOutputBufferI8.size() / - sizeof(float)); - std::memcpy(ExpectedOutputBuffer.data(), ExpectedOutputBufferI8.data(), - ExpectedOutputBufferI8.size()); + CreateTestResources( + D3DDevice, CommandList, InputVector1.getBuffer(), + InputVector1.getTotalBytes(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector1.getTotalBytes()), + &InputVecSRVResource1, &InputVecSRVUploadResource1); + CreateTestResources( + D3DDevice, CommandList, InputVector2.getBuffer(), + InputVector2.getTotalBytes(), + CD3DX12_RESOURCE_DESC::Buffer(InputVector2.getTotalBytes()), + &InputVecSRVResource2, &InputVecSRVUploadResource2); - if (AccumulateProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16) { - std::vector InputVector1FP16( - InputVector1.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(InputVector1FP16.data(), InputVector1.data(), - InputVector1.size()); - - std::vector InputVector2FP16( - InputVector2.size() / sizeof(DirectX::PackedVector::HALF)); - std::memcpy(InputVector2FP16.data(), InputVector2.data(), - InputVector2.size()); - - for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { - for (int M = 0; M < Config.DimM; ++M) { - for (int N = 0; N < Config.DimN; ++N) { - float acc = ConvertFloat16ToFloat32(InputVector1FP16[M]) * - ConvertFloat16ToFloat32(InputVector2FP16[N]); - ExpectedOutputBuffer[M * Config.DimN + N] += acc; - } - } - } - } else if (AccumulateProps.InputType == - D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) { - std::vector InputVector1FP32(InputVector1.size() / sizeof(float)); - std::memcpy(InputVector1FP32.data(), InputVector1.data(), - InputVector1.size()); - - std::vector InputVector2FP32(InputVector2.size() / sizeof(float)); - std::memcpy(InputVector2FP32.data(), InputVector2.data(), - InputVector2.size()); - - for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) { - for (int M = 0; M < Config.DimM; ++M) { - for (int N = 0; N < Config.DimN; ++N) { - float Acc = InputVector1FP32[ThreadIdx * Config.DimM + M] * - InputVector2FP32[ThreadIdx * Config.DimN + N]; - ExpectedOutputBuffer[M * Config.DimN + N] += Acc; - } - } - } - } + // This increments baseHandle + CreateRawSRV( + D3DDevice, BaseHandle, + static_cast(InputVector1.getTotalBytes() / sizeof(int32_t)), + InputVecSRVResource1); + CreateRawSRV( + D3DDevice, BaseHandle, + static_cast(InputVector2.getTotalBytes() / sizeof(int32_t)), + InputVecSRVResource2); CComPtr ConvertedMatrixResource, ConvertedMatrixReadResource; - int ConvertedMatrixSize = 0; + UINT ConvertedMatrixSize = 0; { // Create source matrix info D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_SRC_INFO SrcInfo = {}; @@ -13139,8 +13347,8 @@ void main(uint threadIdx : SV_GroupThreadID) // Create destination matrix info D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_DEST_INFO DestInfo = {}; DestInfo.DestSize = 0; // Will be populated by driver - int SrcEltSize = 0; - int DestEltSize = 0; + UINT SrcEltSize = 0; + UINT DestEltSize = 0; switch (AccumulateProps.AccumulationType) { case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED: @@ -13164,19 +13372,19 @@ void main(uint threadIdx : SV_GroupThreadID) DestEltSize = 1; // FP8 break; } - SrcInfo.SrcStride = Config.DimM * SrcEltSize; - SrcInfo.SrcSize = Config.DimM * Config.DimN * SrcEltSize; + SrcInfo.SrcStride = static_cast(Config.DimM * SrcEltSize); + SrcInfo.SrcSize = static_cast(Config.DimM * Config.DimN * SrcEltSize); DestInfo.DestLayout = Config.MatrixLayout; DestInfo.DestStride = 0; - DestInfo.NumRows = Config.DimM; - DestInfo.NumColumns = Config.DimN; + DestInfo.NumRows = static_cast(Config.DimM); + DestInfo.NumColumns = static_cast(Config.DimN); if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) { - DestInfo.DestStride = Config.DimM * DestEltSize; + DestInfo.DestStride = static_cast(Config.DimM * DestEltSize); } else if (Config.MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) { - DestInfo.DestStride = Config.DimM * DestEltSize; + DestInfo.DestStride = static_cast(Config.DimM * DestEltSize); } // Create conversion info @@ -13210,10 +13418,11 @@ void main(uint threadIdx : SV_GroupThreadID) ConvertInfo.DataDesc = DataDesc; // Get command list interface and perform conversion - CComPtr CommandList11; - VERIFY_SUCCEEDED(CommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); - CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + CComPtr CommandListPreview; + VERIFY_SUCCEEDED( + CommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandListPreview), + (void **)&CommandListPreview)); + CommandListPreview->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); // This increments baseHandle if ((ConvertInfo.DestInfo.DestSize % 4) != 0) { @@ -13225,27 +13434,79 @@ void main(uint threadIdx : SV_GroupThreadID) ConvertedMatrixResource); } + // Create resource for atomic counter + CComPtr AtomicCounterResource; + uint32_t AtomicCounterInit = 0; + CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit, + sizeof(AtomicCounterInit), + CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)), + &AtomicCounterResource, nullptr); + CommandList->Close(); ExecuteCommandList(CommandQueue, CommandList); WaitForSignal(CommandQueue, FO); VERIFY_SUCCEEDED(CommandAllocator->Reset()); - VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, PipelineState)); SetDescriptorHeap(CommandList, DescriptorHeap); CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle( DescriptorHeap->GetGPUDescriptorHandleForHeapStart()); - CommandList->SetComputeRootSignature(RootSignature); - CommandList->SetComputeRootDescriptorTable(0, ResHandle); - CommandList->SetPipelineState(ComputePipelineState); - CommandList->Dispatch(1, 1, 1); + CComPtr RtvHeap; + CComPtr RenderTarget; + CComPtr RenderTargetRead; + + if (RunCompute) { + CommandList->SetComputeRootSignature(RootSignature); + CommandList->SetComputeRootDescriptorTable(0, ResHandle); + CommandList->SetPipelineState(PipelineState); + CommandList->Dispatch(1, 1, 1); + } else { + UINT FrameCount = 1; + UINT RtvDescSize = 0; + CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize); + CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget, + &RenderTargetRead); + + D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc(); + D3D12_VIEWPORT Viewport; + D3D12_RECT ScissorRect; + + memset(&Viewport, 0, sizeof(Viewport)); + Viewport.Height = static_cast(RtDesc.Height); + Viewport.Width = static_cast(RtDesc.Width); + Viewport.MaxDepth = 1.0f; + memset(&ScissorRect, 0, sizeof(ScissorRect)); + ScissorRect.right = static_cast(RtDesc.Width); + ScissorRect.bottom = static_cast(RtDesc.Height); + CommandList->SetGraphicsRootSignature(RootSignature); + CommandList->SetGraphicsRootDescriptorTable(0, ResHandle); + CommandList->SetGraphicsRootUnorderedAccessView( + 1, AtomicCounterResource->GetGPUVirtualAddress()); + CommandList->RSSetViewports(1, &Viewport); + CommandList->RSSetScissorRects(1, &ScissorRect); + + // Indicate that the buffer will be used as a render target. + RecordTransitionBarrier(CommandList, RenderTarget, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_RESOURCE_STATE_RENDER_TARGET); + + CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle( + RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize); + CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr); + + CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr); + CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST); + CommandList->DrawInstanced(3, 1, 0, 0); + } + CommandList->Close(); ExecuteCommandList(CommandQueue, CommandList); WaitForSignal(CommandQueue, FO); VERIFY_SUCCEEDED(CommandAllocator->Reset()); - VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, ComputePipelineState)); + VERIFY_SUCCEEDED(CommandList->Reset(CommandAllocator, PipelineState)); // Convert matrix to sint8/fp32 row-major format before reading back to the // CPU. A new resource is created, along with a readback resource, for the @@ -13263,8 +13524,8 @@ void main(uint threadIdx : SV_GroupThreadID) ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver ConvertInfo.DestInfo.DestLayout = D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR; - ConvertInfo.DestInfo.NumRows = Config.DimM; - ConvertInfo.DestInfo.NumColumns = Config.DimN; + ConvertInfo.DestInfo.NumRows = static_cast(Config.DimM); + ConvertInfo.DestInfo.NumColumns = static_cast(Config.DimN); if (AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32 || @@ -13275,10 +13536,12 @@ void main(uint threadIdx : SV_GroupThreadID) AccumulateProps.AccumulationType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) { ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; - ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(float); + ConvertInfo.DestInfo.DestStride = + static_cast(Config.DimN * sizeof(float)); } else { ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8; - ConvertInfo.DestInfo.DestStride = Config.DimN * sizeof(int8_t); + ConvertInfo.DestInfo.DestStride = + static_cast(Config.DimN * sizeof(int8_t)); } // Get destination size using preview interface @@ -13309,10 +13572,11 @@ void main(uint threadIdx : SV_GroupThreadID) ConvertedMatrixResource->GetGPUVirtualAddress(); // Get command list interface and perform conversion - CComPtr CommandList11; - VERIFY_SUCCEEDED(CommandList->QueryInterface( - __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11)); - CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); + CComPtr CommandListPreview; + VERIFY_SUCCEEDED( + CommandList->QueryInterface(__uuidof(ID3D12GraphicsCommandListPreview), + (void **)&CommandListPreview)); + CommandListPreview->ConvertLinearAlgebraMatrix(&ConvertInfo, 1); } RecordTransitionBarrier(CommandList, MatrixRowMajorResource, @@ -13324,15 +13588,17 @@ void main(uint threadIdx : SV_GroupThreadID) WaitForSignal(CommandQueue, FO); { - MappedData MappedData(MatrixRowMajorReadResource, (UINT)InputMatrix.size()); + MappedData MappedData(MatrixRowMajorReadResource, + static_cast(InputMatrix.size())); float *ResultBuffer = (float *)MappedData.data(); bool Equal = true; - for (int i = 0; i < (UINT)InputMatrix.size() / sizeof(float); i++) { + for (size_t i = 0; + i < static_cast(InputMatrix.size() / sizeof(float)); i++) { if (isnan(ResultBuffer[i]) || isnan(ExpectedOutputBuffer[i]) || fabs(ResultBuffer[i] - ExpectedOutputBuffer[i]) > 0.00001) { LogErrorFmt(L"Result mismatch at index %d", i); - LogErrorFmt(L"ResultBuffer[%d]: %f, ExpectedOutputBuffer[%d]: %f", i, + LogErrorFmt(L"ResultBuffer[%zu]: %f, ExpectedOutputBuffer[%zu]: %f", i, ResultBuffer[i], i, ExpectedOutputBuffer[i]); Equal = false; break;