@@ -190,6 +190,33 @@ static void compileShader(dxc::SpecificDllLoader &DxcSupport,
190190 }
191191}
192192
193+ // ===========================================================================
194+ // Test parameters
195+ // ===========================================================================
196+
197+ struct MatrixParams {
198+ ComponentType CompType;
199+ int M;
200+ int N;
201+ MatrixUse Use;
202+ MatrixScope Scope;
203+ LinalgMatrixLayout Layout;
204+ int NumThreads;
205+ bool Enable16Bit;
206+
207+ int strideBytes () const {
208+ int ES = elemSize (CompType);
209+ if (Layout == LinalgMatrixLayout::RowMajor)
210+ return N * ES;
211+ else
212+ return M * ES;
213+ }
214+
215+ size_t totalElements () const { return static_cast <size_t >(M) * N; }
216+
217+ size_t totalBytes () const { return totalElements () * elemSize (CompType); }
218+ };
219+
193220// ===========================================================================
194221// Compiler arguments builder
195222// ===========================================================================
@@ -257,33 +284,6 @@ static bool verifyIntBuffer(const void *Actual, const int32_t *Expected,
257284 return Success;
258285}
259286
260- // ===========================================================================
261- // Test parameters
262- // ===========================================================================
263-
264- struct MatrixParams {
265- ComponentType CompType;
266- int M;
267- int N;
268- MatrixUse Use;
269- MatrixScope Scope;
270- LinalgMatrixLayout Layout;
271- int NumThreads;
272- bool Enable16Bit;
273-
274- int strideBytes () const {
275- int ES = elemSize (CompType);
276- if (Layout == LinalgMatrixLayout::RowMajor)
277- return N * ES;
278- else
279- return M * ES;
280- }
281-
282- size_t totalElements () const { return static_cast <size_t >(M) * N; }
283-
284- size_t totalBytes () const { return totalElements () * elemSize (CompType); }
285- };
286-
287287// ===========================================================================
288288// Test class
289289// ===========================================================================
@@ -345,15 +345,16 @@ bool DxilConf_SM610_LinAlg::setupClass() {
345345 WEX::TestExecution::RuntimeParameters::TryGetValue (
346346 L" FailIfRequirementsNotMet" , FailIfRequirementsNotMet);
347347
348- // Try to create a device. In HLK mode, fail if unavailable.
349- // In dev mode, D3DDevice stays null and tests will compile shaders
350- // then skip GPU execution.
348+ const bool SkipUnsupported = !FailIfRequirementsNotMet;
351349 if (!D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10,
352- /* SkipUnsupported= */ false )) {
350+ SkipUnsupported)) {
353351 if (FailIfRequirementsNotMet) {
354352 hlsl_test::LogErrorFmt (
355- L" Device creation failed for SM 6.10, and "
356- L" FailIfRequirementsNotMet is set." );
353+ L" Device creation failed, resulting in test failure, since "
354+ L" FailIfRequirementsNotMet is set. The expectation is that this "
355+ L" test will only be executed if something has previously "
356+ L" determined that the system meets the requirements of this "
357+ L" test." );
357358 return false ;
358359 }
359360 // No device — tests will compile shaders and skip execution.
@@ -364,14 +365,18 @@ bool DxilConf_SM610_LinAlg::setupClass() {
364365}
365366
366367bool DxilConf_SM610_LinAlg::setupMethod () {
367- // Re-create device if it was lost. If we never had one, that's fine —
368- // tests compile shaders and skip GPU execution .
368+ // It's possible a previous test case caused a device removal. If it did we
369+ // need to try and create a new device .
369370 if (D3DDevice && D3DDevice->GetDeviceRemovedReason () != S_OK) {
370- hlsl_test::LogCommentFmt (L" Device was lost!" );
371+ hlsl_test::LogCommentFmt (L" Device was lost! Recreating... " );
371372 D3DDevice.Release ();
372- D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10,
373- /* SkipUnsupported=*/ false );
373+
374+ // We expect recreation to succeed since we had a working device before.
375+ const bool SkipUnsupported = false ;
376+ VERIFY_IS_TRUE (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10,
377+ SkipUnsupported));
374378 }
379+
375380 return true ;
376381}
377382
@@ -540,9 +545,18 @@ static void runSplatStore(ID3D12Device *Device,
540545 return ;
541546 }
542547
543- std::vector<float > ExpectedFloats (NumElements, FillValue);
544- std::vector<int32_t > ExpectedInts (NumElements,
545- static_cast <int32_t >(FillValue));
548+ std::vector<float > ExpectedFloats;
549+ std::vector<int32_t > ExpectedInts;
550+ switch (Params.CompType ) {
551+ case ComponentType::F32:
552+ ExpectedFloats.assign (NumElements, FillValue);
553+ break ;
554+ case ComponentType::I32:
555+ ExpectedInts.assign (NumElements, static_cast <int32_t >(FillValue));
556+ break ;
557+ default :
558+ break ;
559+ }
546560
547561 auto Op =
548562 createComputeOp (SplatStoreShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
0 commit comments