Skip to content

Commit 2c3b284

Browse files
alsepkowCopilot
andcommitted
NFC: Move shared validation helpers into HLSLTestDataTypes.h
Move isFloatingPointType, ValidationType, ValidationConfig, DefaultValidation, StrictValidation, and doValuesMatch overloads from LongVectors.cpp and LinearAlgebra.cpp into the shared header. Both files now use 'using' declarations to reference them. Co-authored-by: Copilot <[email protected]>
1 parent 17506cb commit 2c3b284

3 files changed

Lines changed: 121 additions & 185 deletions

File tree

tools/clang/unittests/HLSLExec/HLSLTestDataTypes.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include <cstdint>
66
#include <limits>
77
#include <ostream>
8+
#include <type_traits>
89

910
#include <DirectXMath.h>
1011
#include <DirectXPackedVector.h>
1112

1213
#include "dxc/Support/Global.h"
14+
#include "HlslTestUtils.h"
1315

1416
// Shared HLSL type wrappers for use in execution tests.
1517
// These types bridge the gap between C++ and HLSL type representations.
@@ -490,6 +492,112 @@ struct F8E5M2_t {
490492
}
491493
};
492494

495+
//
496+
// Shared type traits and validation infrastructure.
497+
//
498+
499+
template <typename T> constexpr bool isFloatingPointType() {
500+
return std::is_same_v<T, float> || std::is_same_v<T, double> ||
501+
std::is_same_v<T, HLSLHalf_t>;
502+
}
503+
504+
enum class ValidationType {
505+
Epsilon,
506+
Ulp,
507+
};
508+
509+
struct ValidationConfig {
510+
double Tolerance = 0.0;
511+
ValidationType Type = ValidationType::Epsilon;
512+
513+
static ValidationConfig Epsilon(double Tolerance) {
514+
return ValidationConfig{Tolerance, ValidationType::Epsilon};
515+
}
516+
517+
static ValidationConfig Ulp(double Tolerance) {
518+
return ValidationConfig{Tolerance, ValidationType::Ulp};
519+
}
520+
};
521+
522+
// Default validation: ULP for floating point, exact for integers.
523+
template <typename T> struct DefaultValidation {
524+
ValidationConfig ValidationConfig;
525+
526+
DefaultValidation() {
527+
if constexpr (isFloatingPointType<T>())
528+
ValidationConfig = ValidationConfig::Ulp(1.0f);
529+
}
530+
};
531+
532+
// Strict validation: exact match by default.
533+
struct StrictValidation {
534+
ValidationConfig ValidationConfig;
535+
};
536+
537+
//
538+
// Value comparison overloads used by both LongVector and LinearAlgebra tests.
539+
//
540+
541+
template <typename T>
542+
inline bool doValuesMatch(T A, T B, double Tolerance, ValidationType) {
543+
if (Tolerance == 0.0)
544+
return A == B;
545+
546+
T Diff = A > B ? A - B : B - A;
547+
return Diff <= Tolerance;
548+
}
549+
550+
inline bool doValuesMatch(HLSLBool_t A, HLSLBool_t B, double,
551+
ValidationType) {
552+
return A == B;
553+
}
554+
555+
inline bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
556+
ValidationType VType) {
557+
switch (VType) {
558+
case ValidationType::Epsilon:
559+
return CompareHalfEpsilon(A.Val, B.Val, static_cast<float>(Tolerance));
560+
case ValidationType::Ulp:
561+
return CompareHalfULP(A.Val, B.Val, static_cast<float>(Tolerance));
562+
default:
563+
hlsl_test::LogErrorFmt(
564+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
565+
return false;
566+
}
567+
}
568+
569+
inline bool doValuesMatch(float A, float B, double Tolerance,
570+
ValidationType VType) {
571+
switch (VType) {
572+
case ValidationType::Epsilon:
573+
return CompareFloatEpsilon(A, B, static_cast<float>(Tolerance));
574+
case ValidationType::Ulp: {
575+
const int IntTolerance = static_cast<int>(Tolerance);
576+
return CompareFloatULP(A, B, IntTolerance);
577+
}
578+
default:
579+
hlsl_test::LogErrorFmt(
580+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
581+
return false;
582+
}
583+
}
584+
585+
inline bool doValuesMatch(double A, double B, double Tolerance,
586+
ValidationType VType) {
587+
switch (VType) {
588+
case ValidationType::Epsilon:
589+
return CompareDoubleEpsilon(A, B, Tolerance);
590+
case ValidationType::Ulp: {
591+
const int64_t IntTolerance = static_cast<int64_t>(Tolerance);
592+
return CompareDoubleULP(A, B, IntTolerance);
593+
}
594+
default:
595+
hlsl_test::LogErrorFmt(
596+
L"Invalid ValidationType. Expecting Epsilon or ULP.");
597+
return false;
598+
}
599+
}
600+
493601
} // namespace HLSLTestDataTypes
494602

495603
#endif // HLSLTESTDATATYPES_H

tools/clang/unittests/HLSLExec/LinearAlgebra.cpp

Lines changed: 7 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
#include <type_traits>
2323
#include <vector>
2424

25-
using namespace HLSLTestDataTypes; // For HLSLHalf_t, HLSLBool_t
26-
2725
namespace LinearAlgebra {
2826

2927
//
@@ -89,89 +87,13 @@ DATA_TYPE(uint32_t, "uint", "ComponentType::U32", 4, false)
8987

9088
#undef DATA_TYPE
9189

92-
template <typename T> constexpr bool isFloatingPointType() {
93-
return std::is_same_v<T, float> || std::is_same_v<T, double> ||
94-
std::is_same_v<T, HLSLHalf_t>;
95-
}
96-
97-
//
98-
// Validation
99-
//
100-
101-
enum class ValidationType { Epsilon, Ulp };
102-
103-
struct ValidationConfig {
104-
double Tolerance = 0.0;
105-
ValidationType Type = ValidationType::Epsilon;
106-
107-
static ValidationConfig Epsilon(double Tol) {
108-
return {Tol, ValidationType::Epsilon};
109-
}
110-
111-
static ValidationConfig Ulp(double Tol) { return {Tol, ValidationType::Ulp}; }
112-
};
113-
114-
// Default validation: ULP for floating point, exact for integers.
115-
template <typename T> struct DefaultValidation {
116-
ValidationConfig ValidationConfig;
117-
118-
DefaultValidation() {
119-
if constexpr (isFloatingPointType<T>())
120-
ValidationConfig = ValidationConfig::Ulp(1.0);
121-
}
122-
};
123-
124-
// Strict validation: exact match.
125-
struct StrictValidation {
126-
ValidationConfig ValidationConfig;
127-
};
128-
129-
//
130-
// Value comparison overloads following LongVector patterns.
131-
//
132-
133-
template <typename T>
134-
bool doValuesMatch(T A, T B, double Tolerance, ValidationType) {
135-
if (Tolerance == 0.0)
136-
return A == B;
137-
138-
T Diff = A > B ? A - B : B - A;
139-
return Diff <= Tolerance;
140-
}
141-
142-
bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
143-
ValidationType VType) {
144-
switch (VType) {
145-
case ValidationType::Epsilon:
146-
return CompareHalfEpsilon(A.Val, B.Val, static_cast<float>(Tolerance));
147-
case ValidationType::Ulp:
148-
return CompareHalfULP(A.Val, B.Val, static_cast<float>(Tolerance));
149-
default:
150-
return false;
151-
}
152-
}
153-
154-
bool doValuesMatch(float A, float B, double Tolerance, ValidationType VType) {
155-
switch (VType) {
156-
case ValidationType::Epsilon:
157-
return CompareFloatEpsilon(A, B, static_cast<float>(Tolerance));
158-
case ValidationType::Ulp:
159-
return CompareFloatULP(A, B, static_cast<int>(Tolerance));
160-
default:
161-
return false;
162-
}
163-
}
164-
165-
bool doValuesMatch(double A, double B, double Tolerance, ValidationType VType) {
166-
switch (VType) {
167-
case ValidationType::Epsilon:
168-
return CompareDoubleEpsilon(A, B, Tolerance);
169-
case ValidationType::Ulp:
170-
return CompareDoubleULP(A, B, static_cast<int64_t>(Tolerance));
171-
default:
172-
return false;
173-
}
174-
}
90+
using HLSLTestDataTypes::isFloatingPointType;
91+
using HLSLTestDataTypes::ValidationType;
92+
using HLSLTestDataTypes::ValidationConfig;
93+
using HLSLTestDataTypes::DefaultValidation;
94+
using HLSLTestDataTypes::StrictValidation;
95+
using HLSLTestDataTypes::doValuesMatch;
96+
using HLSLTestDataTypes::HLSLHalf_t;
17597

17698
template <typename T>
17799
bool doVectorsMatch(const std::vector<T> &Actual,

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 6 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ DATA_TYPE(double, "double", 8)
6464

6565
#undef DATA_TYPE
6666

67-
template <typename T> constexpr bool isFloatingPointType() {
68-
return std::is_same_v<T, float> || std::is_same_v<T, double> ||
69-
std::is_same_v<T, HLSLHalf_t>;
70-
}
67+
using HLSLTestDataTypes::isFloatingPointType;
68+
using HLSLTestDataTypes::ValidationType;
69+
using HLSLTestDataTypes::ValidationConfig;
70+
using HLSLTestDataTypes::DefaultValidation;
71+
using HLSLTestDataTypes::StrictValidation;
72+
using HLSLTestDataTypes::doValuesMatch;
7173

7274
//
7375
// Operation Types
@@ -186,72 +188,6 @@ void logLongVector(const std::vector<T> &Values, const std::wstring &Name) {
186188
hlsl_test::LogCommentFmt(Wss.str().c_str());
187189
}
188190

189-
enum class ValidationType {
190-
Epsilon,
191-
Ulp,
192-
};
193-
194-
template <typename T>
195-
bool doValuesMatch(T A, T B, double Tolerance, ValidationType) {
196-
if (Tolerance == 0.0)
197-
return A == B;
198-
199-
T Diff = A > B ? A - B : B - A;
200-
return Diff <= Tolerance;
201-
}
202-
203-
bool doValuesMatch(HLSLBool_t A, HLSLBool_t B, double, ValidationType) {
204-
return A == B;
205-
}
206-
207-
bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
208-
ValidationType ValidationType) {
209-
switch (ValidationType) {
210-
case ValidationType::Epsilon:
211-
return CompareHalfEpsilon(A.Val, B.Val, static_cast<float>(Tolerance));
212-
case ValidationType::Ulp:
213-
return CompareHalfULP(A.Val, B.Val, static_cast<float>(Tolerance));
214-
default:
215-
hlsl_test::LogErrorFmt(
216-
L"Invalid ValidationType. Expecting Epsilon or ULP.");
217-
return false;
218-
}
219-
}
220-
221-
bool doValuesMatch(float A, float B, double Tolerance,
222-
ValidationType ValidationType) {
223-
switch (ValidationType) {
224-
case ValidationType::Epsilon:
225-
return CompareFloatEpsilon(A, B, static_cast<float>(Tolerance));
226-
case ValidationType::Ulp: {
227-
// Tolerance is in ULPs. Convert to int for the comparison.
228-
const int IntTolerance = static_cast<int>(Tolerance);
229-
return CompareFloatULP(A, B, IntTolerance);
230-
};
231-
default:
232-
hlsl_test::LogErrorFmt(
233-
L"Invalid ValidationType. Expecting Epsilon or ULP.");
234-
return false;
235-
}
236-
}
237-
238-
bool doValuesMatch(double A, double B, double Tolerance,
239-
ValidationType ValidationType) {
240-
switch (ValidationType) {
241-
case ValidationType::Epsilon:
242-
return CompareDoubleEpsilon(A, B, Tolerance);
243-
case ValidationType::Ulp: {
244-
// Tolerance is in ULPs. Convert to int64_t for the comparison.
245-
const int64_t IntTolerance = static_cast<int64_t>(Tolerance);
246-
return CompareDoubleULP(A, B, IntTolerance);
247-
};
248-
default:
249-
hlsl_test::LogErrorFmt(
250-
L"Invalid ValidationType. Expecting Epsilon or ULP.");
251-
return false;
252-
}
253-
}
254-
255191
template <typename T>
256192
bool doVectorsMatch(const std::vector<T> &ActualValues,
257193
const std::vector<T> &ExpectedValues, double Tolerance,
@@ -563,19 +499,6 @@ InputSets<T> buildTestInputs(size_t VectorSize, const InputSet OpInputSets[3],
563499
return Inputs;
564500
}
565501

566-
struct ValidationConfig {
567-
double Tolerance = 0.0;
568-
ValidationType Type = ValidationType::Epsilon;
569-
570-
static ValidationConfig Epsilon(double Tolerance) {
571-
return ValidationConfig{Tolerance, ValidationType::Epsilon};
572-
}
573-
574-
static ValidationConfig Ulp(double Tolerance) {
575-
return ValidationConfig{Tolerance, ValidationType::Ulp};
576-
}
577-
};
578-
579502
template <typename T, typename OUT_TYPE>
580503
void runAndVerify(
581504
ID3D12Device *D3DDevice, bool VerboseLogging, const Operation &Operation,
@@ -614,23 +537,6 @@ template <OpType OP, typename T, size_t Arity> struct Op;
614537
// member functions.
615538
template <OpType OP, typename T> struct ExpectedBuilder;
616539

617-
// Default Validation configuration - ULP for floating point types, exact
618-
// matches for everything else.
619-
template <typename T> struct DefaultValidation {
620-
ValidationConfig ValidationConfig;
621-
622-
DefaultValidation() {
623-
if constexpr (isFloatingPointType<T>())
624-
ValidationConfig = ValidationConfig::Ulp(1.0f);
625-
}
626-
};
627-
628-
// Strict Validation - Defaults to exact matches.
629-
// Tolerance can be set to a non-zero value to allow for a wider range.
630-
struct StrictValidation {
631-
ValidationConfig ValidationConfig;
632-
};
633-
634540
// Macros to build up common patterns of Op definitions
635541

636542
#define OP_1(OP, VALIDATION, IMPL) \

0 commit comments

Comments
 (0)