|
4 | 4 |
|
5 | 5 | #include <DirectXMath.h> |
6 | 6 | #include <DirectXPackedVector.h> |
| 7 | + |
| 8 | +#include <cstdlib> |
7 | 9 | #include <vector> |
8 | 10 |
|
9 | 11 | #include "dxc/Support/microcom.h" |
@@ -61,6 +63,7 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler { |
61 | 63 | }; |
62 | 64 |
|
63 | 65 | namespace CoopVecHelpers { |
| 66 | + |
64 | 67 | template <typename EltTy> |
65 | 68 | static std::vector<uint8_t> CreateAllOnesInputMatrix(uint32_t Width, |
66 | 69 | uint32_t Height) { |
@@ -354,6 +357,203 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) { |
354 | 357 | return D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32; |
355 | 358 | } |
356 | 359 | } |
| 360 | + |
| 361 | +struct TestVector { |
| 362 | +private: |
| 363 | + size_t NumVectors = 0; |
| 364 | + size_t VectorSize = 0; |
| 365 | + size_t ElementSize = 0; |
| 366 | + size_t Stride = 0; |
| 367 | + size_t TotalBytes = 0; |
| 368 | + uint8_t *Buffer = nullptr; |
| 369 | + |
| 370 | +public: |
| 371 | + TestVector(size_t NumVectors, size_t VectorSize, size_t ElementSize, |
| 372 | + size_t Alignment = 16) |
| 373 | + : NumVectors(NumVectors), VectorSize(VectorSize), |
| 374 | + ElementSize(ElementSize) { |
| 375 | + if (NumVectors == 0) { |
| 376 | + throw std::invalid_argument("NumVectors must be greater than 0"); |
| 377 | + } |
| 378 | + if (VectorSize == 0) { |
| 379 | + throw std::invalid_argument("VectorSize must be greater than 0"); |
| 380 | + } |
| 381 | + if (ElementSize == 0) { |
| 382 | + throw std::invalid_argument("ElementSize must be greater than 0"); |
| 383 | + } |
| 384 | + |
| 385 | + size_t VectorBytes = VectorSize * ElementSize; |
| 386 | + Stride = ((VectorBytes + Alignment - 1) / Alignment) * Alignment; |
| 387 | + TotalBytes = Stride * NumVectors; |
| 388 | + |
| 389 | + void *Ptr = nullptr; |
| 390 | +#ifdef _MSC_VER |
| 391 | + Ptr = _aligned_malloc(TotalBytes, Alignment); |
| 392 | +#else |
| 393 | + Ptr = std::aligned_alloc(Alignment, TotalBytes); |
| 394 | +#endif |
| 395 | + Buffer = reinterpret_cast<uint8_t *>(Ptr); |
| 396 | + std::fill(Buffer, Buffer + TotalBytes, (uint8_t)0xFF); |
| 397 | + } |
| 398 | + |
| 399 | + // Copy constructor |
| 400 | + TestVector(const TestVector &other) |
| 401 | + : NumVectors(other.NumVectors), VectorSize(other.VectorSize), |
| 402 | + ElementSize(other.ElementSize), Stride(other.Stride), |
| 403 | + TotalBytes(other.TotalBytes) { |
| 404 | + |
| 405 | + void *Ptr = nullptr; |
| 406 | +#ifdef _MSC_VER |
| 407 | + Ptr = _aligned_malloc(TotalBytes, 16); |
| 408 | +#else |
| 409 | + Ptr = std::aligned_alloc(16, TotalBytes); |
| 410 | +#endif |
| 411 | + Buffer = reinterpret_cast<uint8_t *>(Ptr); |
| 412 | + |
| 413 | + if (other.Buffer) { |
| 414 | + std::memcpy(Buffer, other.Buffer, TotalBytes); |
| 415 | + } |
| 416 | + } |
| 417 | + |
| 418 | + // Move constructor |
| 419 | + TestVector(TestVector &&other) noexcept |
| 420 | + : NumVectors(other.NumVectors), VectorSize(other.VectorSize), |
| 421 | + ElementSize(other.ElementSize), Stride(other.Stride), |
| 422 | + TotalBytes(other.TotalBytes), Buffer(other.Buffer) { |
| 423 | + |
| 424 | + // Reset the source object |
| 425 | + other.NumVectors = 0; |
| 426 | + other.VectorSize = 0; |
| 427 | + other.ElementSize = 0; |
| 428 | + other.Stride = 0; |
| 429 | + other.TotalBytes = 0; |
| 430 | + other.Buffer = nullptr; |
| 431 | + } |
| 432 | + |
| 433 | + ~TestVector() { |
| 434 | + if (Buffer) { |
| 435 | +#ifdef _MSC_VER |
| 436 | + _aligned_free(Buffer); |
| 437 | +#else |
| 438 | + std::free(Buffer); |
| 439 | +#endif |
| 440 | + } |
| 441 | + } |
| 442 | + |
| 443 | + size_t getNumVectors() const { return NumVectors; } |
| 444 | + size_t getVectorSize() const { return VectorSize; } |
| 445 | + size_t getElementSize() const { return ElementSize; } |
| 446 | + size_t getStride() const { return Stride; } |
| 447 | + size_t getTotalBytes() const { return TotalBytes; } |
| 448 | + uint8_t *getBuffer() { return Buffer; } |
| 449 | + const uint8_t *getBuffer() const { return Buffer; } |
| 450 | + |
| 451 | + template <typename T> T *getVector(size_t I) { |
| 452 | + uint8_t *Ptr = Buffer + I * Stride; |
| 453 | + return reinterpret_cast<T *>(Ptr); |
| 454 | + } |
| 455 | + |
| 456 | + template <typename T> const T *getVector(size_t I) const { |
| 457 | + const uint8_t *Ptr = Buffer + I * Stride; |
| 458 | + return reinterpret_cast<const T *>(Ptr); |
| 459 | + } |
| 460 | + |
| 461 | + template <typename T> void fill(const T &Value) { |
| 462 | + for (size_t I = 0; I < NumVectors; ++I) { |
| 463 | + T *Vec = getVector<T>(I); |
| 464 | + for (size_t J = 0; J < VectorSize; ++J) |
| 465 | + Vec[J] = Value; |
| 466 | + } |
| 467 | + } |
| 468 | + |
| 469 | + template <typename T> void fillSimpleTestData() { |
| 470 | + // Create a vector of (1, 1, 0, ...) |
| 471 | + for (size_t I = 0; I < NumVectors; ++I) { |
| 472 | + T *Vec = getVector<T>(I); |
| 473 | + for (size_t J = 0; J < VectorSize; ++J) |
| 474 | + if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) { |
| 475 | + // Special case for HALF, which requires conversion from float |
| 476 | + Vec[J] = static_cast<T>( |
| 477 | + ConvertFloat32ToFloat16((J == 0 || J == 1) ? 1.0f : 0.0f)); |
| 478 | + } else { |
| 479 | + Vec[J] = static_cast<T>((J == 0 || J == 1) ? 1 : 0); |
| 480 | + } |
| 481 | + } |
| 482 | + } |
| 483 | + |
| 484 | + static TestVector |
| 485 | + createSimpleTestVector(size_t NumVectors, size_t VectorSize, |
| 486 | + D3D12_LINEAR_ALGEBRA_DATATYPE DataType, |
| 487 | + D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) { |
| 488 | + size_t ElementSize; |
| 489 | + switch (DataType) { |
| 490 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: |
| 491 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: |
| 492 | + ElementSize = sizeof(int8_t); |
| 493 | + break; |
| 494 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: |
| 495 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: |
| 496 | + ElementSize = sizeof(int16_t); |
| 497 | + break; |
| 498 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: |
| 499 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: |
| 500 | + if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || |
| 501 | + DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) { |
| 502 | + ElementSize = sizeof(int8_t); |
| 503 | + } else { |
| 504 | + ElementSize = sizeof(int32_t); |
| 505 | + } |
| 506 | + break; |
| 507 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: |
| 508 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: |
| 509 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: |
| 510 | + ElementSize = sizeof(DirectX::PackedVector::HALF); |
| 511 | + break; |
| 512 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: |
| 513 | + ElementSize = sizeof(float); |
| 514 | + break; |
| 515 | + default: |
| 516 | + throw std::invalid_argument("Unsupported data type"); |
| 517 | + } |
| 518 | + TestVector Vec(NumVectors, VectorSize, ElementSize); |
| 519 | + switch (DataType) { |
| 520 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8: |
| 521 | + Vec.fillSimpleTestData<int8_t>(); |
| 522 | + break; |
| 523 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8: |
| 524 | + Vec.fillSimpleTestData<uint8_t>(); |
| 525 | + break; |
| 526 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16: |
| 527 | + Vec.fillSimpleTestData<int16_t>(); |
| 528 | + break; |
| 529 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16: |
| 530 | + Vec.fillSimpleTestData<uint16_t>(); |
| 531 | + break; |
| 532 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32: |
| 533 | + Vec.fillSimpleTestData<int32_t>(); |
| 534 | + break; |
| 535 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32: |
| 536 | + if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED || |
| 537 | + DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) { |
| 538 | + Vec.fillSimpleTestData<uint8_t>(); |
| 539 | + } else { |
| 540 | + Vec.fillSimpleTestData<uint32_t>(); |
| 541 | + } |
| 542 | + break; |
| 543 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3: |
| 544 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2: |
| 545 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16: |
| 546 | + Vec.fillSimpleTestData<DirectX::PackedVector::HALF>(); |
| 547 | + break; |
| 548 | + case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32: |
| 549 | + Vec.fillSimpleTestData<float>(); |
| 550 | + break; |
| 551 | + default: |
| 552 | + throw std::invalid_argument("Unsupported data type"); |
| 553 | + } |
| 554 | + return Vec; |
| 555 | + } |
| 556 | +}; |
357 | 557 | }; // namespace CoopVecHelpers |
358 | 558 |
|
359 | 559 | #endif // HAVE_COOPVEC_API |
0 commit comments