Skip to content

Commit fb89081

Browse files
Clean up vector handling code by introducing TestVector
1 parent 8ab7045 commit fb89081

2 files changed

Lines changed: 416 additions & 316 deletions

File tree

tools/clang/unittests/HLSLExec/CoopVec.h

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <DirectXMath.h>
66
#include <DirectXPackedVector.h>
7+
8+
#include <cstdlib>
79
#include <vector>
810

911
#include "dxc/Support/microcom.h"
@@ -61,6 +63,7 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler {
6163
};
6264

6365
namespace CoopVecHelpers {
66+
6467
template <typename EltTy>
6568
static std::vector<uint8_t> CreateAllOnesInputMatrix(uint32_t Width,
6669
uint32_t Height) {
@@ -354,6 +357,203 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
354357
return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32;
355358
}
356359
}
360+
361+
struct TestVector {
362+
private:
363+
size_t NumVectors = 0;
364+
size_t VectorSize = 0;
365+
size_t ElementSize = 0;
366+
size_t Stride = 0;
367+
size_t TotalBytes = 0;
368+
uint8_t *Buffer = nullptr;
369+
370+
public:
371+
TestVector(size_t NumVectors, size_t VectorSize, size_t ElementSize,
372+
size_t Alignment = 16)
373+
: NumVectors(NumVectors), VectorSize(VectorSize),
374+
ElementSize(ElementSize) {
375+
if (NumVectors == 0) {
376+
throw std::invalid_argument("NumVectors must be greater than 0");
377+
}
378+
if (VectorSize == 0) {
379+
throw std::invalid_argument("VectorSize must be greater than 0");
380+
}
381+
if (ElementSize == 0) {
382+
throw std::invalid_argument("ElementSize must be greater than 0");
383+
}
384+
385+
size_t VectorBytes = VectorSize * ElementSize;
386+
Stride = ((VectorBytes + Alignment - 1) / Alignment) * Alignment;
387+
TotalBytes = Stride * NumVectors;
388+
389+
void *Ptr = nullptr;
390+
#ifdef _MSC_VER
391+
Ptr = _aligned_malloc(TotalBytes, Alignment);
392+
#else
393+
Ptr = std::aligned_alloc(Alignment, TotalBytes);
394+
#endif
395+
Buffer = reinterpret_cast<uint8_t *>(Ptr);
396+
std::fill(Buffer, Buffer + TotalBytes, (uint8_t)0xFF);
397+
}
398+
399+
// Copy constructor
400+
TestVector(const TestVector &other)
401+
: NumVectors(other.NumVectors), VectorSize(other.VectorSize),
402+
ElementSize(other.ElementSize), Stride(other.Stride),
403+
TotalBytes(other.TotalBytes) {
404+
405+
void *Ptr = nullptr;
406+
#ifdef _MSC_VER
407+
Ptr = _aligned_malloc(TotalBytes, 16);
408+
#else
409+
Ptr = std::aligned_alloc(16, TotalBytes);
410+
#endif
411+
Buffer = reinterpret_cast<uint8_t *>(Ptr);
412+
413+
if (other.Buffer) {
414+
std::memcpy(Buffer, other.Buffer, TotalBytes);
415+
}
416+
}
417+
418+
// Move constructor
419+
TestVector(TestVector &&other) noexcept
420+
: NumVectors(other.NumVectors), VectorSize(other.VectorSize),
421+
ElementSize(other.ElementSize), Stride(other.Stride),
422+
TotalBytes(other.TotalBytes), Buffer(other.Buffer) {
423+
424+
// Reset the source object
425+
other.NumVectors = 0;
426+
other.VectorSize = 0;
427+
other.ElementSize = 0;
428+
other.Stride = 0;
429+
other.TotalBytes = 0;
430+
other.Buffer = nullptr;
431+
}
432+
433+
~TestVector() {
434+
if (Buffer) {
435+
#ifdef _MSC_VER
436+
_aligned_free(Buffer);
437+
#else
438+
std::free(Buffer);
439+
#endif
440+
}
441+
}
442+
443+
size_t getNumVectors() const { return NumVectors; }
444+
size_t getVectorSize() const { return VectorSize; }
445+
size_t getElementSize() const { return ElementSize; }
446+
size_t getStride() const { return Stride; }
447+
size_t getTotalBytes() const { return TotalBytes; }
448+
uint8_t *getBuffer() { return Buffer; }
449+
const uint8_t *getBuffer() const { return Buffer; }
450+
451+
template <typename T> T *getVector(size_t I) {
452+
uint8_t *Ptr = Buffer + I * Stride;
453+
return reinterpret_cast<T *>(Ptr);
454+
}
455+
456+
template <typename T> const T *getVector(size_t I) const {
457+
const uint8_t *Ptr = Buffer + I * Stride;
458+
return reinterpret_cast<const T *>(Ptr);
459+
}
460+
461+
template <typename T> void fill(const T &Value) {
462+
for (size_t I = 0; I < NumVectors; ++I) {
463+
T *Vec = getVector<T>(I);
464+
for (size_t J = 0; J < VectorSize; ++J)
465+
Vec[J] = Value;
466+
}
467+
}
468+
469+
template <typename T> void fillSimpleTestData() {
470+
// Create a vector of (1, 1, 0, ...)
471+
for (size_t I = 0; I < NumVectors; ++I) {
472+
T *Vec = getVector<T>(I);
473+
for (size_t J = 0; J < VectorSize; ++J)
474+
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
475+
// Special case for HALF, which requires conversion from float
476+
Vec[J] = static_cast<T>(
477+
ConvertFloat32ToFloat16((J == 0 || J == 1) ? 1.0f : 0.0f));
478+
} else {
479+
Vec[J] = static_cast<T>((J == 0 || J == 1) ? 1 : 0);
480+
}
481+
}
482+
}
483+
484+
static TestVector
485+
createSimpleTestVector(size_t NumVectors, size_t VectorSize,
486+
D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
487+
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
488+
size_t ElementSize;
489+
switch (DataType) {
490+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
491+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
492+
ElementSize = sizeof(int8_t);
493+
break;
494+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
495+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
496+
ElementSize = sizeof(int16_t);
497+
break;
498+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
499+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
500+
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
501+
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
502+
ElementSize = sizeof(int8_t);
503+
} else {
504+
ElementSize = sizeof(int32_t);
505+
}
506+
break;
507+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
508+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
509+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
510+
ElementSize = sizeof(DirectX::PackedVector::HALF);
511+
break;
512+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
513+
ElementSize = sizeof(float);
514+
break;
515+
default:
516+
throw std::invalid_argument("Unsupported data type");
517+
}
518+
TestVector Vec(NumVectors, VectorSize, ElementSize);
519+
switch (DataType) {
520+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
521+
Vec.fillSimpleTestData<int8_t>();
522+
break;
523+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
524+
Vec.fillSimpleTestData<uint8_t>();
525+
break;
526+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
527+
Vec.fillSimpleTestData<int16_t>();
528+
break;
529+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
530+
Vec.fillSimpleTestData<uint16_t>();
531+
break;
532+
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
533+
Vec.fillSimpleTestData<int32_t>();
534+
break;
535+
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
536+
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
537+
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
538+
Vec.fillSimpleTestData<uint8_t>();
539+
} else {
540+
Vec.fillSimpleTestData<uint32_t>();
541+
}
542+
break;
543+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
544+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
545+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
546+
Vec.fillSimpleTestData<DirectX::PackedVector::HALF>();
547+
break;
548+
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
549+
Vec.fillSimpleTestData<float>();
550+
break;
551+
default:
552+
throw std::invalid_argument("Unsupported data type");
553+
}
554+
return Vec;
555+
}
556+
};
357557
}; // namespace CoopVecHelpers
358558

359559
#endif // HAVE_COOPVEC_API

0 commit comments

Comments
 (0)