Skip to content

Commit f5ed820

Browse files
committed
Add halfs
1 parent de2907b commit f5ed820

1 file changed

Lines changed: 156 additions & 70 deletions

File tree

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 156 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
#pragma comment(lib, "dxguid.lib")
6969
#pragma comment(lib, "version.lib")
7070

71+
const float HALF_MAX = 65504.0f;
72+
const float HALF_MIN = 6.10e-5f;
73+
7174
// A more recent Windows SDK than currently required is needed for these.
7275
typedef HRESULT(WINAPI *D3D12EnableExperimentalFeaturesFn)(
7376
UINT NumFeatures, __in_ecount(NumFeatures) const IID *pIIDs,
@@ -507,10 +510,10 @@ class ExecutionTest {
507510
L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
508511
END_TEST_METHOD()
509512

510-
//BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float16)
511-
//TEST_METHOD_PROPERTY(L"DataSource",
512-
// L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
513-
//END_TEST_METHOD()
513+
BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float16)
514+
TEST_METHOD_PROPERTY(L"DataSource",
515+
L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
516+
END_TEST_METHOD()
514517

515518
BEGIN_TEST_METHOD(LongVector_BinaryOpTest_float32)
516519
TEST_METHOD_PROPERTY(L"DataSource",
@@ -552,10 +555,10 @@ class ExecutionTest {
552555
L"Table:ShaderOpArithTable.xml#LongVector_BinaryOpTable")
553556
END_TEST_METHOD()
554557

555-
//BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float16)
556-
//TEST_METHOD_PROPERTY(L"DataSource",
557-
// L"Table:ShaderOpArithTable.xml#LongVector_UnaryOpTable")
558-
//END_TEST_METHOD()
558+
BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float16)
559+
TEST_METHOD_PROPERTY(L"DataSource",
560+
L"Table:ShaderOpArithTable.xml#LongVector_UnaryOpTable")
561+
END_TEST_METHOD()
559562

560563
BEGIN_TEST_METHOD(LongVector_UnaryOpTest_float32)
561564
TEST_METHOD_PROPERTY(L"DataSource",
@@ -6138,6 +6141,13 @@ class TableParameterHandler {
61386141
return nullptr;
61396142
}
61406143

6144+
template <typename T>
6145+
std::vector<T> GetTableParamByName(LPCWSTR name) {
6146+
std::vector<WEX::Common::String> *table = GetTableParamByName(name);
6147+
return parseStringsToNumbers<T>(GetTableParamByName(name));
6148+
}
6149+
6150+
61416151
void clearTableParameter() {
61426152
for (size_t i = 0; i < m_tableSize; ++i) {
61436153
m_table[i].m_int32 = 0;
@@ -11222,51 +11232,112 @@ TEST_F(ExecutionTest, PackUnpackTest) {
1122211232
}
1122311233
}
1122411234

11235+
// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes.
11236+
// Take int32_t as a constuctor argument and convert it to bool when needed.
11237+
// Comparisons cast to a bool because we only care if the bool representation is
11238+
// true or false.
11239+
struct hlslBool_t
11240+
{
11241+
hlslBool_t() : val(0) {}
11242+
hlslBool_t(int32_t val) : val(val) {}
11243+
hlslBool_t(bool val) : val(val) {}
11244+
hlslBool_t(const hlslBool_t& other) : val(other.val) {}
11245+
11246+
bool operator==(const hlslBool_t& other) const{
11247+
return static_cast<bool>(val) == static_cast<bool>(other.val);
11248+
}
11249+
11250+
bool operator!=(const hlslBool_t& other) const{
11251+
return static_cast<bool>(val) != static_cast<bool>(other.val);
11252+
}
11253+
11254+
// So we can construct strings using std::wostream
11255+
friend std::wostream& operator<<(std::wostream& os, const hlslBool_t& obj) {
11256+
os << static_cast<bool>(obj.val);
11257+
return os;
11258+
}
11259+
11260+
int32_t val = 0;
11261+
};
11262+
11263+
// No native float16 type in C++. So we use uint16_t to represent it.
11264+
struct hlslHalf_t
11265+
{
11266+
hlslHalf_t() : val(0) {}
11267+
hlslHalf_t(uint16_t val) : val(val) {}
11268+
hlslHalf_t(const hlslHalf_t& other) : val(other.val) {}
11269+
11270+
bool operator==(const hlslHalf_t& other) const{
11271+
return val == other.val;
11272+
}
11273+
11274+
bool operator<(const hlslHalf_t& other) const{
11275+
return val < other.val;
11276+
}
11277+
11278+
bool operator>(const hlslHalf_t& other) const{
11279+
return val > other.val;
11280+
}
11281+
11282+
bool operator>(double d) const{
11283+
return static_cast<double>(val) > d;
11284+
}
11285+
11286+
bool operator!=(const hlslHalf_t& other) const{
11287+
return val != other.val;
11288+
}
11289+
11290+
hlslHalf_t operator-(const hlslHalf_t& other) const{
11291+
return hlslHalf_t(val - other.val);
11292+
}
11293+
11294+
// So we can construct strings using std::wostream
11295+
friend std::wostream& operator<<(std::wostream& os, const hlslHalf_t& obj) {
11296+
os << static_cast<long>(obj.val);
11297+
return os;
11298+
}
11299+
11300+
uint16_t val = 0;
11301+
};
11302+
1122511303
// TODOLongVec : Need to change the members to vectors. But when I did that the
1122611304
// copy logic to read back from the shader buffer fails. Needs to fix that
1122711305
// before checking in PR.
1122811306
// SLongVectorBinaryOp is used in ShaderOpArithTable.xml. The shader program
1122911307
// uses the struct defintion to read from the input global buffer.
11230-
template <typename T, std::size_t N,
11231-
// By checking that T is an arithmetic we can keep the compiler error messages
11232-
// from misusing the template much cleaner.
11233-
typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11308+
template <typename T, std::size_t N>
1123411309
struct SLongVectorBinaryOp {
1123511310
T scalarInput;
1123611311
std::array<T, N> vecInput1;
1123711312
std::array<T, N> vecInput2;
1123811313
std::array<T, N> vecOutput;
1123911314
};
1124011315

11241-
// vec1 == Expected vector
11242-
// vec2 == Actual vector
11243-
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11316+
// vec1 == Actual vector
11317+
// vec2 == Expected vector
11318+
template <typename T>
1124411319
bool DoVectorsMatch(const std::vector<T>& vec1, const std::vector<T>& vec2, double tolerance) {
11245-
// Ensure both vectors have the same size
11320+
// Sanity check. Ensure both vectors have the same size
1124611321
if (vec1.size() != vec2.size()) {
1124711322
VERIFY_IS_TRUE(false, L"Vectors are of different sizes!");
1124811323
return false;
1124911324
}
1125011325

11251-
bool vectorsMatch = true;
11252-
1125311326
// Stash mismatched indexes for easy failure logging later
1125411327
std::vector<size_t> mismatchedIndexes;
1125511328
for (size_t i = 0; i < vec1.size(); ++i) {
1125611329
if (tolerance == 0 && vec1[i] != vec2[i]) {
11257-
vectorsMatch = false;
1125811330
mismatchedIndexes.push_back(i);
11259-
} else if constexpr (std::is_same_v<T, bool>) {
11260-
// Compiler was very picky and wanted an explicit case for bools to
11261-
// avoid a warning about comparing bools with >.
11331+
} else if constexpr (std::is_same_v<T, hlslBool_t>) {
11332+
// Compiler was very picky and wanted an explicit case for any T that
11333+
// doesn't implement the operators in the below else. ( > and -). It
11334+
// wouldn't accept putting this constexpr as an or case in the above if.
1126211335
if (vec1[i] != vec2[i]) {
11263-
vectorsMatch = false;
1126411336
mismatchedIndexes.push_back(i);
1126511337
}
1126611338
} else {
1126711339
T diff = vec1[i] > vec2[i] ? vec1[i] - vec2[i] : vec2[i] - vec1[i];
1126811340
if (diff > tolerance) {
11269-
vectorsMatch = false;
1127011341
mismatchedIndexes.push_back(i);
1127111342
}
1127211343
}
@@ -11277,18 +11348,33 @@ bool DoVectorsMatch(const std::vector<T>& vec1, const std::vector<T>& vec2, doub
1127711348
for (size_t index : mismatchedIndexes) {
1127811349
std::wstringstream wss(L"");
1127911350
wss << L"Mismatch at Index: " << index;
11280-
wss << L" Expected Value:" << vec1[index] << ",";
11281-
wss << L" Actual Value:" << vec2[index];
11351+
wss << L" Actual Value:" << vec1[index] << ",";
11352+
wss << L" Expected Value:" << vec2[index];
1128211353
WEX::Logging::Log::Error(wss.str().c_str());
1128311354
}
1128411355
}
1128511356

11286-
return vectorsMatch;
11357+
return mismatchedIndexes.empty();
1128711358
}
1128811359

11289-
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11360+
template <typename T>
1129011361
T parseStringToNumber(std::wstring& str) {
11291-
if constexpr (std::is_same_v<T, uint16_t>) {
11362+
if constexpr (std::is_same_v<T, hlslBool_t>) {
11363+
// HLSL bool is 4 bytes. C++ bool is 1 byte.
11364+
// So we use an int32_t to hold the value.
11365+
return hlslBool_t(std::stol(str));
11366+
} else if constexpr (std::is_same_v<T, hlslHalf_t>) {
11367+
auto num = std::stod(str);
11368+
if(num < HALF_MIN || num > HALF_MAX) {
11369+
LogCommentFmt(L"Value is out of range. HALF_MIN:%f, HALF_MAX: %f, Value: %f. Will clamp.", HALF_MIN, HALF_MAX, num);
11370+
if(num < HALF_MIN) {
11371+
num = HALF_MIN;
11372+
} else {
11373+
num = HALF_MAX;
11374+
}
11375+
}
11376+
return static_cast<uint16_t>(num);
11377+
} else if constexpr (std::is_same_v<T, uint16_t>) {
1129211378
return static_cast<uint16_t>(std::stoul(str));
1129311379
} else if constexpr (std::is_same_v<T, uint32_t>) {
1129411380
return std::stoul(str);
@@ -11315,31 +11401,30 @@ T parseStringToNumber(std::wstring& str) {
1131511401
// A helper to get the hlsl type as a string for a given C++ type.
1131611402
template <typename T>
1131711403
std::string GetHLSLTypeString() {
11318-
if (std::is_same<T, bool>::value) {
11404+
if (std::is_same_v<T, hlslBool_t>) {
1131911405
return "bool";
11320-
// TODO: Need special logic for half. No half in C++
11321-
//} else if (std::is_same<T, half>::value) {
11322-
// return "half";
11323-
} else if (std::is_same<T, float>::value) {
11406+
} else if (std::is_same_v<T, hlslHalf_t>) {
11407+
return "half";
11408+
} else if (std::is_same_v<T, float>) {
1132411409
return "float";
11325-
} else if (std::is_same<T, double>::value) {
11410+
} else if (std::is_same_v<T, double>) {
1132611411
return "double";
11327-
} else if (std::is_same<T, int16_t>::value) {
11412+
} else if (std::is_same_v<T, int16_t>) {
1132811413
return "int16_t";
11329-
} else if (std::is_same<T, int32_t>::value) {
11414+
} else if (std::is_same_v<T, int32_t>) {
1133011415
return "int";
11331-
} else if (std::is_same<T, int64_t>::value) {
11416+
} else if (std::is_same_v<T, int64_t>) {
1133211417
return "int64_t";
11333-
} else if (std::is_same<T, uint16_t>::value) {
11418+
} else if (std::is_same_v<T, uint16_t>) {
1133411419
return "uint16_t";
11335-
} else if (std::is_same<T, uint32_t>::value) {
11420+
} else if (std::is_same_v<T, uint32_t>) {
1133611421
return "uint32_t";
11337-
} else if (std::is_same<T, uint64_t>::value) {
11422+
} else if (std::is_same_v<T, uint64_t>) {
1133811423
return "uint64_t";
11339-
// TODO: Need special logic for these types in C++
11340-
//} else if (std::is_same<T, packed_int16_t>::value) {
11424+
// TODOLONGVEC: Need special logic for these types in C++
11425+
//} else if (std::is_same_v<T, packed_int16_t>) {
1134111426
// return "packed_int16_t";
11342-
//} else if (std::is_same<T, packed_uint16_t>::value) {
11427+
//} else if (std::is_same_v<T, packed_uint16_t>) {
1134311428
// return "packed_uint16_t";
1134411429
} else {
1134511430
std::string errStr("GetHLSLTypeString() Unsupported type: ");
@@ -11358,7 +11443,7 @@ std::string GetHLSLTypeString() {
1135811443
// Or three strings ["1" , "2", "3"] will parse to 3 single elements vectors.
1135911444
// Vectors for a single element is a little weird, but it allows this helper
1136011445
// function to be very flexible.
11361-
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11446+
template <typename T>
1136211447
std::vector<std::vector<T>> parseStringsToNumbers(const std::vector<WEX::Common::String>* input) {
1136311448
std::vector<std::vector<T>> result = {};
1136411449
if (input == nullptr || input->empty()) {
@@ -11387,18 +11472,15 @@ std::vector<std::vector<T>> parseStringsToNumbers(const std::vector<WEX::Common:
1138711472
TEST_F(ExecutionTest, LongVector_BinaryOpTest_bool) {
1138811473
WEX::TestExecution::SetVerifyOutput verifySettings(
1138911474
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11390-
// TODOLONGVEC: Do all the binary ops make sense on bools?
11391-
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped, L"Skipping bool test for now.");
11392-
//LongVectorBinaryOpTestBase<bool>();
11475+
LongVectorBinaryOpTestBase<hlslBool_t>();
1139311476
}
1139411477

11395-
// TODO: No half available in C++. Need to add logic for this type
11396-
//TEST_F(ExecutionTest, LongVector_BinaryOpTest_float16) {
11397-
// WEX::TestExecution::SetVerifyOutput verifySettings(
11398-
// WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11399-
//
11400-
// LongVectorBinaryOpTestBase<float>();
11401-
//}
11478+
TEST_F(ExecutionTest, LongVector_BinaryOpTest_float16) {
11479+
WEX::TestExecution::SetVerifyOutput verifySettings(
11480+
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11481+
11482+
LongVectorBinaryOpTestBase<hlslHalf_t>();
11483+
}
1140211484

1140311485
TEST_F(ExecutionTest, LongVector_BinaryOpTest_float32) {
1140411486
WEX::TestExecution::SetVerifyOutput verifySettings(
@@ -11457,8 +11539,8 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1145711539
LongVectorBinaryOpTestBase<T, 16>();
1145811540
LongVectorBinaryOpTestBase<T, 17>();
1145911541
LongVectorBinaryOpTestBase<T, 35>();
11460-
LongVectorBinaryOpTestBase<T, 100>();
11461-
// TODOLONGVEC: 1024 breaks the size limit for structured buffers
11542+
// TODOLONGVEC: 100, 1024 breaks the size limit for structured buffers
11543+
//LongVectorBinaryOpTestBase<T, 100>();
1146211544
//LongVectorBinaryOpTestBase<T, 1024>();
1146311545
}
1146411546

@@ -11526,9 +11608,15 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1152611608
// The two operand vectors must be the same size.
1152711609
VERIFY_IS_TRUE(inputVectorCount == vecInput2.size() || inputVectorCount == scalarInputs.size());
1152811610

11529-
// Additionally all elements (also vectors) in the vector must be the same
11530-
// size.
11531-
// TODO: Add check.
11611+
// Sanity because breaking the XML would be easy.
11612+
if(hasVecInput2)
11613+
{
11614+
for(size_t i = 0 ; i < vecInput1.size(); i++)
11615+
{
11616+
VERIFY_IS_TRUE(vecInput1[i].size() == vecInput2[i].size(), L"Input1 and Input2 vectors must be the same size.");
11617+
}
11618+
}
11619+
1153211620
// Keep filling the vectors with 'copies' of themselves until we hit N
1153311621
if(vecInput1[0].size() != N)
1153411622
{
@@ -11636,21 +11724,19 @@ void ExecutionTest::LongVectorBinaryOpTestBase() {
1163611724
}
1163711725
}
1163811726

11639-
template <typename T, std::size_t N,
11640-
typename = std::enable_if_t<std::is_arithmetic_v<T>>>
11727+
template <typename T, std::size_t N>
1164111728
struct SLongVectorUnaryOp {
1164211729
T clampArgC;
1164311730
T clampArgT;
1164411731
std::array<T, N> vecInput1;
1164511732
std::array<T, N> vecOutput;
1164611733
};
1164711734

11648-
// TODO: No half available in C++. Need to add logic for this type
11649-
//TEST_F(ExecutionTest, LongVector_UnaryOpTest_float16) {
11650-
// WEX::TestExecution::SetVerifyOutput verifySettings(
11651-
// WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11652-
// LongVectorUnaryOpTestBase<float>();
11653-
//}
11735+
TEST_F(ExecutionTest, LongVector_UnaryOpTest_float16) {
11736+
WEX::TestExecution::SetVerifyOutput verifySettings(
11737+
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
11738+
LongVectorUnaryOpTestBase<hlslHalf_t>();
11739+
}
1165411740

1165511741
TEST_F(ExecutionTest, LongVector_UnaryOpTest_float32) {
1165611742
WEX::TestExecution::SetVerifyOutput verifySettings(
@@ -11710,8 +11796,8 @@ void ExecutionTest::LongVectorUnaryOpTestBase() {
1171011796
LongVectorUnaryOpTestBase<T, 16>();
1171111797
LongVectorUnaryOpTestBase<T, 17>();
1171211798
LongVectorUnaryOpTestBase<T, 35>();
11713-
LongVectorUnaryOpTestBase<T, 100>();
11714-
// TODOLONGVEC: 1024 breaks the size limit for structured buffers
11799+
// TODOLONGVEC: 100, 1024 breaks the size limit for structured buffers
11800+
//LongVectorUnaryOpTestBase<T, 100>();
1171511801
//LongVectorUnaryOpTestBase<T, 1024>();
1171611802
}
1171711803

0 commit comments

Comments
 (0)