Skip to content

Commit e8d9169

Browse files
alsepkowCopilot
andcommitted
Add bitwise and shift ops for min precision types
Add LeftShift and RightShift test entries for min16int and min16uint. Both produce valid min-precision DXIL (shl/ashr/lshr i16 with 4-bit shift masking). ReverseBits, CountBits, FirstBitHigh, FirstBitLow are excluded — DXC promotes min precision to i32 before calling these DXIL intrinsics, so they don't actually test min precision behavior. Infrastructure changes: - LongVectorTestData.h: Add Bitwise and BitShiftRhs input sets for HLSLMin16Int_t and HLSLMin16Uint_t matching int16_t/uint16_t names. Values constrained to 16-bit safe range. - LongVectorTestData.h: Add compound assignment operators (<<=, >>=, |=, &=, ^=) and unary ~ to both wrapper types to resolve ambiguity with integer promotion in template functions. - LongVectorTestData.h: Specialize std::is_signed for wrapper types so FirstBitHigh SFINAE selects the correct signed/unsigned variant. - LongVectors.cpp: Fix ReverseBits, ScanFromMSB, FirstBitLow to use explicit static_cast<T> for integer literals, avoiding ambiguous operator overload resolution with wrapper types. Co-authored-by: Copilot <[email protected]>
1 parent cb6175a commit e8d9169

2 files changed

Lines changed: 71 additions & 18 deletions

File tree

tools/clang/unittests/HLSLExec/LongVectorTestData.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,27 @@ struct HLSLMin16Int_t {
334334
HLSLMin16Int_t operator>>(const HLSLMin16Int_t &O) const {
335335
return HLSLMin16Int_t(Val >> O.Val);
336336
}
337+
HLSLMin16Int_t operator~() const { return HLSLMin16Int_t(~Val); }
338+
HLSLMin16Int_t &operator<<=(const HLSLMin16Int_t &O) {
339+
Val <<= O.Val;
340+
return *this;
341+
}
342+
HLSLMin16Int_t &operator>>=(const HLSLMin16Int_t &O) {
343+
Val >>= O.Val;
344+
return *this;
345+
}
346+
HLSLMin16Int_t &operator|=(const HLSLMin16Int_t &O) {
347+
Val |= O.Val;
348+
return *this;
349+
}
350+
HLSLMin16Int_t &operator&=(const HLSLMin16Int_t &O) {
351+
Val &= O.Val;
352+
return *this;
353+
}
354+
HLSLMin16Int_t &operator^=(const HLSLMin16Int_t &O) {
355+
Val ^= O.Val;
356+
return *this;
357+
}
337358
HLSLMin16Int_t operator&&(const HLSLMin16Int_t &O) const {
338359
return HLSLMin16Int_t(Val && O.Val);
339360
}
@@ -399,6 +420,27 @@ struct HLSLMin16Uint_t {
399420
HLSLMin16Uint_t operator>>(const HLSLMin16Uint_t &O) const {
400421
return HLSLMin16Uint_t(Val >> O.Val);
401422
}
423+
HLSLMin16Uint_t operator~() const { return HLSLMin16Uint_t(~Val); }
424+
HLSLMin16Uint_t &operator<<=(const HLSLMin16Uint_t &O) {
425+
Val <<= O.Val;
426+
return *this;
427+
}
428+
HLSLMin16Uint_t &operator>>=(const HLSLMin16Uint_t &O) {
429+
Val >>= O.Val;
430+
return *this;
431+
}
432+
HLSLMin16Uint_t &operator|=(const HLSLMin16Uint_t &O) {
433+
Val |= O.Val;
434+
return *this;
435+
}
436+
HLSLMin16Uint_t &operator&=(const HLSLMin16Uint_t &O) {
437+
Val &= O.Val;
438+
return *this;
439+
}
440+
HLSLMin16Uint_t &operator^=(const HLSLMin16Uint_t &O) {
441+
Val ^= O.Val;
442+
return *this;
443+
}
402444

403445
bool operator&&(const HLSLMin16Uint_t &O) const { return Val && O.Val; }
404446
bool operator||(const HLSLMin16Uint_t &O) const { return Val || O.Val; }
@@ -415,6 +457,7 @@ struct HLSLMin16Uint_t {
415457

416458
uint32_t Val;
417459
};
460+
418461
enum class InputSet {
419462
#define INPUT_SET(SYMBOL) SYMBOL,
420463
#include "LongVectorOps.def"
@@ -656,6 +699,7 @@ BEGIN_INPUT_SETS(HLSLMin16Int_t)
656699
INPUT_SET(InputSet::Default1, -6, 1, 7, 3, 8, 4, -3, 8, 8, -2);
657700
INPUT_SET(InputSet::Default2, 5, -6, -3, -2, 9, 3, 1, -3, -7, 2);
658701
INPUT_SET(InputSet::Default3, -5, 6, 3, 2, -9, -3, -1, 3, 7, -2);
702+
INPUT_SET(InputSet::BitShiftRhs, 1, 6, 3, 0, 9, 3, 12, 11, 11, 14);
659703
INPUT_SET(InputSet::Zero, 0);
660704
INPUT_SET(InputSet::NoZero, 1);
661705
INPUT_SET(InputSet::SelectCond, 0, 1);
@@ -671,6 +715,7 @@ INPUT_SET(InputSet::Default1, 3, 199, 3, 200, 5, 10, 22, 8, 9, 10);
671715
INPUT_SET(InputSet::Default2, 2, 111, 3, 4, 5, 9, 21, 8, 9, 10);
672716
INPUT_SET(InputSet::Default3, 4, 112, 4, 5, 3, 7, 21, 1, 11, 9);
673717
INPUT_SET(InputSet::Zero, 0);
718+
INPUT_SET(InputSet::BitShiftRhs, 1, 6, 3, 0, 9, 3, 11, 12, 12, 12);
674719
INPUT_SET(InputSet::SelectCond, 0, 1);
675720
INPUT_SET(InputSet::AllOnes, 1);
676721
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -743,12 +743,12 @@ template <typename T> T Saturate(T A) {
743743
}
744744

745745
template <typename T> T ReverseBits(T A) {
746-
T Result = 0;
746+
T Result = static_cast<T>(0);
747747
const size_t NumBits = sizeof(T) * 8;
748748
for (size_t I = 0; I < NumBits; I++) {
749-
Result <<= 1;
750-
Result |= (A & 1);
751-
A >>= 1;
749+
Result <<= static_cast<T>(1);
750+
Result |= (A & static_cast<T>(1));
751+
A >>= static_cast<T>(1);
752752
}
753753
return Result;
754754
}
@@ -760,12 +760,13 @@ template <typename T> uint32_t CountBits(T A) {
760760
// General purpose bit scan from the MSB. Based on the value of LookingForZero
761761
// returns the index of the first high/low bit found.
762762
template <typename T> uint32_t ScanFromMSB(T A, bool LookingForZero) {
763-
if (A == 0)
763+
if (A == static_cast<T>(0))
764764
return std::numeric_limits<uint32_t>::max();
765765

766766
constexpr uint32_t NumBits = sizeof(T) * 8;
767767
for (int32_t I = NumBits - 1; I >= 0; --I) {
768-
bool BitSet = (A & (static_cast<T>(1) << I)) != 0;
768+
bool BitSet =
769+
(A & (static_cast<T>(1) << static_cast<T>(I))) != static_cast<T>(0);
769770
if (BitSet != LookingForZero)
770771
return static_cast<uint32_t>(I);
771772
}
@@ -788,11 +789,11 @@ FirstBitHigh(T A) {
788789
template <typename T> uint32_t FirstBitLow(T A) {
789790
const uint32_t NumBits = sizeof(T) * 8;
790791

791-
if (A == 0)
792+
if (A == static_cast<T>(0))
792793
return std::numeric_limits<uint32_t>::max();
793794

794795
for (uint32_t I = 0; I < NumBits; ++I) {
795-
if (A & (static_cast<T>(1) << I))
796+
if (A & (static_cast<T>(1) << static_cast<T>(I)))
796797
return static_cast<T>(I);
797798
}
798799

@@ -1888,8 +1889,8 @@ void dispatchMinPrecisionTest(ID3D12Device *D3DDevice, bool VerboseLogging,
18881889
constexpr const Operation &Operation = getOperation(OP);
18891890
Op<OP, T, Operation.Arity> Op;
18901891

1891-
// Min precision buffer storage width is implementation-defined, so we use
1892-
// full-precision types for Load/Store via BUFFER_TYPE/BUFFER_OUT_TYPE defines.
1892+
// Min precision buffer storage width is implementation-defined, so we
1893+
// use full-precision types for buffer I/O via BUFFER_TYPE/BUFFER_OUT_TYPE.
18931894
for (size_t VectorSize : InputVectorSizes) {
18941895
std::vector<std::vector<T>> Inputs =
18951896
buildTestInputs<T>(VectorSize, Operation.InputSets, Operation.Arity);
@@ -1919,14 +1920,13 @@ void dispatchMinPrecisionWaveOpTest(ID3D12Device *D3DDevice,
19191920
constexpr const Operation &Operation = getOperation(OP);
19201921
Op<OP, T, Operation.Arity> Op;
19211922

1922-
// Min precision buffer storage width is implementation-defined, so we use
1923-
// full-precision types for Load/Store via BUFFER_TYPE/BUFFER_OUT_TYPE defines.
1923+
// Min precision buffer storage width is implementation-defined, so we
1924+
// use full-precision types for buffer I/O via BUFFER_TYPE/BUFFER_OUT_TYPE.
19241925
for (size_t VectorSize : InputVectorSizes) {
19251926
std::vector<std::vector<T>> Inputs =
19261927
buildTestInputs<T>(VectorSize, Operation.InputSets, Operation.Arity);
19271928

1928-
auto Expected =
1929-
ExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);
1929+
auto Expected = ExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);
19301930

19311931
using OutT = typename decltype(Expected)::value_type;
19321932

@@ -3015,11 +3015,15 @@ class DxilConf_SM69_Vectorized_Core : public TestClassCommon {
30153015
HLK_MIN_PRECISION_TEST(Min, HLSLMin16Int_t);
30163016
HLK_MIN_PRECISION_TEST(Max, HLSLMin16Int_t);
30173017

3018-
// Bitwise (logical and shift — bit-manipulation excluded)
3018+
// Bitwise
30193019
HLK_MIN_PRECISION_TEST(And, HLSLMin16Int_t);
30203020
HLK_MIN_PRECISION_TEST(Or, HLSLMin16Int_t);
30213021
HLK_MIN_PRECISION_TEST(Xor, HLSLMin16Int_t);
3022-
3022+
HLK_MIN_PRECISION_TEST(LeftShift, HLSLMin16Int_t);
3023+
HLK_MIN_PRECISION_TEST(RightShift, HLSLMin16Int_t);
3024+
// Note: ReverseBits, CountBits, FirstBitHigh, FirstBitLow excluded -
3025+
// DXC promotes min precision to i32 before these intrinsics, so they
3026+
// don't operate at min precision.
30233027

30243028
// UnaryMath
30253029
HLK_MIN_PRECISION_TEST(Abs, HLSLMin16Int_t);
@@ -3111,11 +3115,15 @@ class DxilConf_SM69_Vectorized_Core : public TestClassCommon {
31113115
HLK_MIN_PRECISION_TEST(Min, HLSLMin16Uint_t);
31123116
HLK_MIN_PRECISION_TEST(Max, HLSLMin16Uint_t);
31133117

3114-
// Bitwise (logical and shift — bit-manipulation excluded)
3118+
// Bitwise
31153119
HLK_MIN_PRECISION_TEST(And, HLSLMin16Uint_t);
31163120
HLK_MIN_PRECISION_TEST(Or, HLSLMin16Uint_t);
31173121
HLK_MIN_PRECISION_TEST(Xor, HLSLMin16Uint_t);
3118-
3122+
HLK_MIN_PRECISION_TEST(LeftShift, HLSLMin16Uint_t);
3123+
HLK_MIN_PRECISION_TEST(RightShift, HLSLMin16Uint_t);
3124+
// Note: ReverseBits, CountBits, FirstBitHigh, FirstBitLow excluded -
3125+
// DXC promotes min precision to i32 before these intrinsics, so they
3126+
// don't operate at min precision.
31193127

31203128
// UnaryMath
31213129
HLK_MIN_PRECISION_TEST(Abs, HLSLMin16Uint_t);

0 commit comments

Comments
 (0)