Skip to content

Commit ec6ba97

Browse files
[SM6.10][HLK][LinAlg] Address lingering feedback (#8343)
Fixes #8332 Resolve lingering feedback from the previous infrastructure PR --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 6f0743c commit ec6ba97

3 files changed

Lines changed: 169 additions & 155 deletions

File tree

tools/clang/unittests/HLSLExec/HlslTestDataTypes.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
#include <ostream>
1212

13+
#include <windows.h>
14+
1315
#include <DirectXMath.h>
1416
#include <DirectXPackedVector.h>
1517

18+
#include "HlslTestUtils.h"
1619
#include "dxc/Support/Global.h"
1720

1821
// These types bridge the gap between C++ and HLSL type representations.
@@ -460,6 +463,101 @@ struct HLSLMin16Uint_t {
460463
uint32_t Val;
461464
};
462465

466+
enum class ValidationType {
467+
Epsilon,
468+
Ulp,
469+
};
470+
471+
template <typename T>
472+
inline bool doValuesMatch(T A, T B, double Tolerance, ValidationType) {
473+
if (Tolerance == 0.0)
474+
return A == B;
475+
476+
T Diff = A > B ? A - B : B - A;
477+
return Diff <= Tolerance;
478+
}
479+
480+
inline bool doValuesMatch(HLSLBool_t A, HLSLBool_t B, double, ValidationType) {
481+
return A == B;
482+
}
483+
484+
inline bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
485+
ValidationType ValidationType) {
486+
switch (ValidationType) {
487+
case ValidationType::Epsilon:
488+
return CompareHalfEpsilon(A.Val, B.Val, static_cast<float>(Tolerance));
489+
case ValidationType::Ulp:
490+
return CompareHalfULP(A.Val, B.Val, static_cast<float>(Tolerance));
491+
default:
492+
hlsl_test::LogErrorFmt(
493+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
494+
return false;
495+
}
496+
}
497+
498+
// Min precision float comparison: convert to half and compare in fp16 space.
499+
// This reuses the same tolerance values as HLSLHalf_t. Min precision is at
500+
// least 16-bit, so fp16 tolerances are an upper bound for all cases.
501+
inline bool doValuesMatch(HLSLMin16Float_t A, HLSLMin16Float_t B,
502+
double Tolerance, ValidationType ValidationType) {
503+
auto HalfA = DirectX::PackedVector::XMConvertFloatToHalf(A.Val);
504+
auto HalfB = DirectX::PackedVector::XMConvertFloatToHalf(B.Val);
505+
switch (ValidationType) {
506+
case ValidationType::Epsilon:
507+
return CompareHalfEpsilon(HalfA, HalfB, static_cast<float>(Tolerance));
508+
case ValidationType::Ulp:
509+
return CompareHalfULP(HalfA, HalfB, static_cast<float>(Tolerance));
510+
default:
511+
hlsl_test::LogErrorFmt(
512+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
513+
return false;
514+
}
515+
}
516+
517+
inline bool doValuesMatch(HLSLMin16Int_t A, HLSLMin16Int_t B, double,
518+
ValidationType) {
519+
return A == B;
520+
}
521+
522+
inline bool doValuesMatch(HLSLMin16Uint_t A, HLSLMin16Uint_t B, double,
523+
ValidationType) {
524+
return A == B;
525+
}
526+
527+
inline bool doValuesMatch(float A, float B, double Tolerance,
528+
ValidationType ValidationType) {
529+
switch (ValidationType) {
530+
case ValidationType::Epsilon:
531+
return CompareFloatEpsilon(A, B, static_cast<float>(Tolerance));
532+
case ValidationType::Ulp: {
533+
// Tolerance is in ULPs. Convert to int for the comparison.
534+
const int IntTolerance = static_cast<int>(Tolerance);
535+
return CompareFloatULP(A, B, IntTolerance);
536+
};
537+
default:
538+
hlsl_test::LogErrorFmt(
539+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
540+
return false;
541+
}
542+
}
543+
544+
inline bool doValuesMatch(double A, double B, double Tolerance,
545+
ValidationType ValidationType) {
546+
switch (ValidationType) {
547+
case ValidationType::Epsilon:
548+
return CompareDoubleEpsilon(A, B, Tolerance);
549+
case ValidationType::Ulp: {
550+
// Tolerance is in ULPs. Convert to int64_t for the comparison.
551+
const int64_t IntTolerance = static_cast<int64_t>(Tolerance);
552+
return CompareDoubleULP(A, B, IntTolerance);
553+
};
554+
default:
555+
hlsl_test::LogErrorFmt(
556+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
557+
return false;
558+
}
559+
}
560+
463561
} // namespace HLSLTestDataTypes
464562

465563
#endif

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ using hlsl::DXIL::LinalgMatrixLayout;
3838
using hlsl::DXIL::MatrixScope;
3939
using hlsl::DXIL::MatrixUse;
4040

41+
using HLSLTestDataTypes::doValuesMatch;
4142
using HLSLTestDataTypes::HLSLHalf_t;
43+
using HLSLTestDataTypes::ValidationType;
4244

4345
using VariantCompType = std::variant<std::vector<float>, std::vector<int32_t>,
4446
std::vector<HLSLHalf_t>>;
@@ -111,10 +113,8 @@ static bool verifyFloatBuffer(const float *Actual, const float *Expected,
111113
float Tolerance = 0.0f) {
112114
bool Success = true;
113115
for (size_t I = 0; I < Count; I++) {
114-
float Diff = Actual[I] - Expected[I];
115-
if (Diff < 0)
116-
Diff = -Diff;
117-
if (Diff > Tolerance) {
116+
if (!doValuesMatch(Actual[I], Expected[I], Tolerance,
117+
ValidationType::Epsilon)) {
118118
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%f, expected=%f",
119119
I, static_cast<double>(Actual[I]),
120120
static_cast<double>(Expected[I]));
@@ -132,7 +132,7 @@ static bool verifyIntBuffer(const int32_t *Actual, const int32_t *Expected,
132132
size_t Count, bool Verbose) {
133133
bool Success = true;
134134
for (size_t I = 0; I < Count; I++) {
135-
if (Actual[I] != Expected[I]) {
135+
if (!doValuesMatch(Actual[I], Expected[I], 0.0, ValidationType::Epsilon)) {
136136
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%d, expected=%d",
137137
I, Actual[I], Expected[I]);
138138
Success = false;
@@ -149,10 +149,8 @@ static bool verifyHalfBuffer(const HLSLHalf_t *Actual,
149149
bool Verbose, HLSLHalf_t Tolerance = 0.0f) {
150150
bool Success = true;
151151
for (size_t I = 0; I < Count; I++) {
152-
HLSLHalf_t Diff = Actual[I] - Expected[I];
153-
if (Diff < 0.0f)
154-
Diff = -Diff;
155-
if (Diff > Tolerance) {
152+
if (!doValuesMatch(Actual[I], Expected[I], Tolerance,
153+
ValidationType::Epsilon)) {
156154
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%f, expected=%f",
157155
I, static_cast<float>(Actual[I]),
158156
static_cast<float>(Expected[I]));
@@ -254,21 +252,10 @@ static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
254252
return std::vector<float>();
255253
}
256254

257-
static bool shouldSkipBecauseSM610Unsupported(ID3D12Device *Device) {
258-
// Never skip in an HLK environment
259-
#ifdef _HLK_CONF
260-
return false;
261-
#endif
262-
263-
// Don't skip if a device is available
264-
if (Device)
265-
return false;
266-
267-
// Skip GPU execution
255+
static void logCompiledButSkipping() {
268256
hlsl_test::LogCommentFmt(
269257
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
270258
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
271-
return true;
272259
}
273260

274261
class DxilConf_SM610_LinAlg {
@@ -299,49 +286,34 @@ class DxilConf_SM610_LinAlg {
299286
TEST_METHOD(ElementAccess_Wave_16x16_F16);
300287

301288
private:
302-
bool createDevice();
289+
D3D_SHADER_MODEL createDevice();
303290

304291
CComPtr<ID3D12Device> D3DDevice;
305292
dxc::SpecificDllLoader DxcSupport;
306293
bool VerboseLogging = false;
307294
bool EmulateTest = false;
308295
bool Initialized = false;
296+
bool CompileOnly = false;
309297
std::optional<D3D12SDKSelector> D3D12SDK;
310298

311299
WEX::TestExecution::SetVerifyOutput VerifyOutput{
312300
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
313301
};
314302

315-
/// Creates the device and setups the test scenario with the following variants
316-
/// HLK build: Require SM6.10 supported fail otherwise
317-
/// Non-HLK, no SM6.10 support: Compile shaders, then exit with skip
318-
/// Non-HLK, SM6.10 support: Compile shaders and run full test
319-
bool DxilConf_SM610_LinAlg::createDevice() {
320-
bool FailIfRequirementsNotMet = false;
321-
#ifdef _HLK_CONF
322-
FailIfRequirementsNotMet = true;
323-
#endif
303+
/// Attempts to create a device. If shaders are being emulated then a SM6.8
304+
/// device is attempted. Otherwise a SM6.10 device is attempted
305+
D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() {
306+
if (EmulateTest) {
307+
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false))
308+
return D3D_SHADER_MODEL_6_8;
324309

325-
const bool SkipUnsupported = FailIfRequirementsNotMet;
326-
if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10,
327-
SkipUnsupported)) {
328-
if (FailIfRequirementsNotMet) {
329-
hlsl_test::LogErrorFmt(
330-
L"Device creation failed, resulting in test failure, since "
331-
L"FailIfRequirementsNotMet is set. The expectation is that this "
332-
L"test will only be executed if something has previously "
333-
L"determined that the system meets the requirements of this "
334-
L"test.");
335-
return false;
336-
}
310+
return D3D_SHADER_MODEL_NONE;
337311
}
338312

339-
if (EmulateTest) {
340-
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
341-
return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false);
342-
}
313+
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false))
314+
return D3D_SHADER_MODEL_6_10;
343315

344-
return true;
316+
return D3D_SHADER_MODEL_NONE;
345317
}
346318

347319
bool DxilConf_SM610_LinAlg::setupClass() {
@@ -354,7 +326,26 @@ bool DxilConf_SM610_LinAlg::setupClass() {
354326
VerboseLogging);
355327
WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest",
356328
EmulateTest);
357-
return createDevice();
329+
D3D_SHADER_MODEL SupportedSM = createDevice();
330+
331+
if (EmulateTest) {
332+
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
333+
if (SupportedSM != D3D_SHADER_MODEL_6_8) {
334+
hlsl_test::LogErrorFmt(
335+
L"Device creation failed. Expected a driver supporting SM6.8");
336+
return false;
337+
}
338+
}
339+
340+
#ifdef _HLK_CONF
341+
if (SupportedSM != D3D_SHADER_MODEL_6_10) {
342+
hlsl_test::LogErrorFmt(
343+
L"Device creation failed. Expected a driver supporting SM6.10");
344+
return false;
345+
}
346+
#endif
347+
348+
CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
358349
}
359350

360351
return true;
@@ -366,11 +357,17 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
366357
if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK)
367358
return true;
368359

360+
// Device is expected to be null. No point in recreating it
361+
if (CompileOnly)
362+
return true;
363+
369364
hlsl_test::LogCommentFmt(L"Device was lost!");
370365
D3DDevice.Release();
371366

372367
hlsl_test::LogCommentFmt(L"Recreating device");
373-
return createDevice();
368+
369+
// !CompileOnly implies we expect it to succeeded
370+
return createDevice() != D3D_SHADER_MODEL_NONE;
374371
}
375372

376373
static const char LoadStoreShader[] = R"(
@@ -400,7 +397,8 @@ static const char LoadStoreShader[] = R"(
400397

401398
static void runLoadStoreRoundtrip(ID3D12Device *Device,
402399
dxc::SpecificDllLoader &DxcSupport,
403-
const MatrixParams &Params, bool Verbose) {
400+
const MatrixParams &Params, bool Verbose,
401+
bool CompileOnly) {
404402
const size_t NumElements = Params.totalElements();
405403
const size_t BufferSize = Params.totalBytes();
406404

@@ -417,8 +415,10 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
417415
// Always verify the shader compiles.
418416
compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose);
419417

420-
if (shouldSkipBecauseSM610Unsupported(Device))
418+
if (CompileOnly) {
419+
logCompiledButSkipping();
421420
return;
421+
}
422422

423423
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
424424

@@ -457,7 +457,8 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
457457
Params.NumThreads = 4;
458458
Params.Enable16Bit = true;
459459
Params.EmulateTest = EmulateTest;
460-
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging);
460+
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging,
461+
CompileOnly);
461462
}
462463

463464
static const char SplatStoreShader[] = R"(
@@ -493,7 +494,7 @@ static const char SplatStoreShader[] = R"(
493494
static void runSplatStore(ID3D12Device *Device,
494495
dxc::SpecificDllLoader &DxcSupport,
495496
const MatrixParams &Params, float FillValue,
496-
bool Verbose) {
497+
bool Verbose, bool CompileOnly) {
497498
const size_t NumElements = Params.totalElements();
498499
const size_t BufferSize = Params.totalBytes();
499500
std::string Target = "cs_6_10";
@@ -508,8 +509,10 @@ static void runSplatStore(ID3D12Device *Device,
508509
// Always verify the shader compiles.
509510
compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose);
510511

511-
if (shouldSkipBecauseSM610Unsupported(Device))
512+
if (CompileOnly) {
513+
logCompiledButSkipping();
512514
return;
515+
}
513516

514517
auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);
515518

@@ -538,7 +541,8 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
538541
Params.NumThreads = 4;
539542
Params.Enable16Bit = true;
540543
Params.EmulateTest = EmulateTest;
541-
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
544+
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging,
545+
CompileOnly);
542546
}
543547

544548
static const char ElementAccessShader[] = R"(
@@ -598,7 +602,8 @@ static const char ElementAccessShader[] = R"(
598602

599603
static void runElementAccess(ID3D12Device *Device,
600604
dxc::SpecificDllLoader &DxcSupport,
601-
const MatrixParams &Params, bool Verbose) {
605+
const MatrixParams &Params, bool Verbose,
606+
bool CompileOnly) {
602607
const size_t NumElements = Params.totalElements();
603608
const size_t NumThreads = Params.NumThreads;
604609
const size_t InputBufSize = Params.totalBytes();
@@ -621,8 +626,10 @@ static void runElementAccess(ID3D12Device *Device,
621626

622627
compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);
623628

624-
if (shouldSkipBecauseSM610Unsupported(Device))
629+
if (CompileOnly) {
630+
logCompiledButSkipping();
625631
return;
632+
}
626633

627634
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
628635

@@ -673,7 +680,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
673680
Params.NumThreads = 4;
674681
Params.Enable16Bit = true;
675682
Params.EmulateTest = EmulateTest;
676-
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
683+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
677684
}
678685

679686
} // namespace LinAlg

0 commit comments

Comments
 (0)