diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index 1044119685..06ca264cc2 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -152,11 +152,18 @@ class DxilConf_SM610_LinAlg { TEST_CLASS_SETUP(setupClass); TEST_METHOD_SETUP(setupMethod); + // Load/Store TEST_METHOD(LoadStoreRoundtrip_Wave_F32); TEST_METHOD(LoadStoreRoundtrip_Wave_I32); + + // Splat Store TEST_METHOD(SplatStore_Wave_F32); TEST_METHOD(SplatStore_Wave_I32); + // Element access + TEST_METHOD(ElementAccess_Wave_F32); + TEST_METHOD(ElementAccess_Wave_I32); + private: bool createDevice(); @@ -389,6 +396,7 @@ static void runSplatStore(ID3D12Device *Device, } #endif + // Build expected data. std::vector ExpectedFloats; std::vector ExpectedInts; switch (Params.CompType) { @@ -456,4 +464,189 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() { runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging); } +static const char ElementAccessShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + // flatten the 2D index into a 1d index then scale by element size + uint cordToByteOffset(uint2 coord) { + return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE; + } + + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadIndex : SV_GroupIndex) { + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 0); + + // Copy Matrix values from input to output without assuming order + for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) { + uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I); + uint Offset = cordToByteOffset(Coord); +#if COMP_TYPE == 9 + float Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Output.Store(Offset, asuint(Elem)); +#else + uint Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Output.Store(Offset, Elem); +#endif + } + + // Store each threads Length in the output after the copied matrix + uint finalIdx = (M_DIM * N_DIM + threadIndex) * ELEM_SIZE; + uint Len = __builtin_LinAlg_MatrixLength(Mat); + Output.Store(finalIdx, Len); + } +)"; + +static void runElementAccess(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, int MajorDim, bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t NumThreads = Params.NumThreads; + const size_t InputBufSize = Params.totalBytes(); + const size_t ElementSize = elemSize(Params.CompType); + // Output: ElementSize bytes per element + // 1 element for each mat idx + // 1 element for each thread's length + const size_t OutputBufSize = (NumElements + NumThreads) * ElementSize; + + std::stringstream ExtraDefs; + ExtraDefs << " -DMAJOR_DIM=" << MajorDim; + ExtraDefs << " -DELEM_SIZE=" << ElementSize; + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose); + +#ifndef _HLK_CONF + // Skip GPU execution if no device. + if (!Device) { + hlsl_test::LogCommentFmt( + L"Shader compiled OK; skipping execution (no SM 6.10 device)"); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; + } +#endif + + // Build expected data. + std::vector ExpectedFloats(NumElements); + std::vector ExpectedInts(NumElements); + for (size_t I = 0; I < NumElements; I++) { + ExpectedFloats[I] = static_cast(I + 1); + ExpectedInts[I] = static_cast(I + 1); + } + + auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", OutputBufSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [&](LPCSTR Name, std::vector &Data, st::ShaderOp *) { + if (_stricmp(Name, "Input") != 0) + return; + + switch (Params.CompType) { + case ComponentType::F32: { + float *Ptr = reinterpret_cast(Data.data()); + for (size_t I = 0; I < NumElements; I++) + Ptr[I] = static_cast(I + 1); + break; + } + case ComponentType::I32: { + int32_t *Ptr = reinterpret_cast(Data.data()); + for (size_t I = 0; I < NumElements; I++) + Ptr[I] = static_cast(I + 1); + break; + } + default: + VERIFY_IS_TRUE(false, "Saw unsupported component type"); + break; + } + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + const uint32_t *Out = static_cast(OutData.data()); + + // Build actual data. + std::vector ActualFloats(NumElements); + std::vector ActualInts(NumElements); + for (size_t I = 0; I < NumElements * 4; I = I + 4) { + switch (Params.CompType) { + case ComponentType::F32: { + float Actual; + memcpy(&Actual, &Out[I], sizeof(float)); + ActualFloats[I / 4] = Actual; + break; + } + case ComponentType::I32: { + ActualInts[I / 4] = Out[I]; + break; + } + default: + VERIFY_IS_TRUE(false, "Saw unsupported component type"); + break; + } + } + + // Verify element values match input data. + switch (Params.CompType) { + case ComponentType::F32: + VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(), + NumElements, Verbose)); + break; + case ComponentType::I32: + VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(), + NumElements, Verbose)); + break; + default: + VERIFY_IS_TRUE(false, "Saw unsupported component type"); + break; + } + + // The sum of the values returned by Length across all threads must be + // greater than or equal to the total number of matrix elements + size_t TotalLength = 0; + for (size_t I = NumElements * 4; I < (NumElements + NumThreads) * 4; + I = I + 4) { + TotalLength += Out[I]; + } + VERIFY_IS_TRUE(TotalLength >= NumElements, + "Sum of all lengths must be gte num elements"); +} + +void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F32; + Params.M = 4; + Params.N = 4; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 4; + Params.Enable16Bit = false; + runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging); +} + +void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() { + MatrixParams Params = {}; + Params.CompType = ComponentType::I32; + Params.M = 4; + Params.N = 4; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 4; + Params.Enable16Bit = false; + runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging); +} + } // namespace LinAlg