@@ -448,6 +448,74 @@ struct TestVector {
448448 uint8_t *getBuffer () { return Buffer; }
449449 const uint8_t *getBuffer () const { return Buffer; }
450450
451+ // Copy assignment operator
452+ TestVector &operator =(const TestVector &other) {
453+ if (this != &other) {
454+ // Free existing buffer
455+ if (Buffer) {
456+ #ifdef _MSC_VER
457+ _aligned_free (Buffer);
458+ #else
459+ std::free (Buffer);
460+ #endif
461+ Buffer = nullptr ;
462+ }
463+
464+ // Copy metadata
465+ NumVectors = other.NumVectors ;
466+ VectorSize = other.VectorSize ;
467+ ElementSize = other.ElementSize ;
468+ Stride = other.Stride ;
469+ TotalBytes = other.TotalBytes ;
470+
471+ // Allocate new buffer
472+ void *Ptr = nullptr ;
473+ #ifdef _MSC_VER
474+ Ptr = _aligned_malloc (TotalBytes, 16 );
475+ #else
476+ Ptr = std::aligned_alloc (16 , TotalBytes);
477+ #endif
478+ Buffer = reinterpret_cast <uint8_t *>(Ptr);
479+
480+ // Copy data
481+ if (other.Buffer ) {
482+ std::memcpy (Buffer, other.Buffer , TotalBytes);
483+ }
484+ }
485+ return *this ;
486+ }
487+
488+ // Move assignment operator
489+ TestVector &operator =(TestVector &&other) noexcept {
490+ if (this != &other) {
491+ // Free existing buffer
492+ if (Buffer) {
493+ #ifdef _MSC_VER
494+ _aligned_free (Buffer);
495+ #else
496+ std::free (Buffer);
497+ #endif
498+ }
499+
500+ // Move metadata and buffer
501+ NumVectors = other.NumVectors ;
502+ VectorSize = other.VectorSize ;
503+ ElementSize = other.ElementSize ;
504+ Stride = other.Stride ;
505+ TotalBytes = other.TotalBytes ;
506+ Buffer = other.Buffer ;
507+
508+ // Reset the source object
509+ other.NumVectors = 0 ;
510+ other.VectorSize = 0 ;
511+ other.ElementSize = 0 ;
512+ other.Stride = 0 ;
513+ other.TotalBytes = 0 ;
514+ other.Buffer = nullptr ;
515+ }
516+ return *this ;
517+ }
518+
451519 template <typename T> T *getVector (size_t I) {
452520 uint8_t *Ptr = Buffer + I * Stride;
453521 return reinterpret_cast <T *>(Ptr);
@@ -481,6 +549,20 @@ struct TestVector {
481549 }
482550 }
483551
552+ template <typename T> void fillAllOnesTestData () {
553+ // Create a vector of (1, 1, 1, ...)
554+ for (size_t I = 0 ; I < NumVectors; ++I) {
555+ T *Vec = getVector<T>(I);
556+ for (size_t J = 0 ; J < VectorSize; ++J)
557+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
558+ // Special case for HALF, which requires conversion from float
559+ Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (1 .0f ));
560+ } else {
561+ Vec[J] = static_cast <T>(1 );
562+ }
563+ }
564+ }
565+
484566 static TestVector
485567 createSimpleTestVector (size_t NumVectors, size_t VectorSize,
486568 D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
@@ -553,6 +635,199 @@ struct TestVector {
553635 }
554636 return Vec;
555637 }
638+
639+ static TestVector
640+ createAllOnesTestMatrix (size_t NumVectors, size_t VectorSize,
641+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
642+ size_t ElementSize;
643+ switch (DataInterpretation) {
644+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
645+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
646+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
647+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
648+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
649+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
650+ ElementSize = sizeof (int8_t );
651+ break ;
652+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
653+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
654+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
655+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
656+ ElementSize = sizeof (float );
657+ break ;
658+ default :
659+ throw std::invalid_argument (" Unsupported data type" );
660+ }
661+ TestVector Vec (NumVectors, VectorSize, ElementSize);
662+ switch (DataInterpretation) {
663+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
664+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
665+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
666+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
667+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
668+ case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
669+ Vec.fillAllOnesTestData <int8_t >();
670+ break ;
671+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
672+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
673+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
674+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
675+ Vec.fillAllOnesTestData <float >();
676+ break ;
677+ default :
678+ throw std::invalid_argument (" Unsupported data type" );
679+ }
680+ return Vec;
681+ }
682+
683+ D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO
684+ getConversionInfo (ID3D12Device *D3DDevice,
685+ D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType,
686+ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
687+ // Create source matrix info
688+ D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {};
689+ ConvertInfo.SrcInfo .SrcDataType =
690+ ::CoopVecHelpers::GetMatrixSrcDataType (DestDataType);
691+ ConvertInfo.SrcInfo .SrcLayout =
692+ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR;
693+
694+ // Create destination matrix info
695+ ConvertInfo.DestInfo .DestSize = 0 ; // Will be populated by driver
696+ int DestEltSize = 0 ;
697+ switch (DestDataType) {
698+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
699+ case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
700+ ConvertInfo.DestInfo .DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8;
701+ DestEltSize = 1 ;
702+ break ;
703+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
704+ ConvertInfo.DestInfo .DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16;
705+ DestEltSize = 2 ; // FP16
706+ break ;
707+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
708+ ConvertInfo.DestInfo .DestDataType =
709+ D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3;
710+ DestEltSize = 1 ; // FP8
711+ break ;
712+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
713+ ConvertInfo.DestInfo .DestDataType =
714+ D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2;
715+ DestEltSize = 1 ; // FP8
716+ break ;
717+ }
718+ ConvertInfo.SrcInfo .SrcStride = (UINT)getStride ();
719+ ConvertInfo.SrcInfo .SrcSize = (UINT)getTotalBytes ();
720+
721+ ConvertInfo.DestInfo .DestLayout = MatrixLayout;
722+ ConvertInfo.DestInfo .DestStride = 0 ;
723+ ConvertInfo.DestInfo .NumRows = (UINT)getNumVectors ();
724+ ConvertInfo.DestInfo .NumColumns = (UINT)getVectorSize ();
725+
726+ if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) {
727+ ConvertInfo.DestInfo .DestStride = (UINT)getVectorSize () * DestEltSize;
728+ } else if (MatrixLayout ==
729+ D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) {
730+ ConvertInfo.DestInfo .DestStride = (UINT)getNumVectors () * DestEltSize;
731+ }
732+
733+ // Get destination size using preview interface
734+ {
735+ CComPtr<ID3D12DevicePreview> PreviewDevice;
736+ VERIFY_SUCCEEDED (D3DDevice->QueryInterface (__uuidof (ID3D12DevicePreview),
737+ (void **)&PreviewDevice));
738+
739+ // Query required destination size
740+ PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo (
741+ &ConvertInfo.DestInfo );
742+ }
743+
744+ return ConvertInfo;
745+ }
746+
747+ static TestVector
748+ matrixVectorMultiply (const TestVector &Matrix, const TestVector &InputVector,
749+ const TestVector &Bias, bool HasBias,
750+ D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
751+ D3D12_LINEAR_ALGEBRA_DATATYPE InputType) {
752+ bool IsFP32 = false ;
753+ switch (MatrixInterpretation) {
754+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
755+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
756+ case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
757+ IsFP32 = true ;
758+ break ;
759+ default :
760+ break ;
761+ }
762+
763+ TestVector ResultVec (InputVector.getNumVectors (), Matrix.getNumVectors (),
764+ sizeof (float ));
765+
766+ if (IsFP32) {
767+ for (int VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
768+ const DirectX::PackedVector::HALF *InputBiasFP16 =
769+ Bias.getVector <DirectX::PackedVector::HALF>(0 );
770+ for (int OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
771+ ++OutputIdx) {
772+ float Acc = 0 ;
773+
774+ for (int InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
775+ ++InputIdx) {
776+ float InputElem;
777+ if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
778+ InputElem = InputVector.getVector <float >(VecIdx)[InputIdx];
779+ } else {
780+ InputElem = ConvertFloat16ToFloat32 (
781+ InputVector.getVector <DirectX::PackedVector::HALF>(
782+ VecIdx)[InputIdx]);
783+ }
784+ float const MatrixElem =
785+ Matrix.getVector <float >(OutputIdx)[InputIdx];
786+ Acc += InputElem * MatrixElem;
787+ }
788+
789+ if (HasBias) {
790+ Acc += ConvertFloat16ToFloat32 (InputBiasFP16[OutputIdx]);
791+ }
792+
793+ float Result = Acc;
794+ ResultVec.getVector <float >(VecIdx)[OutputIdx] = Result;
795+ }
796+ }
797+ } else if (MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) {
798+ for (int VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
799+ const int32_t *InputBiasI32 = Bias.getVector <int32_t >(0 );
800+ for (int OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
801+ ++OutputIdx) {
802+ int Acc = 0 ;
803+
804+ for (int InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
805+ ++InputIdx) {
806+ int InputElem;
807+ if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
808+ InputElem = (int )InputVector.getVector <float >(VecIdx)[InputIdx];
809+ } else {
810+ InputElem = InputVector.getVector <int8_t >(VecIdx)[InputIdx];
811+ }
812+ int const MatrixElem =
813+ Matrix.getVector <int8_t >(OutputIdx)[InputIdx];
814+ Acc += InputElem * MatrixElem;
815+ }
816+
817+ if (HasBias) {
818+ Acc += InputBiasI32[OutputIdx];
819+ }
820+
821+ float Result = float (Acc);
822+ ResultVec.getVector <float >(VecIdx)[OutputIdx] = Result;
823+ }
824+ }
825+ } else {
826+ throw std::invalid_argument (" Unsupported matrix interpretation" );
827+ }
828+
829+ return ResultVec;
830+ }
556831};
557832}; // namespace CoopVecHelpers
558833
0 commit comments