Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 55 additions & 50 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,10 @@ static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
return std::vector<float>();
}

static bool shouldSkipBecauseSM610Unsupported(ID3D12Device *Device) {
// Never skip in an HLK environment
#ifdef _HLK_CONF
return false;
#endif

// Don't skip if a device is available
if (Device)
return false;

// Skip GPU execution
static void logCompiledButSkipping() {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return true;
}

class DxilConf_SM610_LinAlg {
Expand Down Expand Up @@ -299,49 +288,34 @@ class DxilConf_SM610_LinAlg {
TEST_METHOD(ElementAccess_Wave_16x16_F16);

private:
bool createDevice();
D3D_SHADER_MODEL createDevice();

CComPtr<ID3D12Device> D3DDevice;
dxc::SpecificDllLoader DxcSupport;
bool VerboseLogging = false;
bool EmulateTest = false;
bool Initialized = false;
bool CompileOnly = false;
std::optional<D3D12SDKSelector> D3D12SDK;

WEX::TestExecution::SetVerifyOutput VerifyOutput{
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
};

/// Creates the device and setups the test scenario with the following variants
/// HLK build: Require SM6.10 supported fail otherwise
/// Non-HLK, no SM6.10 support: Compile shaders, then exit with skip
/// Non-HLK, SM6.10 support: Compile shaders and run full test
bool DxilConf_SM610_LinAlg::createDevice() {
bool FailIfRequirementsNotMet = false;
#ifdef _HLK_CONF
FailIfRequirementsNotMet = true;
#endif
/// Attempts to create a device. If tests are being emulated this an SM6.8
/// device is attempted. Durning normal execution SM6.10 is required.
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() {
if (EmulateTest) {
if(D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false))
return D3D_SHADER_MODEL_6_8;

const bool SkipUnsupported = FailIfRequirementsNotMet;
if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10,
SkipUnsupported)) {
if (FailIfRequirementsNotMet) {
hlsl_test::LogErrorFmt(
L"Device creation failed, resulting in test failure, since "
L"FailIfRequirementsNotMet is set. The expectation is that this "
L"test will only be executed if something has previously "
L"determined that the system meets the requirements of this "
L"test.");
return false;
}
return D3D_SHADER_MODEL_NONE;
}

if (EmulateTest) {
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false);
}
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false))
return D3D_SHADER_MODEL_6_10;

return true;
return D3D_SHADER_MODEL_NONE;
}

bool DxilConf_SM610_LinAlg::setupClass() {
Expand All @@ -354,7 +328,26 @@ bool DxilConf_SM610_LinAlg::setupClass() {
VerboseLogging);
WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest",
EmulateTest);
return createDevice();
D3D_SHADER_MODEL SupportedSM = createDevice();

if (EmulateTest) {
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
if (SupportedSM != D3D_SHADER_MODEL_6_8) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.8");
return false;
}
}

#ifdef _HLK_CONF
if (SupportedSM != D3D_SHADER_MODEL_6_10) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.10");
return false;
}
#endif

CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
}

return true;
Expand All @@ -366,11 +359,17 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK)
return true;

// Device is expected to be null. No point in recreating it
if (CompileOnly)
return true;

hlsl_test::LogCommentFmt(L"Device was lost!");
D3DDevice.Release();

hlsl_test::LogCommentFmt(L"Recreating device");
return createDevice();

// !CompileOnly implies we expect it to succeeded
return createDevice() != D3D_SHADER_MODEL_NONE;
}

static const char LoadStoreShader[] = R"(
Expand Down Expand Up @@ -400,7 +399,7 @@ static const char LoadStoreShader[] = R"(

static void runLoadStoreRoundtrip(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const MatrixParams &Params, bool Verbose, bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();

Expand All @@ -417,8 +416,10 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
// Always verify the shader compiles.
compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);

Expand Down Expand Up @@ -457,7 +458,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging);
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
}

static const char SplatStoreShader[] = R"(
Expand Down Expand Up @@ -493,7 +494,7 @@ static const char SplatStoreShader[] = R"(
static void runSplatStore(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, float FillValue,
bool Verbose) {
bool Verbose, bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();
std::string Target = "cs_6_10";
Expand All @@ -508,8 +509,10 @@ static void runSplatStore(ID3D12Device *Device,
// Always verify the shader compiles.
compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);

Expand Down Expand Up @@ -538,7 +541,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging, CompileOnly);
}

static const char ElementAccessShader[] = R"(
Expand Down Expand Up @@ -598,7 +601,7 @@ static const char ElementAccessShader[] = R"(

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const MatrixParams &Params, bool Verbose, bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
Expand All @@ -621,8 +624,10 @@ static void runElementAccess(ID3D12Device *Device,

compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);

Expand Down Expand Up @@ -673,7 +678,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
}

} // namespace LinAlg
Loading