66#include < DirectXPackedVector.h>
77
88#include < cstdlib>
9+ #include < random>
910#include < vector>
1011
1112#include " dxc/Support/microcom.h"
@@ -358,6 +359,15 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
358359 }
359360}
360361
362+ bool IsIntegralDataType (D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
363+ return DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 ||
364+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 ||
365+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 ||
366+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 ||
367+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 ||
368+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32;
369+ }
370+
361371struct TestVector {
362372private:
363373 size_t NumVectors = 0 ;
@@ -534,39 +544,61 @@ struct TestVector {
534544 }
535545 }
536546
537- template <typename T> void fillSimpleTestData () {
538- // Create a vector of (1, 1, 0, ...)
547+ template <typename T>
548+ void fillSimpleTestData (D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
549+ std::mt19937 &Rnd) {
539550 for (size_t I = 0 ; I < NumVectors; ++I) {
540551 T *Vec = getVector<T>(I);
541552 for (size_t J = 0 ; J < VectorSize; ++J)
542- if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
543- // Special case for HALF, which requires conversion from float
544- Vec[J] = static_cast <T>(
545- ConvertFloat32ToFloat16 ((J == 0 || J == 1 ) ? 1 .0f : 0 .0f ));
553+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF> ||
554+ std::is_same_v<T, float >) {
555+ float Elt = 0 .0f ;
556+ if (IsIntegralDataType (MatrixInterpretation)) {
557+ Elt = (float )(Rnd () & 0x7 ) - 3 .0f ;
558+ } else {
559+ Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
560+ }
561+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
562+ Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (Elt));
563+ } else {
564+ Vec[J] = static_cast <T>(Elt);
565+ }
546566 } else {
547- Vec[J] = static_cast <T>((J == 0 || J == 1 ) ? 1 : 0 );
567+ if constexpr (std::is_signed_v<T>) {
568+ Vec[J] = static_cast <T>((int32_t )(Rnd () & 0xf ) - 8 );
569+ } else {
570+ Vec[J] = static_cast <T>((uint32_t )(Rnd () & 0xf ));
571+ }
548572 }
549573 }
550574 }
551575
552- template <typename T> void fillAllOnesTestData () {
553- // Create a vector of (1, 1, 1, ...)
576+ template <typename T> void FillSimpleMatrixTestData (std::mt19937 &Rnd) {
554577 for (size_t I = 0 ; I < NumVectors; ++I) {
555578 T *Vec = getVector<T>(I);
556579 for (size_t J = 0 ; J < VectorSize; ++J)
557580 if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
558- // Special case for HALF, which requires conversion from float
559- Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (1 .0f ));
581+ float Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
582+ Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (Elt));
583+ } else if constexpr (std::is_same_v<T, float >) {
584+ float Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
585+ Vec[J] = static_cast <T>(Elt);
560586 } else {
561- Vec[J] = static_cast <T>(1 );
587+ if constexpr (std::is_signed_v<T>) {
588+ Vec[J] = static_cast <T>((int32_t )(Rnd () & 0xf ) - 8 );
589+ } else {
590+ Vec[J] = static_cast <T>((uint32_t )(Rnd () & 0xf ));
591+ }
562592 }
563593 }
564594 }
565595
566596 static TestVector
567597 createSimpleTestVector (size_t NumVectors, size_t VectorSize,
568598 D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
569- D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
599+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
600+ D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
601+ std::mt19937 &Rnd) {
570602 size_t ElementSize;
571603 switch (DataType) {
572604 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -600,35 +632,36 @@ struct TestVector {
600632 TestVector Vec (NumVectors, VectorSize, ElementSize);
601633 switch (DataType) {
602634 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
603- Vec.fillSimpleTestData <int8_t >();
635+ Vec.fillSimpleTestData <int8_t >(MatrixInterpretation, Rnd );
604636 break ;
605637 case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
606- Vec.fillSimpleTestData <uint8_t >();
638+ Vec.fillSimpleTestData <uint8_t >(MatrixInterpretation, Rnd );
607639 break ;
608640 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
609- Vec.fillSimpleTestData <int16_t >();
641+ Vec.fillSimpleTestData <int16_t >(MatrixInterpretation, Rnd );
610642 break ;
611643 case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
612- Vec.fillSimpleTestData <uint16_t >();
644+ Vec.fillSimpleTestData <uint16_t >(MatrixInterpretation, Rnd );
613645 break ;
614646 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
615- Vec.fillSimpleTestData <int32_t >();
647+ Vec.fillSimpleTestData <int32_t >(MatrixInterpretation, Rnd );
616648 break ;
617649 case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
618650 if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
619651 DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
620- Vec.fillSimpleTestData <uint8_t >();
652+ Vec.fillSimpleTestData <uint8_t >(MatrixInterpretation, Rnd );
621653 } else {
622- Vec.fillSimpleTestData <uint32_t >();
654+ Vec.fillSimpleTestData <uint32_t >(MatrixInterpretation, Rnd );
623655 }
624656 break ;
625657 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
626658 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
627659 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
628- Vec.fillSimpleTestData <DirectX::PackedVector::HALF>();
660+ Vec.fillSimpleTestData <DirectX::PackedVector::HALF>(MatrixInterpretation,
661+ Rnd);
629662 break ;
630663 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
631- Vec.fillSimpleTestData <float >();
664+ Vec.fillSimpleTestData <float >(MatrixInterpretation, Rnd );
632665 break ;
633666 default :
634667 throw std::invalid_argument (" Unsupported data type" );
@@ -638,7 +671,8 @@ struct TestVector {
638671
639672 static TestVector
640673 createAllOnesTestMatrix (size_t NumVectors, size_t VectorSize,
641- D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
674+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
675+ std::mt19937 &Rnd) {
642676 size_t ElementSize;
643677 switch (DataInterpretation) {
644678 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -666,13 +700,13 @@ struct TestVector {
666700 case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
667701 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
668702 case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
669- Vec.fillAllOnesTestData <int8_t >();
703+ Vec.FillSimpleMatrixTestData <int8_t >(Rnd );
670704 break ;
671705 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
672706 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
673707 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
674708 case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
675- Vec.fillAllOnesTestData <float >();
709+ Vec.FillSimpleMatrixTestData <float >(Rnd );
676710 break ;
677711 default :
678712 throw std::invalid_argument (" Unsupported data type" );
@@ -724,10 +758,12 @@ struct TestVector {
724758 ConvertInfo.DestInfo .NumColumns = (UINT)getVectorSize ();
725759
726760 if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) {
727- ConvertInfo.DestInfo .DestStride = (UINT)getVectorSize () * DestEltSize;
761+ ConvertInfo.DestInfo .DestStride =
762+ ((UINT)getVectorSize () * DestEltSize + 15 ) & ~15 ;
728763 } else if (MatrixLayout ==
729764 D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) {
730- ConvertInfo.DestInfo .DestStride = (UINT)getNumVectors () * DestEltSize;
765+ ConvertInfo.DestInfo .DestStride =
766+ ((UINT)getNumVectors () * DestEltSize + 15 ) & ~15 ;
731767 }
732768
733769 // Get destination size using preview interface
0 commit comments