Skip to content

Commit ed3744a

Browse files
Address style/format
1 parent b7bc46b commit ed3744a

2 files changed

Lines changed: 133 additions & 116 deletions

File tree

tools/clang/unittests/HLSLExec/CoopVec.h

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,62 @@ bool IsIntegralDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
369369
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32;
370370
}
371371

372-
struct TestVector {
372+
static size_t
373+
GetVectorElementSize(D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
374+
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
375+
switch (DataType) {
376+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
377+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
378+
return sizeof(int8_t);
379+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
380+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
381+
return sizeof(int16_t);
382+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
383+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
384+
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
385+
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
386+
return sizeof(int8_t);
387+
} else {
388+
return sizeof(int32_t);
389+
}
390+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
391+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
392+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
393+
return sizeof(DirectX::PackedVector::HALF);
394+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
395+
return sizeof(float);
396+
default:
397+
throw std::invalid_argument("Unsupported data type");
398+
}
399+
}
400+
401+
static size_t
402+
GetMatrixElementSize(D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
403+
switch (DataInterpretation) {
404+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
405+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
406+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
407+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
408+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
409+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
410+
// The CPU reference matrix is always int8 for all integer
411+
// interpretations. The GPU version will be converted to the destination
412+
// format by ConvertLinearAlgebraMatrix.
413+
return sizeof(int8_t);
414+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
415+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
416+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
417+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
418+
// The CPU reference matrix is always FP32 for all FP interpretations.
419+
// The GPU version will be converted to the destination format by
420+
// ConvertLinearAlgebraMatrix.
421+
return sizeof(float);
422+
default:
423+
throw std::invalid_argument("Unsupported data type");
424+
}
425+
}
426+
427+
class TestVector {
373428
private:
374429
size_t NumVectors = 0;
375430
size_t VectorSize = 0;
@@ -390,7 +445,7 @@ struct TestVector {
390445
if (ElementSize == 0)
391446
throw std::invalid_argument("ElementSize must be greater than 0");
392447

393-
size_t VectorBytes = VectorSize * ElementSize;
448+
const size_t VectorBytes = VectorSize * ElementSize;
394449
Stride = ((VectorBytes + Alignment - 1) / Alignment) * Alignment;
395450
TotalBytes = Stride * NumVectors;
396451

@@ -550,22 +605,21 @@ struct TestVector {
550605
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF> ||
551606
std::is_same_v<T, float>) {
552607
float Elt = 0.0f;
553-
if (IsIntegralDataType(MatrixInterpretation)) {
608+
609+
if (IsIntegralDataType(MatrixInterpretation))
554610
Elt = (float)(Rnd() & 0x7) - 3.0f;
555-
} else {
611+
else
556612
Elt = ((float)(Rnd() & 0x3) - 1.0f) / 2.0f;
557-
}
558-
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
613+
614+
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>)
559615
Vec[J] = static_cast<T>(ConvertFloat32ToFloat16(Elt));
560-
} else {
616+
else
561617
Vec[J] = static_cast<T>(Elt);
562-
}
563618
} else {
564-
if constexpr (std::is_signed_v<T>) {
619+
if constexpr (std::is_signed_v<T>)
565620
Vec[J] = static_cast<T>((int32_t)(Rnd() & 0xf) - 8);
566-
} else {
621+
else
567622
Vec[J] = static_cast<T>((uint32_t)(Rnd() & 0xf));
568-
}
569623
}
570624
}
571625
}
@@ -596,36 +650,9 @@ struct TestVector {
596650
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
597651
D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
598652
std::mt19937 &Rnd) {
599-
size_t ElementSize;
600-
switch (DataType) {
601-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
602-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
603-
ElementSize = sizeof(int8_t);
604-
break;
605-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
606-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
607-
ElementSize = sizeof(int16_t);
608-
break;
609-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
610-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
611-
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
612-
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
613-
ElementSize = sizeof(int8_t);
614-
} else {
615-
ElementSize = sizeof(int32_t);
616-
}
617-
break;
618-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
619-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
620-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
621-
ElementSize = sizeof(DirectX::PackedVector::HALF);
622-
break;
623-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
624-
ElementSize = sizeof(float);
625-
break;
626-
default:
627-
throw std::invalid_argument("Unsupported data type");
628-
}
653+
const size_t ElementSize =
654+
::CoopVecHelpers::GetVectorElementSize(DataType, DataInterpretation);
655+
629656
TestVector Vec(NumVectors, VectorSize, ElementSize);
630657
switch (DataType) {
631658
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -670,25 +697,9 @@ struct TestVector {
670697
createAllOnesTestMatrix(size_t NumVectors, size_t VectorSize,
671698
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
672699
std::mt19937 &Rnd) {
673-
size_t ElementSize;
674-
switch (DataInterpretation) {
675-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
676-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
677-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
678-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
679-
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
680-
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
681-
ElementSize = sizeof(int8_t);
682-
break;
683-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
684-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
685-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
686-
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
687-
ElementSize = sizeof(float);
688-
break;
689-
default:
690-
throw std::invalid_argument("Unsupported data type");
691-
}
700+
const size_t ElementSize =
701+
::CoopVecHelpers::GetMatrixElementSize(DataInterpretation);
702+
692703
TestVector Vec(NumVectors, VectorSize, ElementSize);
693704
switch (DataInterpretation) {
694705
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -808,21 +819,20 @@ struct TestVector {
808819
for (size_t InputIdx = 0; InputIdx < Matrix.getVectorSize();
809820
++InputIdx) {
810821
float InputElem;
811-
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
822+
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32)
812823
InputElem = InputVector.getVector<float>(VecIdx)[InputIdx];
813-
} else {
824+
else
814825
InputElem = ConvertFloat16ToFloat32(
815826
InputVector.getVector<DirectX::PackedVector::HALF>(
816827
VecIdx)[InputIdx]);
817-
}
828+
818829
float const MatrixElem =
819830
Matrix.getVector<float>(OutputIdx)[InputIdx];
820831
Acc += InputElem * MatrixElem;
821832
}
822833

823-
if (HasBias) {
834+
if (HasBias)
824835
Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]);
825-
}
826836

827837
float Result = Acc;
828838
ResultVec.getVector<float>(VecIdx)[OutputIdx] = Result;
@@ -838,19 +848,19 @@ struct TestVector {
838848
for (size_t InputIdx = 0; InputIdx < Matrix.getVectorSize();
839849
++InputIdx) {
840850
int InputElem;
841-
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
842-
InputElem = (int)InputVector.getVector<float>(VecIdx)[InputIdx];
843-
} else {
851+
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32)
852+
InputElem = static_cast<int>(
853+
InputVector.getVector<float>(VecIdx)[InputIdx]);
854+
else
844855
InputElem = InputVector.getVector<int8_t>(VecIdx)[InputIdx];
845-
}
856+
846857
int const MatrixElem =
847858
Matrix.getVector<int8_t>(OutputIdx)[InputIdx];
848859
Acc += InputElem * MatrixElem;
849860
}
850861

851-
if (HasBias) {
862+
if (HasBias)
852863
Acc += InputBiasI32[OutputIdx];
853-
}
854864

855865
float Result = float(Acc);
856866
ResultVec.getVector<float>(VecIdx)[OutputIdx] = Result;

0 commit comments

Comments
 (0)