1212
1313#include " HlslExecTestUtils.h"
1414
15+ #include < algorithm>
1516#include < array>
1617#include < bitset>
1718#include < iomanip>
@@ -175,37 +176,37 @@ enum class ValidationType {
175176};
176177
177178template <typename T>
178- bool doValuesMatch (T A, T B, float Tolerance, ValidationType) {
179- if (Tolerance == 0 .0f )
179+ bool doValuesMatch (T A, T B, double Tolerance, ValidationType) {
180+ if (Tolerance == 0.0 )
180181 return A == B;
181182
182183 T Diff = A > B ? A - B : B - A;
183184 return Diff <= Tolerance;
184185}
185186
186- bool doValuesMatch (HLSLBool_t A, HLSLBool_t B, float , ValidationType) {
187+ bool doValuesMatch (HLSLBool_t A, HLSLBool_t B, double , ValidationType) {
187188 return A == B;
188189}
189190
190- bool doValuesMatch (HLSLHalf_t A, HLSLHalf_t B, float Tolerance,
191+ bool doValuesMatch (HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
191192 ValidationType ValidationType) {
192193 switch (ValidationType) {
193194 case ValidationType::Epsilon:
194- return CompareHalfEpsilon (A.Val , B.Val , Tolerance);
195+ return CompareHalfEpsilon (A.Val , B.Val , static_cast < float >( Tolerance) );
195196 case ValidationType::Ulp:
196- return CompareHalfULP (A.Val , B.Val , Tolerance);
197+ return CompareHalfULP (A.Val , B.Val , static_cast < float >( Tolerance) );
197198 default :
198199 hlsl_test::LogErrorFmt (
199200 L" Invalid ValidationType. Expecting Epsilon or ULP." );
200201 return false ;
201202 }
202203}
203204
204- bool doValuesMatch (float A, float B, float Tolerance,
205+ bool doValuesMatch (float A, float B, double Tolerance,
205206 ValidationType ValidationType) {
206207 switch (ValidationType) {
207208 case ValidationType::Epsilon:
208- return CompareFloatEpsilon (A, B, Tolerance);
209+ return CompareFloatEpsilon (A, B, static_cast < float >( Tolerance) );
209210 case ValidationType::Ulp: {
210211 // Tolerance is in ULPs. Convert to int for the comparison.
211212 const int IntTolerance = static_cast <int >(Tolerance);
@@ -218,7 +219,7 @@ bool doValuesMatch(float A, float B, float Tolerance,
218219 }
219220}
220221
221- bool doValuesMatch (double A, double B, float Tolerance,
222+ bool doValuesMatch (double A, double B, double Tolerance,
222223 ValidationType ValidationType) {
223224 switch (ValidationType) {
224225 case ValidationType::Epsilon:
@@ -237,7 +238,7 @@ bool doValuesMatch(double A, double B, float Tolerance,
237238
238239template <typename T>
239240bool doVectorsMatch (const std::vector<T> &ActualValues,
240- const std::vector<T> &ExpectedValues, float Tolerance,
241+ const std::vector<T> &ExpectedValues, double Tolerance,
241242 ValidationType ValidationType, bool VerboseLogging) {
242243
243244 DXASSERT (
@@ -247,6 +248,11 @@ bool doVectorsMatch(const std::vector<T> &ActualValues,
247248 if (VerboseLogging) {
248249 logLongVector (ActualValues, L" ActualValues" );
249250 logLongVector (ExpectedValues, L" ExpectedValues" );
251+
252+ hlsl_test::LogCommentFmt (
253+ L" ValidationType: %s, Tolerance: %17g" ,
254+ ValidationType == ValidationType::Epsilon ? L" Epsilon" : L" ULP" ,
255+ Tolerance);
250256 }
251257
252258 // Stash mismatched indexes for easy failure logging later
@@ -534,14 +540,14 @@ InputSets<T> buildTestInputs(size_t VectorSize, const InputSet OpInputSets[3],
534540}
535541
536542struct ValidationConfig {
537- float Tolerance = 0 .0f ;
543+ double Tolerance = 0.0 ;
538544 ValidationType Type = ValidationType::Epsilon;
539545
540- static ValidationConfig Epsilon (float Tolerance) {
546+ static ValidationConfig Epsilon (double Tolerance) {
541547 return ValidationConfig{Tolerance, ValidationType::Epsilon};
542548 }
543549
544- static ValidationConfig Ulp (float Tolerance) {
550+ static ValidationConfig Ulp (double Tolerance) {
545551 return ValidationConfig{Tolerance, ValidationType::Ulp};
546552 }
547553};
@@ -593,7 +599,8 @@ template <typename T> struct DefaultValidation {
593599 }
594600};
595601
596- // Strict Validation - require exact matches for all types
602+ // Strict Validation - Defaults to exact matches.
603+ // Tolerance can be set to a non-zero value to allow for a wider range.
597604struct StrictValidation {
598605 ValidationConfig ValidationConfig;
599606};
@@ -935,7 +942,7 @@ struct Op<OpType::AsUint_SplitDouble, double, 1> : StrictValidation {};
935942// values.
936943template <> struct ExpectedBuilder <OpType::AsUint_SplitDouble, double > {
937944 static std::vector<uint32_t >
938- buildExpected (Op<OpType::AsUint_SplitDouble, double , 1 >,
945+ buildExpected (Op<OpType::AsUint_SplitDouble, double , 1 > & ,
939946 const InputSets<double > &Inputs) {
940947 DXASSERT_NOMSG (Inputs.size () == 1 );
941948
@@ -1009,7 +1016,7 @@ DEFAULT_OP_1(OpType::Log2, (std::log2(A)));
10091016template <> struct Op <OpType::Frexp, float , 1 > : DefaultValidation<float > {};
10101017
10111018template <> struct ExpectedBuilder <OpType::Frexp, float > {
1012- static std::vector<float > buildExpected (Op<OpType::Frexp, float , 1 >,
1019+ static std::vector<float > buildExpected (Op<OpType::Frexp, float , 1 > & ,
10131020 const InputSets<float > &Inputs) {
10141021 DXASSERT_NOMSG (Inputs.size () == 1 );
10151022
@@ -1079,7 +1086,7 @@ OP_3(OpType::Select, StrictValidation, (static_cast<bool>(A) ? B : C));
10791086#define REDUCTION_OP (OP, STDFUNC ) \
10801087 template <typename T> struct Op <OP, T, 1 > : StrictValidation {}; \
10811088 template <typename T> struct ExpectedBuilder <OP, T> { \
1082- static std::vector<HLSLBool_t> buildExpected (Op<OP, T, 1 >, \
1089+ static std::vector<HLSLBool_t> buildExpected (Op<OP, T, 1 > &, \
10831090 const InputSets<T> &Inputs) { \
10841091 const bool Res = STDFUNC (Inputs[0 ].begin (), Inputs[0 ].end (), \
10851092 [](T A) { return A != static_cast <T>(0 ); }); \
@@ -1097,22 +1104,97 @@ REDUCTION_OP(OpType::All_Zero, (std::all_of));
10971104
10981105#undef REDUCTION_OP
10991106
1100- template <typename T> struct Op <OpType::Dot, T, 2 > : DefaultValidation<T> {};
1107+ template <typename T> struct Op <OpType::Dot, T, 2 > : StrictValidation {};
11011108template <typename T> struct ExpectedBuilder <OpType::Dot, T> {
1102- static std::vector<T> buildExpected (Op<OpType::Dot, T, 2 >,
1109+ // For Dot, buildExpected is a special case: it also computes an absolute
1110+ // epsilon for validation because Dot is a compound operation. Expected value
1111+ // is computed by multiplying and accumulating in fp64 for higher precision.
1112+ // Absolute epsilon is computed by reordering the accumulation into a
1113+ // worst-case sequence, then summing the per-step epsilons to produce a
1114+ // conservative error tolerance for the entire Dot operation.
1115+ static std::vector<T> buildExpected (Op<OpType::Dot, T, 2 > &Op,
11031116 const InputSets<T> &Inputs) {
1104- T DotProduct = T ();
11051117
1106- for (size_t I = 0 ; I < Inputs[0 ].size (); ++I) {
1107- DotProduct += Inputs[0 ][I] * Inputs[1 ][I];
1118+ std::vector<double > PositiveProducts;
1119+ std::vector<double > NegativeProducts;
1120+
1121+ const size_t VectorSize = Inputs[0 ].size ();
1122+
1123+ // Floating point ops have a tolerance of 0.5 ULPs per operation as per the
1124+ // DX spec.
1125+ const double ULPTolerance = 0.5 ;
1126+
1127+ // Accumulate in fp64 to improve precision.
1128+ double DotProduct = 0.0 ; // computed reference result
1129+ double AbsoluteEpsilon = 0.0 ; // computed tolerance
1130+ for (size_t I = 0 ; I < VectorSize; ++I) {
1131+ double Product = Inputs[0 ][I] * Inputs[1 ][I];
1132+ AbsoluteEpsilon += computeAbsoluteEpsilon<T>(Product, ULPTolerance);
1133+
1134+ DotProduct += Product;
1135+
1136+ if (Product >= 0.0 )
1137+ PositiveProducts.push_back (Product);
1138+ else
1139+ NegativeProducts.push_back (Product);
11081140 }
11091141
1142+ // Sort each by magnitude so that we can accumulate them in worst case
1143+ // order.
1144+ std::sort (PositiveProducts.begin (), PositiveProducts.end (),
1145+ std::greater<double >());
1146+ std::sort (NegativeProducts.begin (), NegativeProducts.end ());
1147+
1148+ // Helper to sum the products and compute/add to the running absolute
1149+ // epsilon total.
1150+ auto SumProducts = [&AbsoluteEpsilon,
1151+ ULPTolerance](const std::vector<double > &Values) {
1152+ double Sum = Values.empty () ? 0.0 : Values[0 ];
1153+ for (size_t I = 1 ; I < Values.size (); ++I) {
1154+ Sum += Values[I];
1155+ AbsoluteEpsilon += computeAbsoluteEpsilon<T>(Sum, ULPTolerance);
1156+ }
1157+ return Sum;
1158+ };
1159+
1160+ // Accumulate products in the worst case order while computing the absolute
1161+ // epsilon error for each intermediate step. And accumulate that error.
1162+ const double SumPos = SumProducts (PositiveProducts);
1163+ const double SumNeg = SumProducts (NegativeProducts);
1164+
1165+ if (!PositiveProducts.empty () && !NegativeProducts.empty ())
1166+ AbsoluteEpsilon +=
1167+ computeAbsoluteEpsilon<T>((SumPos + SumNeg), ULPTolerance);
1168+
1169+ Op.ValidationConfig = ValidationConfig::Epsilon (AbsoluteEpsilon);
1170+
11101171 std::vector<T> Expected;
1111- Expected.push_back (DotProduct);
1172+ Expected.push_back (static_cast <T>( DotProduct) );
11121173 return Expected;
11131174 }
11141175};
11151176
1177+ template <typename T>
1178+ static double computeAbsoluteEpsilon (double A, double ULPTolerance) {
1179+ DXASSERT ((!isinf (A) && !isnan (A)),
1180+ " Input values should not produce inf or nan results" );
1181+
1182+ // ULP is a positive value by definition. So, working with abs(A) simplifies
1183+ // our logic for computing ULP in the first place.
1184+ A = std::abs (A);
1185+
1186+ double ULP = 0.0 ;
1187+
1188+ if constexpr (std::is_same_v<T, HLSLHalf_t>)
1189+ ULP = HLSLHalf_t::GetULP (A);
1190+ else
1191+ ULP =
1192+ std::nextafter (static_cast <T>(A), std::numeric_limits<T>::infinity ()) -
1193+ static_cast <T>(A);
1194+
1195+ return ULP * ULPTolerance;
1196+ }
1197+
11161198template <typename T>
11171199struct Op <OpType::ShuffleVector, T, 1 > : DefaultValidation<T> {};
11181200template <typename T> struct ExpectedBuilder <OpType::ShuffleVector, T> {
0 commit comments