Skip to content

Commit bb9b8f8

Browse files
committed
[SM6.10][HLK][LinAlg] Address lingering feedback
1 parent 8d34720 commit bb9b8f8

1 file changed

Lines changed: 55 additions & 50 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,10 @@ static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
254254
return std::vector<float>();
255255
}
256256

257-
static bool shouldSkipBecauseSM610Unsupported(ID3D12Device *Device) {
258-
// Never skip in an HLK environment
259-
#ifdef _HLK_CONF
260-
return false;
261-
#endif
262-
263-
// Don't skip if a device is available
264-
if (Device)
265-
return false;
266-
267-
// Skip GPU execution
257+
static void logCompiledButSkipping() {
268258
hlsl_test::LogCommentFmt(
269259
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
270260
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
271-
return true;
272261
}
273262

274263
class DxilConf_SM610_LinAlg {
@@ -299,49 +288,34 @@ class DxilConf_SM610_LinAlg {
299288
TEST_METHOD(ElementAccess_Wave_16x16_F16);
300289

301290
private:
302-
bool createDevice();
291+
D3D_SHADER_MODEL createDevice();
303292

304293
CComPtr<ID3D12Device> D3DDevice;
305294
dxc::SpecificDllLoader DxcSupport;
306295
bool VerboseLogging = false;
307296
bool EmulateTest = false;
308297
bool Initialized = false;
298+
bool CompileOnly = false;
309299
std::optional<D3D12SDKSelector> D3D12SDK;
310300

311301
WEX::TestExecution::SetVerifyOutput VerifyOutput{
312302
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
313303
};
314304

315-
/// Creates the device and setups the test scenario with the following variants
316-
/// HLK build: Require SM6.10 supported fail otherwise
317-
/// Non-HLK, no SM6.10 support: Compile shaders, then exit with skip
318-
/// Non-HLK, SM6.10 support: Compile shaders and run full test
319-
bool DxilConf_SM610_LinAlg::createDevice() {
320-
bool FailIfRequirementsNotMet = false;
321-
#ifdef _HLK_CONF
322-
FailIfRequirementsNotMet = true;
323-
#endif
305+
/// Attempts to create a device. If tests are being emulated this an SM6.8
306+
/// device is attempted. Durning normal execution SM6.10 is required.
307+
D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() {
308+
if (EmulateTest) {
309+
if(D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false))
310+
return D3D_SHADER_MODEL_6_8;
324311

325-
const bool SkipUnsupported = FailIfRequirementsNotMet;
326-
if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10,
327-
SkipUnsupported)) {
328-
if (FailIfRequirementsNotMet) {
329-
hlsl_test::LogErrorFmt(
330-
L"Device creation failed, resulting in test failure, since "
331-
L"FailIfRequirementsNotMet is set. The expectation is that this "
332-
L"test will only be executed if something has previously "
333-
L"determined that the system meets the requirements of this "
334-
L"test.");
335-
return false;
336-
}
312+
return D3D_SHADER_MODEL_NONE;
337313
}
338314

339-
if (EmulateTest) {
340-
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
341-
return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false);
342-
}
315+
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false))
316+
return D3D_SHADER_MODEL_6_10;
343317

344-
return true;
318+
return D3D_SHADER_MODEL_NONE;
345319
}
346320

347321
bool DxilConf_SM610_LinAlg::setupClass() {
@@ -354,7 +328,26 @@ bool DxilConf_SM610_LinAlg::setupClass() {
354328
VerboseLogging);
355329
WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest",
356330
EmulateTest);
357-
return createDevice();
331+
D3D_SHADER_MODEL SupportedSM = createDevice();
332+
333+
if (EmulateTest) {
334+
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
335+
if (SupportedSM != D3D_SHADER_MODEL_6_8) {
336+
hlsl_test::LogErrorFmt(
337+
L"Device creation failed. Expected a driver supporting SM6.8");
338+
return false;
339+
}
340+
}
341+
342+
#ifdef _HLK_CONF
343+
if (SupportedSM != D3D_SHADER_MODEL_6_10) {
344+
hlsl_test::LogErrorFmt(
345+
L"Device creation failed. Expected a driver supporting SM6.10");
346+
return false;
347+
}
348+
#endif
349+
350+
CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
358351
}
359352

360353
return true;
@@ -366,11 +359,17 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
366359
if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK)
367360
return true;
368361

362+
// Device is expected to be null. No point in recreating it
363+
if (CompileOnly)
364+
return true;
365+
369366
hlsl_test::LogCommentFmt(L"Device was lost!");
370367
D3DDevice.Release();
371368

372369
hlsl_test::LogCommentFmt(L"Recreating device");
373-
return createDevice();
370+
371+
// !CompileOnly implies we expect it to succeeded
372+
return createDevice() != D3D_SHADER_MODEL_NONE;
374373
}
375374

376375
static const char LoadStoreShader[] = R"(
@@ -400,7 +399,7 @@ static const char LoadStoreShader[] = R"(
400399

401400
static void runLoadStoreRoundtrip(ID3D12Device *Device,
402401
dxc::SpecificDllLoader &DxcSupport,
403-
const MatrixParams &Params, bool Verbose) {
402+
const MatrixParams &Params, bool Verbose, bool CompileOnly) {
404403
const size_t NumElements = Params.totalElements();
405404
const size_t BufferSize = Params.totalBytes();
406405

@@ -417,8 +416,10 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
417416
// Always verify the shader compiles.
418417
compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose);
419418

420-
if (shouldSkipBecauseSM610Unsupported(Device))
419+
if (CompileOnly) {
420+
logCompiledButSkipping();
421421
return;
422+
}
422423

423424
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
424425

@@ -457,7 +458,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
457458
Params.NumThreads = 4;
458459
Params.Enable16Bit = true;
459460
Params.EmulateTest = EmulateTest;
460-
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging);
461+
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
461462
}
462463

463464
static const char SplatStoreShader[] = R"(
@@ -493,7 +494,7 @@ static const char SplatStoreShader[] = R"(
493494
static void runSplatStore(ID3D12Device *Device,
494495
dxc::SpecificDllLoader &DxcSupport,
495496
const MatrixParams &Params, float FillValue,
496-
bool Verbose) {
497+
bool Verbose, bool CompileOnly) {
497498
const size_t NumElements = Params.totalElements();
498499
const size_t BufferSize = Params.totalBytes();
499500
std::string Target = "cs_6_10";
@@ -508,8 +509,10 @@ static void runSplatStore(ID3D12Device *Device,
508509
// Always verify the shader compiles.
509510
compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose);
510511

511-
if (shouldSkipBecauseSM610Unsupported(Device))
512+
if (CompileOnly) {
513+
logCompiledButSkipping();
512514
return;
515+
}
513516

514517
auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);
515518

@@ -538,7 +541,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
538541
Params.NumThreads = 4;
539542
Params.Enable16Bit = true;
540543
Params.EmulateTest = EmulateTest;
541-
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
544+
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging, CompileOnly);
542545
}
543546

544547
static const char ElementAccessShader[] = R"(
@@ -598,7 +601,7 @@ static const char ElementAccessShader[] = R"(
598601

599602
static void runElementAccess(ID3D12Device *Device,
600603
dxc::SpecificDllLoader &DxcSupport,
601-
const MatrixParams &Params, bool Verbose) {
604+
const MatrixParams &Params, bool Verbose, bool CompileOnly) {
602605
const size_t NumElements = Params.totalElements();
603606
const size_t NumThreads = Params.NumThreads;
604607
const size_t InputBufSize = Params.totalBytes();
@@ -621,8 +624,10 @@ static void runElementAccess(ID3D12Device *Device,
621624

622625
compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);
623626

624-
if (shouldSkipBecauseSM610Unsupported(Device))
627+
if (CompileOnly) {
628+
logCompiledButSkipping();
625629
return;
630+
}
626631

627632
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
628633

@@ -673,7 +678,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
673678
Params.NumThreads = 4;
674679
Params.Enable16Bit = true;
675680
Params.EmulateTest = EmulateTest;
676-
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
681+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
677682
}
678683

679684
} // namespace LinAlg

0 commit comments

Comments
 (0)