@@ -369,7 +369,62 @@ bool IsIntegralDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
369369 DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32;
370370}
371371
372- struct TestVector {
372+ static size_t
373+ GetVectorElementSize (D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
374+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
375+ switch (DataType) {
376+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
377+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
378+ return sizeof (int8_t );
379+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
380+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
381+ return sizeof (int16_t );
382+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
383+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
384+ if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
385+ DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
386+ return sizeof (int8_t );
387+ } else {
388+ return sizeof (int32_t );
389+ }
390+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
391+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
392+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
393+ return sizeof (DirectX::PackedVector::HALF);
394+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
395+ return sizeof (float );
396+ default :
397+ throw std::invalid_argument (" Unsupported data type" );
398+ }
399+ }
400+
401+ static size_t
402+ GetMatrixElementSize (D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
403+ switch (DataInterpretation) {
404+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
405+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
406+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
407+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
408+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
409+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
410+ // The CPU reference matrix is always int8 for all integer
411+ // interpretations. The GPU version will be converted to the destination
412+ // format by ConvertLinearAlgebraMatrix.
413+ return sizeof (int8_t );
414+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
415+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
416+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
417+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
418+ // The CPU reference matrix is always FP32 for all FP interpretations.
419+ // The GPU version will be converted to the destination format by
420+ // ConvertLinearAlgebraMatrix.
421+ return sizeof (float );
422+ default :
423+ throw std::invalid_argument (" Unsupported data type" );
424+ }
425+ }
426+
427+ class TestVector {
373428private:
374429 size_t NumVectors = 0 ;
375430 size_t VectorSize = 0 ;
@@ -390,7 +445,7 @@ struct TestVector {
390445 if (ElementSize == 0 )
391446 throw std::invalid_argument (" ElementSize must be greater than 0" );
392447
393- size_t VectorBytes = VectorSize * ElementSize;
448+ const size_t VectorBytes = VectorSize * ElementSize;
394449 Stride = ((VectorBytes + Alignment - 1 ) / Alignment) * Alignment;
395450 TotalBytes = Stride * NumVectors;
396451
@@ -550,22 +605,21 @@ struct TestVector {
550605 if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF> ||
551606 std::is_same_v<T, float >) {
552607 float Elt = 0 .0f ;
553- if (IsIntegralDataType (MatrixInterpretation)) {
608+
609+ if (IsIntegralDataType (MatrixInterpretation))
554610 Elt = (float )(Rnd () & 0x7 ) - 3 .0f ;
555- } else {
611+ else
556612 Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
557- }
558- if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
613+
614+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>)
559615 Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (Elt));
560- } else {
616+ else
561617 Vec[J] = static_cast <T>(Elt);
562- }
563618 } else {
564- if constexpr (std::is_signed_v<T>) {
619+ if constexpr (std::is_signed_v<T>)
565620 Vec[J] = static_cast <T>((int32_t )(Rnd () & 0xf ) - 8 );
566- } else {
621+ else
567622 Vec[J] = static_cast <T>((uint32_t )(Rnd () & 0xf ));
568- }
569623 }
570624 }
571625 }
@@ -596,36 +650,9 @@ struct TestVector {
596650 D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
597651 D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
598652 std::mt19937 &Rnd) {
599- size_t ElementSize;
600- switch (DataType) {
601- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
602- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
603- ElementSize = sizeof (int8_t );
604- break ;
605- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
606- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
607- ElementSize = sizeof (int16_t );
608- break ;
609- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
610- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
611- if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
612- DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
613- ElementSize = sizeof (int8_t );
614- } else {
615- ElementSize = sizeof (int32_t );
616- }
617- break ;
618- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
619- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
620- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
621- ElementSize = sizeof (DirectX::PackedVector::HALF);
622- break ;
623- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
624- ElementSize = sizeof (float );
625- break ;
626- default :
627- throw std::invalid_argument (" Unsupported data type" );
628- }
653+ const size_t ElementSize =
654+ ::CoopVecHelpers::GetVectorElementSize (DataType, DataInterpretation);
655+
629656 TestVector Vec (NumVectors, VectorSize, ElementSize);
630657 switch (DataType) {
631658 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -670,25 +697,9 @@ struct TestVector {
670697 createAllOnesTestMatrix (size_t NumVectors, size_t VectorSize,
671698 D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
672699 std::mt19937 &Rnd) {
673- size_t ElementSize;
674- switch (DataInterpretation) {
675- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
676- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
677- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
678- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
679- case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
680- case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
681- ElementSize = sizeof (int8_t );
682- break ;
683- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
684- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
685- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
686- case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
687- ElementSize = sizeof (float );
688- break ;
689- default :
690- throw std::invalid_argument (" Unsupported data type" );
691- }
700+ const size_t ElementSize =
701+ ::CoopVecHelpers::GetMatrixElementSize (DataInterpretation);
702+
692703 TestVector Vec (NumVectors, VectorSize, ElementSize);
693704 switch (DataInterpretation) {
694705 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -808,21 +819,20 @@ struct TestVector {
808819 for (size_t InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
809820 ++InputIdx) {
810821 float InputElem;
811- if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
822+ if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32)
812823 InputElem = InputVector.getVector <float >(VecIdx)[InputIdx];
813- } else {
824+ else
814825 InputElem = ConvertFloat16ToFloat32 (
815826 InputVector.getVector <DirectX::PackedVector::HALF>(
816827 VecIdx)[InputIdx]);
817- }
828+
818829 float const MatrixElem =
819830 Matrix.getVector <float >(OutputIdx)[InputIdx];
820831 Acc += InputElem * MatrixElem;
821832 }
822833
823- if (HasBias) {
834+ if (HasBias)
824835 Acc += ConvertFloat16ToFloat32 (InputBiasFP16[OutputIdx]);
825- }
826836
827837 float Result = Acc;
828838 ResultVec.getVector <float >(VecIdx)[OutputIdx] = Result;
@@ -838,19 +848,19 @@ struct TestVector {
838848 for (size_t InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
839849 ++InputIdx) {
840850 int InputElem;
841- if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
842- InputElem = (int )InputVector.getVector <float >(VecIdx)[InputIdx];
843- } else {
851+ if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32)
852+ InputElem = static_cast <int >(
853+ InputVector.getVector <float >(VecIdx)[InputIdx]);
854+ else
844855 InputElem = InputVector.getVector <int8_t >(VecIdx)[InputIdx];
845- }
856+
846857 int const MatrixElem =
847858 Matrix.getVector <int8_t >(OutputIdx)[InputIdx];
848859 Acc += InputElem * MatrixElem;
849860 }
850861
851- if (HasBias) {
862+ if (HasBias)
852863 Acc += InputBiasI32[OutputIdx];
853- }
854864
855865 float Result = float (Acc);
856866 ResultVec.getVector <float >(VecIdx)[OutputIdx] = Result;
0 commit comments