Skip to content

Commit 5dde799

Browse files
Finish support for NumLayers=2
1 parent 721087a commit 5dde799

2 files changed

Lines changed: 492 additions & 193 deletions

File tree

tools/clang/unittests/HLSLExec/CoopVec.h

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,74 @@ struct TestVector {
448448
uint8_t *getBuffer() { return Buffer; }
449449
const uint8_t *getBuffer() const { return Buffer; }
450450

451+
// Copy assignment operator
452+
TestVector &operator=(const TestVector &other) {
453+
if (this != &other) {
454+
// Free existing buffer
455+
if (Buffer) {
456+
#ifdef _MSC_VER
457+
_aligned_free(Buffer);
458+
#else
459+
std::free(Buffer);
460+
#endif
461+
Buffer = nullptr;
462+
}
463+
464+
// Copy metadata
465+
NumVectors = other.NumVectors;
466+
VectorSize = other.VectorSize;
467+
ElementSize = other.ElementSize;
468+
Stride = other.Stride;
469+
TotalBytes = other.TotalBytes;
470+
471+
// Allocate new buffer
472+
void *Ptr = nullptr;
473+
#ifdef _MSC_VER
474+
Ptr = _aligned_malloc(TotalBytes, 16);
475+
#else
476+
Ptr = std::aligned_alloc(16, TotalBytes);
477+
#endif
478+
Buffer = reinterpret_cast<uint8_t *>(Ptr);
479+
480+
// Copy data
481+
if (other.Buffer) {
482+
std::memcpy(Buffer, other.Buffer, TotalBytes);
483+
}
484+
}
485+
return *this;
486+
}
487+
488+
// Move assignment operator
489+
TestVector &operator=(TestVector &&other) noexcept {
490+
if (this != &other) {
491+
// Free existing buffer
492+
if (Buffer) {
493+
#ifdef _MSC_VER
494+
_aligned_free(Buffer);
495+
#else
496+
std::free(Buffer);
497+
#endif
498+
}
499+
500+
// Move metadata and buffer
501+
NumVectors = other.NumVectors;
502+
VectorSize = other.VectorSize;
503+
ElementSize = other.ElementSize;
504+
Stride = other.Stride;
505+
TotalBytes = other.TotalBytes;
506+
Buffer = other.Buffer;
507+
508+
// Reset the source object
509+
other.NumVectors = 0;
510+
other.VectorSize = 0;
511+
other.ElementSize = 0;
512+
other.Stride = 0;
513+
other.TotalBytes = 0;
514+
other.Buffer = nullptr;
515+
}
516+
return *this;
517+
}
518+
451519
template <typename T> T *getVector(size_t I) {
452520
uint8_t *Ptr = Buffer + I * Stride;
453521
return reinterpret_cast<T *>(Ptr);
@@ -481,6 +549,20 @@ struct TestVector {
481549
}
482550
}
483551

552+
template <typename T> void fillAllOnesTestData() {
553+
// Create a vector of (1, 1, 1, ...)
554+
for (size_t I = 0; I < NumVectors; ++I) {
555+
T *Vec = getVector<T>(I);
556+
for (size_t J = 0; J < VectorSize; ++J)
557+
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
558+
// Special case for HALF, which requires conversion from float
559+
Vec[J] = static_cast<T>(ConvertFloat32ToFloat16(1.0f));
560+
} else {
561+
Vec[J] = static_cast<T>(1);
562+
}
563+
}
564+
}
565+
484566
static TestVector
485567
createSimpleTestVector(size_t NumVectors, size_t VectorSize,
486568
D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
@@ -553,6 +635,199 @@ struct TestVector {
553635
}
554636
return Vec;
555637
}
638+
639+
static TestVector
640+
createAllOnesTestMatrix(size_t NumVectors, size_t VectorSize,
641+
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
642+
size_t ElementSize;
643+
switch (DataInterpretation) {
644+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
645+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
646+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
647+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
648+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
649+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
650+
ElementSize = sizeof(int8_t);
651+
break;
652+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
653+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
654+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
655+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
656+
ElementSize = sizeof(float);
657+
break;
658+
default:
659+
throw std::invalid_argument("Unsupported data type");
660+
}
661+
TestVector Vec(NumVectors, VectorSize, ElementSize);
662+
switch (DataInterpretation) {
663+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
664+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
665+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
666+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
667+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
668+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
669+
Vec.fillAllOnesTestData<int8_t>();
670+
break;
671+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
672+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
673+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
674+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
675+
Vec.fillAllOnesTestData<float>();
676+
break;
677+
default:
678+
throw std::invalid_argument("Unsupported data type");
679+
}
680+
return Vec;
681+
}
682+
683+
D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO
684+
getConversionInfo(ID3D12Device *D3DDevice,
685+
D3D12_LINEAR_ALGEBRA_DATATYPE DestDataType,
686+
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT MatrixLayout) {
687+
// Create source matrix info
688+
D3D12_LINEAR_ALGEBRA_MATRIX_CONVERSION_INFO ConvertInfo = {};
689+
ConvertInfo.SrcInfo.SrcDataType =
690+
::CoopVecHelpers::GetMatrixSrcDataType(DestDataType);
691+
ConvertInfo.SrcInfo.SrcLayout =
692+
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR;
693+
694+
// Create destination matrix info
695+
ConvertInfo.DestInfo.DestSize = 0; // Will be populated by driver
696+
int DestEltSize = 0;
697+
switch (DestDataType) {
698+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
699+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
700+
ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8;
701+
DestEltSize = 1;
702+
break;
703+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
704+
ConvertInfo.DestInfo.DestDataType = D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16;
705+
DestEltSize = 2; // FP16
706+
break;
707+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
708+
ConvertInfo.DestInfo.DestDataType =
709+
D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3;
710+
DestEltSize = 1; // FP8
711+
break;
712+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
713+
ConvertInfo.DestInfo.DestDataType =
714+
D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2;
715+
DestEltSize = 1; // FP8
716+
break;
717+
}
718+
ConvertInfo.SrcInfo.SrcStride = (UINT)getStride();
719+
ConvertInfo.SrcInfo.SrcSize = (UINT)getTotalBytes();
720+
721+
ConvertInfo.DestInfo.DestLayout = MatrixLayout;
722+
ConvertInfo.DestInfo.DestStride = 0;
723+
ConvertInfo.DestInfo.NumRows = (UINT)getNumVectors();
724+
ConvertInfo.DestInfo.NumColumns = (UINT)getVectorSize();
725+
726+
if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) {
727+
ConvertInfo.DestInfo.DestStride = (UINT)getVectorSize() * DestEltSize;
728+
} else if (MatrixLayout ==
729+
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) {
730+
ConvertInfo.DestInfo.DestStride = (UINT)getNumVectors() * DestEltSize;
731+
}
732+
733+
// Get destination size using preview interface
734+
{
735+
CComPtr<ID3D12DevicePreview> PreviewDevice;
736+
VERIFY_SUCCEEDED(D3DDevice->QueryInterface(__uuidof(ID3D12DevicePreview),
737+
(void **)&PreviewDevice));
738+
739+
// Query required destination size
740+
PreviewDevice->GetLinearAlgebraMatrixConversionDestinationInfo(
741+
&ConvertInfo.DestInfo);
742+
}
743+
744+
return ConvertInfo;
745+
}
746+
747+
static TestVector
748+
matrixVectorMultiply(const TestVector &Matrix, const TestVector &InputVector,
749+
const TestVector &Bias, bool HasBias,
750+
D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
751+
D3D12_LINEAR_ALGEBRA_DATATYPE InputType) {
752+
bool IsFP32 = false;
753+
switch (MatrixInterpretation) {
754+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
755+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
756+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
757+
IsFP32 = true;
758+
break;
759+
default:
760+
break;
761+
}
762+
763+
TestVector ResultVec(InputVector.getNumVectors(), Matrix.getNumVectors(),
764+
sizeof(float));
765+
766+
if (IsFP32) {
767+
for (int VecIdx = 0; VecIdx < InputVector.getNumVectors(); ++VecIdx) {
768+
const DirectX::PackedVector::HALF *InputBiasFP16 =
769+
Bias.getVector<DirectX::PackedVector::HALF>(0);
770+
for (int OutputIdx = 0; OutputIdx < Matrix.getNumVectors();
771+
++OutputIdx) {
772+
float Acc = 0;
773+
774+
for (int InputIdx = 0; InputIdx < Matrix.getVectorSize();
775+
++InputIdx) {
776+
float InputElem;
777+
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
778+
InputElem = InputVector.getVector<float>(VecIdx)[InputIdx];
779+
} else {
780+
InputElem = ConvertFloat16ToFloat32(
781+
InputVector.getVector<DirectX::PackedVector::HALF>(
782+
VecIdx)[InputIdx]);
783+
}
784+
float const MatrixElem =
785+
Matrix.getVector<float>(OutputIdx)[InputIdx];
786+
Acc += InputElem * MatrixElem;
787+
}
788+
789+
if (HasBias) {
790+
Acc += ConvertFloat16ToFloat32(InputBiasFP16[OutputIdx]);
791+
}
792+
793+
float Result = Acc;
794+
ResultVec.getVector<float>(VecIdx)[OutputIdx] = Result;
795+
}
796+
}
797+
} else if (MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) {
798+
for (int VecIdx = 0; VecIdx < InputVector.getNumVectors(); ++VecIdx) {
799+
const int32_t *InputBiasI32 = Bias.getVector<int32_t>(0);
800+
for (int OutputIdx = 0; OutputIdx < Matrix.getNumVectors();
801+
++OutputIdx) {
802+
int Acc = 0;
803+
804+
for (int InputIdx = 0; InputIdx < Matrix.getVectorSize();
805+
++InputIdx) {
806+
int InputElem;
807+
if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
808+
InputElem = (int)InputVector.getVector<float>(VecIdx)[InputIdx];
809+
} else {
810+
InputElem = InputVector.getVector<int8_t>(VecIdx)[InputIdx];
811+
}
812+
int const MatrixElem =
813+
Matrix.getVector<int8_t>(OutputIdx)[InputIdx];
814+
Acc += InputElem * MatrixElem;
815+
}
816+
817+
if (HasBias) {
818+
Acc += InputBiasI32[OutputIdx];
819+
}
820+
821+
float Result = float(Acc);
822+
ResultVec.getVector<float>(VecIdx)[OutputIdx] = Result;
823+
}
824+
}
825+
} else {
826+
throw std::invalid_argument("Unsupported matrix interpretation");
827+
}
828+
829+
return ResultVec;
830+
}
556831
};
557832
}; // namespace CoopVecHelpers
558833

0 commit comments

Comments
 (0)