diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index 2e4ce65d57..6eb637cdcd 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -228,37 +228,58 @@ static bool fillInputBuffer(LPCSTR Name, std::vector &Data, return false; } -static VariantCompType makeExpected(ComponentType CompType, size_t NumElements, - float StartingVal, bool Increment) { +static VariantCompType makeExpected(ComponentType CompType, int32_t M, + int32_t N, float StartingVal, + bool Increment = true, + bool Transpose = false) { + const size_t NumElements = M * N; + std::vector Floats(NumElements); + std::vector Ints(NumElements); + std::vector Halfs(NumElements); + + for (size_t I = 0; I < M; ++I) { + for (size_t J = 0; J < N; ++J) { + size_t Value = I * M + J; + size_t Idx = Transpose ? J * N + I : Value; + switch (CompType) { + case ComponentType::F32: + Floats[Idx] = StartingVal + static_cast(Increment ? Value : 0); + break; + case ComponentType::I32: + VERIFY_IS_TRUE(StartingVal < static_cast( + std::numeric_limits::max()), + "Value too large to cast to int32_t"); + VERIFY_IS_TRUE(StartingVal > static_cast( + std::numeric_limits::min()), + "Value too small to cast to int32_t"); + Ints[Idx] = static_cast(StartingVal) + + static_cast(Increment ? Value : 0); + break; + case ComponentType::F16: { + // Downcasting is safe here since HLSLHalf_t will clamp if F is too + // large. + float F = StartingVal + static_cast(Increment ? Value : 0); + Halfs[Idx] = HLSLHalf_t(F); + break; + } + default: + VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType"); + break; + } + } + } + switch (CompType) { - case ComponentType::F32: { - std::vector Floats(NumElements); - for (size_t I = 0; I < NumElements; I++) - Floats[I] = StartingVal + static_cast(Increment ? I : 0); + case ComponentType::F32: return Floats; - } - case ComponentType::I32: { - DXASSERT(StartingVal < static_cast(INT_MAX), - "Value too large to cast to int32_t"); - std::vector Ints(NumElements); - for (size_t I = 0; I < NumElements; I++) - Ints[I] = static_cast(StartingVal) + - static_cast(Increment ? I : 0); + case ComponentType::I32: return Ints; - } - case ComponentType::F16: { - std::vector Halfs(NumElements); - for (size_t I = 0; I < NumElements; I++) { - // Downcasting is safe here since HLSLHalf_t will clamp if F is too large. - float F = StartingVal + static_cast(Increment ? I : 0); - Halfs[I] = HLSLHalf_t(F); - } + case ComponentType::F16: return Halfs; + default: + VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType"); + return Floats; } - } - - DXASSERT(false, "Unable to fill unexpected ComponentType"); - return std::vector(); } static void logCompiledButSkipping() { @@ -384,6 +405,7 @@ static const char LoadStoreShader[] = R"( RWByteAddressBuffer Output : register(u1); #ifndef EMULATE_TEST + [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] void main() { __builtin_LinAlgMatrix @@ -429,7 +451,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device, return; } - auto Expected = makeExpected(Params.CompType, NumElements, 1, true); + auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); // Construct the ShaderOp: two UAV buffers, load from one, store to other. auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)", @@ -463,7 +485,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { Params.Use = MatrixUse::A; Params.Scope = MatrixScope::Wave; Params.Layout = LinalgMatrixLayout::RowMajor; - Params.NumThreads = 4; + Params.NumThreads = 64; Params.Enable16Bit = true; Params.EmulateTest = EmulateTest; runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging, @@ -474,6 +496,7 @@ static const char SplatStoreShader[] = R"( RWByteAddressBuffer Output : register(u0); #ifndef EMULATE_TEST + [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] void main() { __builtin_LinAlgMatrix @@ -517,7 +540,8 @@ static void runSplatStore(ID3D12Device *Device, return; } - auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false); + auto Expected = + makeExpected(Params.CompType, Params.M, Params.N, FillValue, false); auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)", Args.c_str()); @@ -541,7 +565,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() { Params.Use = MatrixUse::Accumulator; Params.Scope = MatrixScope::Wave; Params.Layout = LinalgMatrixLayout::RowMajor; - Params.NumThreads = 4; + Params.NumThreads = 64; Params.Enable16Bit = true; Params.EmulateTest = EmulateTest; runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging, @@ -553,11 +577,13 @@ static const char ElementAccessShader[] = R"( RWByteAddressBuffer Output : register(u1); // flatten the 2D index into a 1D index then scale by element size + // Always store row-major and work it out in the test runner uint coordToByteOffset(uint2 coord) { - return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE; + return (coord.y * N_DIM + coord.x) * ELEM_SIZE; } #ifndef EMULATE_TEST + [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] void main(uint threadIndex : SV_GroupIndex) { __builtin_LinAlgMatrix @@ -605,8 +631,7 @@ static void runElementAccess(ID3D12Device *Device, const size_t NumThreads = Params.NumThreads; const size_t InputBufSize = Params.totalBytes(); const size_t ElementSize = elementSize(Params.CompType); - const size_t MajorDim = - Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N; + // Output: ElementSize bytes per element // 1 element for each mat idx // 1 uint for each thread's length @@ -618,7 +643,6 @@ static void runElementAccess(ID3D12Device *Device, Target = "cs_6_8"; std::stringstream ExtraDefs; - ExtraDefs << " -DMAJOR_DIM=" << MajorDim; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose); @@ -628,7 +652,7 @@ static void runElementAccess(ID3D12Device *Device, return; } - auto Expected = makeExpected(Params.CompType, NumElements, 1, true); + auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); auto Op = createComputeOp(ElementAccessShader, Target.c_str(), "UAV(u0), UAV(u1)", Args.c_str()); @@ -674,7 +698,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() { Params.Use = MatrixUse::Accumulator; Params.Scope = MatrixScope::Wave; Params.Layout = LinalgMatrixLayout::RowMajor; - Params.NumThreads = 4; + Params.NumThreads = 64; Params.Enable16Bit = true; Params.EmulateTest = EmulateTest; runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);