diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index 6eb637cdcd..570dcb705a 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -108,8 +108,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params, SS << " -DELEM_TYPE=uint"; break; } - if (Params.EmulateTest) - SS << " -DEMULATE_TEST"; if (Params.Enable16Bit) SS << " -enable-16bit-types"; if (ExtraDefines) @@ -282,12 +280,6 @@ static VariantCompType makeExpected(ComponentType CompType, int32_t M, } } -static void logCompiledButSkipping() { - hlsl_test::LogCommentFmt( - L"Shader compiled OK; skipping execution (no SM 6.10 device)"); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); -} - class DxilConf_SM610_LinAlg { public: BEGIN_TEST_CLASS(DxilConf_SM610_LinAlg) @@ -316,36 +308,16 @@ class DxilConf_SM610_LinAlg { TEST_METHOD(ElementAccess_Wave_16x16_F16); private: - D3D_SHADER_MODEL createDevice(); - CComPtr D3DDevice; dxc::SpecificDllLoader DxcSupport; bool VerboseLogging = false; - bool EmulateTest = false; bool Initialized = false; - bool CompileOnly = false; std::optional D3D12SDK; WEX::TestExecution::SetVerifyOutput VerifyOutput{ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures}; }; -/// Attempts to create a device. If shaders are being emulated then a SM6.8 -/// device is attempted. Otherwise a SM6.10 device is attempted -D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() { - if (EmulateTest) { - if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false)) - return D3D_SHADER_MODEL_6_8; - - return D3D_SHADER_MODEL_NONE; - } - - if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false)) - return D3D_SHADER_MODEL_6_10; - - return D3D_SHADER_MODEL_NONE; -} - bool DxilConf_SM610_LinAlg::setupClass() { if (!Initialized) { Initialized = true; @@ -354,28 +326,18 @@ bool DxilConf_SM610_LinAlg::setupClass() { D3D12SDK = D3D12SDKSelector(); WEX::TestExecution::RuntimeParameters::TryGetValue(L"VerboseLogging", VerboseLogging); - WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest", - EmulateTest); - D3D_SHADER_MODEL SupportedSM = createDevice(); - - if (EmulateTest) { - hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL"); - if (SupportedSM != D3D_SHADER_MODEL_6_8) { - hlsl_test::LogErrorFmt( - L"Device creation failed. Expected a driver supporting SM6.8"); - return false; - } - } + if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false)) { #ifdef _HLK_CONF - if (SupportedSM != D3D_SHADER_MODEL_6_10) { hlsl_test::LogErrorFmt( L"Device creation failed. Expected a driver supporting SM6.10"); +#else + hlsl_test::LogWarningFmt( + L"Device creation failed. Expected a driver supporting SM6.10"); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); +#endif return false; } -#endif - - CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE; } return true; @@ -387,27 +349,24 @@ bool DxilConf_SM610_LinAlg::setupMethod() { if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK) return true; - // Device is expected to be null. No point in recreating it - if (CompileOnly) - return true; - hlsl_test::LogCommentFmt(L"Device was lost!"); D3DDevice.Release(); hlsl_test::LogCommentFmt(L"Recreating device"); - // !CompileOnly implies we expect it to succeeded - return createDevice() != D3D_SHADER_MODEL_NONE; + return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false); } static const char LoadStoreShader[] = R"( RWByteAddressBuffer Input : register(u0); RWByteAddressBuffer Output : register(u1); -#ifndef EMULATE_TEST [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] - void main() { + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] Mat; @@ -416,45 +375,26 @@ static const char LoadStoreShader[] = R"( __builtin_LinAlg_MatrixStoreToDescriptor( Mat, Output, OFFSET, STRIDE, LAYOUT, 128); } -#else - [numthreads(NUMTHREADS, 1, 1)] - void main() { - for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); - } - } -#endif )"; static void runLoadStoreRoundtrip(ID3D12Device *Device, dxc::SpecificDllLoader &DxcSupport, - const MatrixParams &Params, bool Verbose, - bool CompileOnly) { + const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t BufferSize = Params.totalBytes(); - std::string Target = "cs_6_10"; - if (Params.EmulateTest) - Target = "cs_6_8"; - // TODO: these should be varied by test to ensure full coverage std::stringstream ExtraDefs; ExtraDefs << " -DOFFSET=" << 0; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); - // Always verify the shader compiles. - compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose); - - if (CompileOnly) { - logCompiledButSkipping(); - return; - } + compileShader(DxcSupport, LoadStoreShader, "cs_6_10", Args, Verbose); auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); // Construct the ShaderOp: two UAV buffers, load from one, store to other. - auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)", + auto Op = createComputeOp(LoadStoreShader, "cs_6_10", "UAV(u0), UAV(u1)", Args.c_str()); addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); addUAVBuffer(Op.get(), "Output", BufferSize, true); @@ -487,18 +427,18 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { Params.Layout = LinalgMatrixLayout::RowMajor; Params.NumThreads = 64; Params.Enable16Bit = true; - Params.EmulateTest = EmulateTest; - runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging, - CompileOnly); + runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging); } static const char SplatStoreShader[] = R"( RWByteAddressBuffer Output : register(u0); -#ifndef EMULATE_TEST [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] - void main() { + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] Mat; @@ -506,45 +446,27 @@ static const char SplatStoreShader[] = R"( __builtin_LinAlg_MatrixStoreToDescriptor( Mat, Output, 0, STRIDE, LAYOUT, 128); } -#else - [numthreads(NUMTHREADS, 1, 1)] - void main() { - ELEM_TYPE fill = FILL_VALUE; - for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, fill); - } - } -#endif )"; static void runSplatStore(ID3D12Device *Device, dxc::SpecificDllLoader &DxcSupport, const MatrixParams &Params, float FillValue, - bool Verbose, bool CompileOnly) { + bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t BufferSize = Params.totalBytes(); - std::string Target = "cs_6_10"; - if (Params.EmulateTest) - Target = "cs_6_8"; std::stringstream ExtraDefs; ExtraDefs << "-DFILL_VALUE=" << FillValue; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); - // Always verify the shader compiles. - compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose); - - if (CompileOnly) { - logCompiledButSkipping(); - return; - } + compileShader(DxcSupport, SplatStoreShader, "cs_6_10", Args, Verbose); auto Expected = makeExpected(Params.CompType, Params.M, Params.N, FillValue, false); - auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)", - Args.c_str()); + auto Op = + createComputeOp(SplatStoreShader, "cs_6_10", "UAV(u0)", Args.c_str()); addUAVBuffer(Op.get(), "Output", BufferSize, true); addRootUAV(Op.get(), 0, "Output"); @@ -567,9 +489,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() { Params.Layout = LinalgMatrixLayout::RowMajor; Params.NumThreads = 64; Params.Enable16Bit = true; - Params.EmulateTest = EmulateTest; - runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging, - CompileOnly); + runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging); } static const char ElementAccessShader[] = R"( @@ -582,10 +502,12 @@ static const char ElementAccessShader[] = R"( return (coord.y * N_DIM + coord.x) * ELEM_SIZE; } -#ifndef EMULATE_TEST [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] - void main(uint threadIndex : SV_GroupIndex) { + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] Mat; @@ -603,30 +525,15 @@ static const char ElementAccessShader[] = R"( // Save the matrix length that this thread saw. The length is written // to the output right after the matrix, offset by the thread index - uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint)); + uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadID * sizeof(uint)); uint Len = __builtin_LinAlg_MatrixLength(Mat); Output.Store(LenIdx, Len); } -#else - [numthreads(NUMTHREADS, 1, 1)] - void main(uint threadIndex : SV_GroupIndex) { - uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint)); - Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS); - - if (threadIndex != 0) - return; - - for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); - } - } -#endif )"; static void runElementAccess(ID3D12Device *Device, dxc::SpecificDllLoader &DxcSupport, - const MatrixParams &Params, bool Verbose, - bool CompileOnly) { + const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t NumThreads = Params.NumThreads; const size_t InputBufSize = Params.totalBytes(); @@ -638,24 +545,15 @@ static void runElementAccess(ID3D12Device *Device, const size_t OutputBufSize = NumElements * ElementSize + NumThreads * sizeof(uint32_t); - std::string Target = "cs_6_10"; - if (Params.EmulateTest) - Target = "cs_6_8"; - std::stringstream ExtraDefs; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); - compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose); - - if (CompileOnly) { - logCompiledButSkipping(); - return; - } + compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose); auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); - auto Op = createComputeOp(ElementAccessShader, Target.c_str(), - "UAV(u0), UAV(u1)", Args.c_str()); + auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname"); addUAVBuffer(Op.get(), "Output", OutputBufSize, true); addRootUAV(Op.get(), 0, "Input"); @@ -700,8 +598,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() { Params.Layout = LinalgMatrixLayout::RowMajor; Params.NumThreads = 64; Params.Enable16Bit = true; - Params.EmulateTest = EmulateTest; - runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly); + runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging); } } // namespace LinAlg