6868#pragma comment(lib, "dxguid.lib")
6969#pragma comment(lib, "version.lib")
7070
71+ const float HALF_MAX = 65504.0f;
72+ const float HALF_MIN = 6.10e-5f;
73+
7174// A more recent Windows SDK than currently required is needed for these.
7275typedef HRESULT(WINAPI *D3D12EnableExperimentalFeaturesFn)(
7376 UINT NumFeatures, __in_ecount(NumFeatures) const IID *pIIDs,
@@ -507,10 +510,10 @@ class ExecutionTest {
507510 L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
508511 END_TEST_METHOD()
509512
510- // BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float16)
511- // TEST_METHOD_PROPERTY(L"DataSource",
512- // L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
513- // END_TEST_METHOD()
513+ BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float16)
514+ TEST_METHOD_PROPERTY(L"DataSource",
515+ L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
516+ END_TEST_METHOD()
514517
515518 BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float32)
516519 TEST_METHOD_PROPERTY(L"DataSource",
@@ -552,10 +555,10 @@ class ExecutionTest {
552555 L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
553556 END_TEST_METHOD()
554557
555- // BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float16)
556- // TEST_METHOD_PROPERTY(L"DataSource",
557- // L"Table:ShaderOpArithTable.xml#LongVector_UnaryOpTable")
558- // END_TEST_METHOD()
558+ BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float16)
559+ TEST_METHOD_PROPERTY(L"DataSource",
560+ L"Table:ShaderOpArithTable.xml#LongVector_UnaryOpTable")
561+ END_TEST_METHOD()
559562
560563 BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float32)
561564 TEST_METHOD_PROPERTY(L"DataSource",
@@ -6138,6 +6141,13 @@ class TableParameterHandler {
61386141 return nullptr;
61396142 }
61406143
6144+ template <typename T>
6145+ std::vector<T> GetTableParamByName(LPCWSTR name) {
6146+ std::vector<WEX::Common::String> *table = GetTableParamByName(name);
6147+ return parseStringsToNumbers<T>(GetTableParamByName(name));
6148+ }
6149+
6150+
61416151 void clearTableParameter() {
61426152 for (size_t i = 0; i < m_tableSize; ++i) {
61436153 m_table[i].m_int32 = 0;
@@ -11222,51 +11232,112 @@ TEST_F(ExecutionTest, PackUnpackTest) {
1122211232 }
1122311233}
1122411234
11235+ // A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes.
11236+ // Take int32_t as a constuctor argument and convert it to bool when needed.
11237+ // Comparisons cast to a bool because we only care if the bool representation is
11238+ // true or false.
11239+ struct hlslBool_t
11240+ {
11241+ hlslBool_t() : val(0) {}
11242+ hlslBool_t(int32_t val) : val(val) {}
11243+ hlslBool_t(bool val) : val(val) {}
11244+ hlslBool_t(const hlslBool_t& other) : val(other.val) {}
11245+
11246+ bool operator==(const hlslBool_t& other) const{
11247+ return static_cast<bool>(val) == static_cast<bool>(other.val);
11248+ }
11249+
11250+ bool operator!=(const hlslBool_t& other) const{
11251+ return static_cast<bool>(val) != static_cast<bool>(other.val);
11252+ }
11253+
11254+ // So we can construct strings using std::wostream
11255+ friend std::wostream& operator<<(std::wostream& os, const hlslBool_t& obj) {
11256+ os << static_cast<bool>(obj.val);
11257+ return os;
11258+ }
11259+
11260+ int32_t val = 0;
11261+ };
11262+
11263+ // No native float16 type in C++. So we use uint16_t to represent it.
11264+ struct hlslHalf_t
11265+ {
11266+ hlslHalf_t() : val(0) {}
11267+ hlslHalf_t(uint16_t val) : val(val) {}
11268+ hlslHalf_t(const hlslHalf_t& other) : val(other.val) {}
11269+
11270+ bool operator==(const hlslHalf_t& other) const{
11271+ return val == other.val;
11272+ }
11273+
11274+ bool operator<(const hlslHalf_t& other) const{
11275+ return val < other.val;
11276+ }
11277+
11278+ bool operator>(const hlslHalf_t& other) const{
11279+ return val > other.val;
11280+ }
11281+
11282+ bool operator>(double d) const{
11283+ return static_cast<double>(val) > d;
11284+ }
11285+
11286+ bool operator!=(const hlslHalf_t& other) const{
11287+ return val != other.val;
11288+ }
11289+
11290+ hlslHalf_t operator-(const hlslHalf_t& other) const{
11291+ return hlslHalf_t(val - other.val);
11292+ }
11293+
11294+ // So we can construct strings using std::wostream
11295+ friend std::wostream& operator<<(std::wostream& os, const hlslHalf_t& obj) {
11296+ os << static_cast<long>(obj.val);
11297+ return os;
11298+ }
11299+
11300+ uint16_t val = 0;
11301+ };
11302+
1122511303// TODOLongVec : Need to change the members to vectors. But when I did that the
1122611304// copy logic to read back from the shader buffer fails. Needs to fix that
1122711305// before checking in PR.
1122811306// SLongVectorBinaryOp is used in ShaderOpArithTable.xml. The shader program
1122911307// uses the struct defintion to read from the input global buffer.
11230- template <typename T, std::size_t N,
11231- // By checking that T is an arithmetic we can keep the compiler error messages
11232- // from misusing the template much cleaner.
11233- typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11308+ template <typename T, std::size_t N>
1123411309struct SLongVectorBinaryOp {
1123511310 T scalarInput;
1123611311 std::array<T, N> vecInput1;
1123711312 std::array<T, N> vecInput2;
1123811313 std::array<T, N> vecOutput;
1123911314};
1124011315
11241- // vec1 == Expected vector
11242- // vec2 == Actual vector
11243- template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>> >
11316+ // vec1 == Actual vector
11317+ // vec2 == Expected vector
11318+ template <typename T>
1124411319bool DoVectorsMatch(const std::vector<T>& vec1, const std::vector<T>& vec2, double tolerance) {
11245- // Ensure both vectors have the same size
11320+ // Sanity check. Ensure both vectors have the same size
1124611321 if (vec1.size() != vec2.size()) {
1124711322 VERIFY_IS_TRUE(false, L"Vectors are of different sizes!");
1124811323 return false;
1124911324 }
1125011325
11251- bool vectorsMatch = true;
11252-
1125311326 // Stash mismatched indexes for easy failure logging later
1125411327 std::vector<size_t> mismatchedIndexes;
1125511328 for (size_t i = 0; i < vec1.size(); ++i) {
1125611329 if (tolerance == 0 && vec1[i] != vec2[i]) {
11257- vectorsMatch = false;
1125811330 mismatchedIndexes.push_back(i);
11259- } else if constexpr (std::is_same_v<T, bool>) {
11260- // Compiler was very picky and wanted an explicit case for bools to
11261- // avoid a warning about comparing bools with >.
11331+ } else if constexpr (std::is_same_v<T, hlslBool_t>) {
11332+ // Compiler was very picky and wanted an explicit case for any T that
11333+ // doesn't implement the operators in the below else. ( > and -). It
11334+ // wouldn't accept putting this constexpr as an or case in the above if.
1126211335 if (vec1[i] != vec2[i]) {
11263- vectorsMatch = false;
1126411336 mismatchedIndexes.push_back(i);
1126511337 }
1126611338 } else {
1126711339 T diff = vec1[i] > vec2[i] ? vec1[i] - vec2[i] : vec2[i] - vec1[i];
1126811340 if (diff > tolerance) {
11269- vectorsMatch = false;
1127011341 mismatchedIndexes.push_back(i);
1127111342 }
1127211343 }
@@ -11277,18 +11348,33 @@ bool DoVectorsMatch(const std::vector<T>& vec1, const std::vector<T>& vec2, doub
1127711348 for (size_t index : mismatchedIndexes) {
1127811349 std::wstringstream wss(L"");
1127911350 wss << L"Mismatch at Index: " << index;
11280- wss << L" Expected Value:" << vec1[index] << ",";
11281- wss << L" Actual Value:" << vec2[index];
11351+ wss << L" Actual Value:" << vec1[index] << ",";
11352+ wss << L" Expected Value:" << vec2[index];
1128211353 WEX::Logging::Log::Error(wss.str().c_str());
1128311354 }
1128411355 }
1128511356
11286- return vectorsMatch ;
11357+ return mismatchedIndexes.empty() ;
1128711358}
1128811359
11289- template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>> >
11360+ template <typename T>
1129011361T parseStringToNumber(std::wstring& str) {
11291- if constexpr (std::is_same_v<T, uint16_t>) {
11362+ if constexpr (std::is_same_v<T, hlslBool_t>) {
11363+ // HLSL bool is 4 bytes. C++ bool is 1 byte.
11364+ // So we use an int32_t to hold the value.
11365+ return hlslBool_t(std::stol(str));
11366+ } else if constexpr (std::is_same_v<T, hlslHalf_t>) {
11367+ auto num = std::stod(str);
11368+ if(num < HALF_MIN || num > HALF_MAX) {
11369+ LogCommentFmt(L"Value is out of range. HALF_MIN:%f, HALF_MAX: %f, Value: %f. Will clamp.", HALF_MIN, HALF_MAX, num);
11370+ if(num < HALF_MIN) {
11371+ num = HALF_MIN;
11372+ } else {
11373+ num = HALF_MAX;
11374+ }
11375+ }
11376+ return static_cast<uint16_t>(num);
11377+ } else if constexpr (std::is_same_v<T, uint16_t>) {
1129211378 return static_cast<uint16_t>(std::stoul(str));
1129311379 } else if constexpr (std::is_same_v<T, uint32_t>) {
1129411380 return std::stoul(str);
@@ -11315,31 +11401,30 @@ T parseStringToNumber(std::wstring& str) {
1131511401// A helper to get the hlsl type as a string for a given C++ type.
1131611402template <typename T>
1131711403std::string GetHLSLTypeString() {
11318- if (std::is_same <T, bool>::value ) {
11404+ if (std::is_same_v <T, hlslBool_t> ) {
1131911405 return "bool";
11320- // TODO: Need special logic for half. No half in C++
11321- //} else if (std::is_same<T, half>::value) {
11322- // return "half";
11323- } else if (std::is_same<T, float>::value) {
11406+ } else if (std::is_same_v<T, hlslHalf_t>) {
11407+ return "half";
11408+ } else if (std::is_same_v<T, float>) {
1132411409 return "float";
11325- } else if (std::is_same <T, double>::value ) {
11410+ } else if (std::is_same_v <T, double>) {
1132611411 return "double";
11327- } else if (std::is_same <T, int16_t>::value ) {
11412+ } else if (std::is_same_v <T, int16_t>) {
1132811413 return "int16_t";
11329- } else if (std::is_same <T, int32_t>::value ) {
11414+ } else if (std::is_same_v <T, int32_t>) {
1133011415 return "int";
11331- } else if (std::is_same <T, int64_t>::value ) {
11416+ } else if (std::is_same_v <T, int64_t>) {
1133211417 return "int64_t";
11333- } else if (std::is_same <T, uint16_t>::value ) {
11418+ } else if (std::is_same_v <T, uint16_t>) {
1133411419 return "uint16_t";
11335- } else if (std::is_same <T, uint32_t>::value ) {
11420+ } else if (std::is_same_v <T, uint32_t>) {
1133611421 return "uint32_t";
11337- } else if (std::is_same <T, uint64_t>::value ) {
11422+ } else if (std::is_same_v <T, uint64_t>) {
1133811423 return "uint64_t";
11339- // TODO : Need special logic for these types in C++
11340- //} else if (std::is_same <T, packed_int16_t>::value ) {
11424+ // TODOLONGVEC : Need special logic for these types in C++
11425+ //} else if (std::is_same_v <T, packed_int16_t>) {
1134111426 // return "packed_int16_t";
11342- //} else if (std::is_same <T, packed_uint16_t>::value ) {
11427+ //} else if (std::is_same_v <T, packed_uint16_t>) {
1134311428 // return "packed_uint16_t";
1134411429 } else {
1134511430 std::string errStr("GetHLSLTypeString() Unsupported type: ");
@@ -11358,7 +11443,7 @@ std::string GetHLSLTypeString() {
1135811443// Or three strings ["1" , "2", "3"] will parse to 3 single elements vectors.
1135911444// Vectors for a single element is a little weird, but it allows this helper
1136011445// function to be very flexible.
11361- template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>> >
11446+ template <typename T>
1136211447std::vector<std::vector<T>> parseStringsToNumbers(const std::vector<WEX::Common::String>* input) {
1136311448 std::vector<std::vector<T>> result = {};
1136411449 if (input == nullptr || input->empty()) {
@@ -11387,18 +11472,15 @@ std::vector<std::vector<T>> parseStringsToNumbers(const std::vector<WEX::Common:
1138711472TEST_F(ExecutionTest, LongVector_BinaryOpTest_bool) {
1138811473 WEX::TestExecution::SetVerifyOutput verifySettings(
1138911474 WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11390- // TODOLONGVEC: Do all the binary ops make sense on bools?
11391- WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped, L"Skipping bool test for now.");
11392- //LongVectorBinaryOpTestBase<bool>();
11475+ LongVectorBinaryOpTestBase<hlslBool_t>();
1139311476}
1139411477
11395- // TODO: No half available in C++. Need to add logic for this type
11396- //TEST_F(ExecutionTest, LongVector_BinaryOpTest_float16) {
11397- // WEX::TestExecution::SetVerifyOutput verifySettings(
11398- // WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11399- //
11400- // LongVectorBinaryOpTestBase<float>();
11401- //}
11478+ TEST_F(ExecutionTest, LongVector_BinaryOpTest_float16) {
11479+ WEX::TestExecution::SetVerifyOutput verifySettings(
11480+ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11481+
11482+ LongVectorBinaryOpTestBase<hlslHalf_t>();
11483+ }
1140211484
1140311485TEST_F(ExecutionTest, LongVector_BinaryOpTest_float32) {
1140411486 WEX::TestExecution::SetVerifyOutput verifySettings(
@@ -11457,8 +11539,8 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1145711539 LongVectorBinaryOpTestBase<T, 16>();
1145811540 LongVectorBinaryOpTestBase<T, 17>();
1145911541 LongVectorBinaryOpTestBase<T, 35>();
11460- LongVectorBinaryOpTestBase<T, 100>();
11461- // TODOLONGVEC: 1024 breaks the size limit for structured buffers
11542+ // TODOLONGVEC: 100, 1024 breaks the size limit for structured buffers
11543+ //LongVectorBinaryOpTestBase<T, 100>();
1146211544 //LongVectorBinaryOpTestBase<T, 1024>();
1146311545}
1146411546
@@ -11526,9 +11608,15 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1152611608 // The two operand vectors must be the same size.
1152711609 VERIFY_IS_TRUE(inputVectorCount == vecInput2.size() || inputVectorCount == scalarInputs.size());
1152811610
11529- // Additionally all elements (also vectors) in the vector must be the same
11530- // size.
11531- // TODO: Add check.
11611+ // Sanity because breaking the XML would be easy.
11612+ if(hasVecInput2)
11613+ {
11614+ for(size_t i = 0 ; i < vecInput1.size(); i++)
11615+ {
11616+ VERIFY_IS_TRUE(vecInput1[i].size() == vecInput2[i].size(), L"Input1 and Input2 vectors must be the same size.");
11617+ }
11618+ }
11619+
1153211620 // Keep filling the vectors with 'copies' of themselves until we hit N
1153311621 if(vecInput1[0].size() != N)
1153411622 {
@@ -11636,21 +11724,19 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1163611724 }
1163711725}
1163811726
11639- template <typename T, std::size_t N,
11640- typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11727+ template <typename T, std::size_t N>
1164111728struct SLongVectorUnaryOp {
1164211729 T clampArgC;
1164311730 T clampArgT;
1164411731 std::array<T, N> vecInput1;
1164511732 std::array<T, N> vecOutput;
1164611733};
1164711734
11648- // TODO: No half available in C++. Need to add logic for this type
11649- //TEST_F(ExecutionTest, LongVector_UnaryOpTest_float16) {
11650- // WEX::TestExecution::SetVerifyOutput verifySettings(
11651- // WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11652- // LongVectorUnaryOpTestBase<float>();
11653- //}
11735+ TEST_F(ExecutionTest, LongVector_UnaryOpTest_float16) {
11736+ WEX::TestExecution::SetVerifyOutput verifySettings(
11737+ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11738+ LongVectorUnaryOpTestBase<hlslHalf_t>();
11739+ }
1165411740
1165511741TEST_F(ExecutionTest, LongVector_UnaryOpTest_float32) {
1165611742 WEX::TestExecution::SetVerifyOutput verifySettings(
@@ -11710,8 +11796,8 @@ void ExecutionTest::LongVectorUnaryOpTestBase() {
1171011796 LongVectorUnaryOpTestBase<T, 16>();
1171111797 LongVectorUnaryOpTestBase<T, 17>();
1171211798 LongVectorUnaryOpTestBase<T, 35>();
11713- LongVectorUnaryOpTestBase<T, 100>();
11714- // TODOLONGVEC: 1024 breaks the size limit for structured buffers
11799+ // TODOLONGVEC: 100, 1024 breaks the size limit for structured buffers
11800+ //LongVectorUnaryOpTestBase<T, 100>();
1171511801 //LongVectorUnaryOpTestBase<T, 1024>();
1171611802}
1171711803
0 commit comments