@@ -38,7 +38,9 @@ using hlsl::DXIL::LinalgMatrixLayout;
3838using hlsl::DXIL::MatrixScope;
3939using hlsl::DXIL::MatrixUse;
4040
41+ using HLSLTestDataTypes::doValuesMatch;
4142using HLSLTestDataTypes::HLSLHalf_t;
43+ using HLSLTestDataTypes::ValidationType;
4244
4345using VariantCompType = std::variant<std::vector<float >, std::vector<int32_t >,
4446 std::vector<HLSLHalf_t>>;
@@ -111,10 +113,8 @@ static bool verifyFloatBuffer(const float *Actual, const float *Expected,
111113 float Tolerance = 0 .0f ) {
112114 bool Success = true ;
113115 for (size_t I = 0 ; I < Count; I++) {
114- float Diff = Actual[I] - Expected[I];
115- if (Diff < 0 )
116- Diff = -Diff;
117- if (Diff > Tolerance) {
116+ if (!doValuesMatch (Actual[I], Expected[I], Tolerance,
117+ ValidationType::Epsilon)) {
118118 hlsl_test::LogErrorFmt (L" Mismatch at index %zu: actual=%f, expected=%f" ,
119119 I, static_cast <double >(Actual[I]),
120120 static_cast <double >(Expected[I]));
@@ -132,7 +132,7 @@ static bool verifyIntBuffer(const int32_t *Actual, const int32_t *Expected,
132132 size_t Count, bool Verbose) {
133133 bool Success = true ;
134134 for (size_t I = 0 ; I < Count; I++) {
135- if (Actual[I] != Expected[I]) {
135+ if (! doValuesMatch ( Actual[I], Expected[I], 0.0 , ValidationType::Epsilon) ) {
136136 hlsl_test::LogErrorFmt (L" Mismatch at index %zu: actual=%d, expected=%d" ,
137137 I, Actual[I], Expected[I]);
138138 Success = false ;
@@ -149,10 +149,8 @@ static bool verifyHalfBuffer(const HLSLHalf_t *Actual,
149149 bool Verbose, HLSLHalf_t Tolerance = 0 .0f ) {
150150 bool Success = true ;
151151 for (size_t I = 0 ; I < Count; I++) {
152- HLSLHalf_t Diff = Actual[I] - Expected[I];
153- if (Diff < 0 .0f )
154- Diff = -Diff;
155- if (Diff > Tolerance) {
152+ if (!doValuesMatch (Actual[I], Expected[I], Tolerance,
153+ ValidationType::Epsilon)) {
156154 hlsl_test::LogErrorFmt (L" Mismatch at index %zu: actual=%f, expected=%f" ,
157155 I, static_cast <float >(Actual[I]),
158156 static_cast <float >(Expected[I]));
@@ -254,21 +252,10 @@ static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
254252 return std::vector<float >();
255253}
256254
257- static bool shouldSkipBecauseSM610Unsupported (ID3D12Device *Device) {
258- // Never skip in an HLK environment
259- #ifdef _HLK_CONF
260- return false ;
261- #endif
262-
263- // Don't skip if a device is available
264- if (Device)
265- return false ;
266-
267- // Skip GPU execution
255+ static void logCompiledButSkipping () {
268256 hlsl_test::LogCommentFmt (
269257 L" Shader compiled OK; skipping execution (no SM 6.10 device)" );
270258 WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
271- return true ;
272259}
273260
274261class DxilConf_SM610_LinAlg {
@@ -299,49 +286,34 @@ class DxilConf_SM610_LinAlg {
299286 TEST_METHOD (ElementAccess_Wave_16x16_F16);
300287
301288private:
302- bool createDevice ();
289+ D3D_SHADER_MODEL createDevice ();
303290
304291 CComPtr<ID3D12Device> D3DDevice;
305292 dxc::SpecificDllLoader DxcSupport;
306293 bool VerboseLogging = false ;
307294 bool EmulateTest = false ;
308295 bool Initialized = false ;
296+ bool CompileOnly = false ;
309297 std::optional<D3D12SDKSelector> D3D12SDK;
310298
311299 WEX::TestExecution::SetVerifyOutput VerifyOutput{
312300 WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
313301};
314302
315- // / Creates the device and setups the test scenario with the following variants
316- // / HLK build: Require SM6.10 supported fail otherwise
317- // / Non-HLK, no SM6.10 support: Compile shaders, then exit with skip
318- // / Non-HLK, SM6.10 support: Compile shaders and run full test
319- bool DxilConf_SM610_LinAlg::createDevice () {
320- bool FailIfRequirementsNotMet = false ;
321- #ifdef _HLK_CONF
322- FailIfRequirementsNotMet = true ;
323- #endif
303+ // / Attempts to create a device. If shaders are being emulated then a SM6.8
304+ // / device is attempted. Otherwise a SM6.10 device is attempted
305+ D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice () {
306+ if (EmulateTest) {
307+ if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_8, false ))
308+ return D3D_SHADER_MODEL_6_8;
324309
325- const bool SkipUnsupported = FailIfRequirementsNotMet;
326- if (!D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10,
327- SkipUnsupported)) {
328- if (FailIfRequirementsNotMet) {
329- hlsl_test::LogErrorFmt (
330- L" Device creation failed, resulting in test failure, since "
331- L" FailIfRequirementsNotMet is set. The expectation is that this "
332- L" test will only be executed if something has previously "
333- L" determined that the system meets the requirements of this "
334- L" test." );
335- return false ;
336- }
310+ return D3D_SHADER_MODEL_NONE;
337311 }
338312
339- if (EmulateTest) {
340- hlsl_test::LogWarningFmt (L" EmulateTest flag set. Tests are NOT REAL" );
341- return D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_8, false );
342- }
313+ if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false ))
314+ return D3D_SHADER_MODEL_6_10;
343315
344- return true ;
316+ return D3D_SHADER_MODEL_NONE ;
345317}
346318
347319bool DxilConf_SM610_LinAlg::setupClass () {
@@ -354,7 +326,26 @@ bool DxilConf_SM610_LinAlg::setupClass() {
354326 VerboseLogging);
355327 WEX::TestExecution::RuntimeParameters::TryGetValue (L" EmulateTest" ,
356328 EmulateTest);
357- return createDevice ();
329+ D3D_SHADER_MODEL SupportedSM = createDevice ();
330+
331+ if (EmulateTest) {
332+ hlsl_test::LogWarningFmt (L" EmulateTest flag set. Tests are NOT REAL" );
333+ if (SupportedSM != D3D_SHADER_MODEL_6_8) {
334+ hlsl_test::LogErrorFmt (
335+ L" Device creation failed. Expected a driver supporting SM6.8" );
336+ return false ;
337+ }
338+ }
339+
340+ #ifdef _HLK_CONF
341+ if (SupportedSM != D3D_SHADER_MODEL_6_10) {
342+ hlsl_test::LogErrorFmt (
343+ L" Device creation failed. Expected a driver supporting SM6.10" );
344+ return false ;
345+ }
346+ #endif
347+
348+ CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
358349 }
359350
360351 return true ;
@@ -366,11 +357,17 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
366357 if (D3DDevice && D3DDevice->GetDeviceRemovedReason () == S_OK)
367358 return true ;
368359
360+ // Device is expected to be null. No point in recreating it
361+ if (CompileOnly)
362+ return true ;
363+
369364 hlsl_test::LogCommentFmt (L" Device was lost!" );
370365 D3DDevice.Release ();
371366
372367 hlsl_test::LogCommentFmt (L" Recreating device" );
373- return createDevice ();
368+
369+ // !CompileOnly implies we expect it to succeeded
370+ return createDevice () != D3D_SHADER_MODEL_NONE;
374371}
375372
376373static const char LoadStoreShader[] = R"(
@@ -400,7 +397,8 @@ static const char LoadStoreShader[] = R"(
400397
401398static void runLoadStoreRoundtrip (ID3D12Device *Device,
402399 dxc::SpecificDllLoader &DxcSupport,
403- const MatrixParams &Params, bool Verbose) {
400+ const MatrixParams &Params, bool Verbose,
401+ bool CompileOnly) {
404402 const size_t NumElements = Params.totalElements ();
405403 const size_t BufferSize = Params.totalBytes ();
406404
@@ -417,8 +415,10 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
417415 // Always verify the shader compiles.
418416 compileShader (DxcSupport, LoadStoreShader, Target.c_str (), Args, Verbose);
419417
420- if (shouldSkipBecauseSM610Unsupported (Device))
418+ if (CompileOnly) {
419+ logCompiledButSkipping ();
421420 return ;
421+ }
422422
423423 auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
424424
@@ -457,7 +457,8 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
457457 Params.NumThreads = 4 ;
458458 Params.Enable16Bit = true ;
459459 Params.EmulateTest = EmulateTest;
460- runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging);
460+ runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging,
461+ CompileOnly);
461462}
462463
463464static const char SplatStoreShader[] = R"(
@@ -493,7 +494,7 @@ static const char SplatStoreShader[] = R"(
493494static void runSplatStore (ID3D12Device *Device,
494495 dxc::SpecificDllLoader &DxcSupport,
495496 const MatrixParams &Params, float FillValue,
496- bool Verbose) {
497+ bool Verbose, bool CompileOnly ) {
497498 const size_t NumElements = Params.totalElements ();
498499 const size_t BufferSize = Params.totalBytes ();
499500 std::string Target = " cs_6_10" ;
@@ -508,8 +509,10 @@ static void runSplatStore(ID3D12Device *Device,
508509 // Always verify the shader compiles.
509510 compileShader (DxcSupport, SplatStoreShader, Target.c_str (), Args, Verbose);
510511
511- if (shouldSkipBecauseSM610Unsupported (Device))
512+ if (CompileOnly) {
513+ logCompiledButSkipping ();
512514 return ;
515+ }
513516
514517 auto Expected = makeExpected (Params.CompType , NumElements, FillValue, false );
515518
@@ -538,7 +541,8 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
538541 Params.NumThreads = 4 ;
539542 Params.Enable16Bit = true ;
540543 Params.EmulateTest = EmulateTest;
541- runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
544+ runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging,
545+ CompileOnly);
542546}
543547
544548static const char ElementAccessShader[] = R"(
@@ -598,7 +602,8 @@ static const char ElementAccessShader[] = R"(
598602
599603static void runElementAccess (ID3D12Device *Device,
600604 dxc::SpecificDllLoader &DxcSupport,
601- const MatrixParams &Params, bool Verbose) {
605+ const MatrixParams &Params, bool Verbose,
606+ bool CompileOnly) {
602607 const size_t NumElements = Params.totalElements ();
603608 const size_t NumThreads = Params.NumThreads ;
604609 const size_t InputBufSize = Params.totalBytes ();
@@ -621,8 +626,10 @@ static void runElementAccess(ID3D12Device *Device,
621626
622627 compileShader (DxcSupport, ElementAccessShader, Target.c_str (), Args, Verbose);
623628
624- if (shouldSkipBecauseSM610Unsupported (Device))
629+ if (CompileOnly) {
630+ logCompiledButSkipping ();
625631 return ;
632+ }
626633
627634 auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
628635
@@ -673,7 +680,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
673680 Params.NumThreads = 4 ;
674681 Params.Enable16Bit = true ;
675682 Params.EmulateTest = EmulateTest;
676- runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
683+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly );
677684}
678685
679686} // namespace LinAlg
0 commit comments