Skip to content

Commit 811bf1f

Browse files
authored
[SM6.10][HLK] LinAlg element access exec tests (#8317)
Adds HLK smoke test for LinAlg matrix element access
1 parent 834978b commit 811bf1f

1 file changed

Lines changed: 204 additions & 4 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 204 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using hlsl::DXIL::MatrixScope;
3838
using hlsl::DXIL::MatrixUse;
3939

4040
/// Return the byte size of a single element for the given component type.
41-
static int elemSize(ComponentType CT) {
41+
static int elementSize(ComponentType CT) {
4242
switch (CT) {
4343
case ComponentType::F16:
4444
case ComponentType::I16:
@@ -64,22 +64,23 @@ struct MatrixParams {
6464
bool Enable16Bit;
6565

6666
int strideBytes() const {
67-
int ES = elemSize(CompType);
67+
int ES = elementSize(CompType);
6868
if (Layout == LinalgMatrixLayout::RowMajor)
6969
return N * ES;
7070
return M * ES;
7171
}
7272

7373
size_t totalElements() const { return static_cast<size_t>(M) * N; }
7474

75-
size_t totalBytes() const { return totalElements() * elemSize(CompType); }
75+
size_t totalBytes() const { return totalElements() * elementSize(CompType); }
7676
};
7777

7878
static std::string buildCompilerArgs(const MatrixParams &Params,
7979
const char *ExtraDefines = nullptr) {
8080
std::stringstream SS;
8181
SS << "-HV 202x";
8282
SS << " -DCOMP_TYPE=" << static_cast<int>(Params.CompType);
83+
SS << " -DCOMP_TYPE_F32=" << 9;
8384
SS << " -DM_DIM=" << Params.M;
8485
SS << " -DN_DIM=" << Params.N;
8586
SS << " -DUSE=" << static_cast<int>(Params.Use);
@@ -152,11 +153,18 @@ class DxilConf_SM610_LinAlg {
152153
TEST_CLASS_SETUP(setupClass);
153154
TEST_METHOD_SETUP(setupMethod);
154155

156+
// Load/Store
155157
TEST_METHOD(LoadStoreRoundtrip_Wave_F32);
156158
TEST_METHOD(LoadStoreRoundtrip_Wave_I32);
159+
160+
// Splat Store
157161
TEST_METHOD(SplatStore_Wave_F32);
158162
TEST_METHOD(SplatStore_Wave_I32);
159163

164+
// Element access
165+
TEST_METHOD(ElementAccess_Wave_F32);
166+
TEST_METHOD(ElementAccess_Wave_I32);
167+
160168
private:
161169
bool createDevice();
162170

@@ -266,7 +274,6 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
266274
}
267275
#endif
268276

269-
// Build expected data.
270277
std::vector<float> ExpectedFloats(NumElements);
271278
std::vector<int32_t> ExpectedInts(NumElements);
272279
for (size_t I = 0; I < NumElements; I++) {
@@ -456,4 +463,197 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
456463
runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging);
457464
}
458465

466+
static const char ElementAccessShader[] = R"(
467+
RWByteAddressBuffer Input : register(u0);
468+
RWByteAddressBuffer Output : register(u1);
469+
470+
// flatten the 2D index into a 1D index then scale by element size
471+
uint coordToByteOffset(uint2 coord) {
472+
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
473+
}
474+
475+
[numthreads(NUMTHREADS, 1, 1)]
476+
void main(uint threadIndex : SV_GroupIndex) {
477+
__builtin_LinAlgMatrix
478+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
479+
Mat;
480+
__builtin_LinAlg_MatrixLoadFromDescriptor(
481+
Mat, Input, 0, STRIDE, LAYOUT, 0);
482+
483+
// Copy Matrix values from input to output without assuming order
484+
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
485+
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
486+
uint Offset = coordToByteOffset(Coord);
487+
#if COMP_TYPE == COMP_TYPE_F32
488+
float Elem;
489+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
490+
Output.Store(Offset, asuint(Elem));
491+
#else
492+
uint Elem;
493+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
494+
Output.Store(Offset, Elem);
495+
#endif
496+
}
497+
498+
// Save the matrix length that this thread saw. The length is written
499+
// to the output right after the matrix, offset by the thread index
500+
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
501+
uint Len = __builtin_LinAlg_MatrixLength(Mat);
502+
Output.Store(LenIdx, Len);
503+
}
504+
)";
505+
506+
static void runElementAccess(ID3D12Device *Device,
507+
dxc::SpecificDllLoader &DxcSupport,
508+
const MatrixParams &Params, bool Verbose) {
509+
const size_t NumElements = Params.totalElements();
510+
const size_t NumThreads = Params.NumThreads;
511+
const size_t InputBufSize = Params.totalBytes();
512+
const size_t ElementSize = elementSize(Params.CompType);
513+
const size_t MajorDim =
514+
Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N;
515+
// Output: ElementSize bytes per element
516+
// 1 element for each mat idx
517+
// 1 uint for each thread's length
518+
const size_t OutputBufSize =
519+
NumElements * ElementSize + NumThreads * sizeof(uint32_t);
520+
521+
std::stringstream ExtraDefs;
522+
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
523+
ExtraDefs << " -DELEM_SIZE=" << ElementSize;
524+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
525+
526+
compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);
527+
528+
#ifndef _HLK_CONF
529+
// Skip GPU execution if no device.
530+
if (!Device) {
531+
hlsl_test::LogCommentFmt(
532+
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
533+
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
534+
return;
535+
}
536+
#endif
537+
538+
std::vector<float> ExpectedFloats(NumElements);
539+
std::vector<int32_t> ExpectedInts(NumElements);
540+
for (size_t I = 0; I < NumElements; I++) {
541+
ExpectedFloats[I] = static_cast<float>(I + 1);
542+
ExpectedInts[I] = static_cast<int32_t>(I + 1);
543+
}
544+
545+
auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
546+
Args.c_str());
547+
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
548+
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
549+
addRootUAV(Op.get(), 0, "Input");
550+
addRootUAV(Op.get(), 1, "Output");
551+
552+
auto Result =
553+
runShaderOp(Device, DxcSupport, std::move(Op),
554+
[&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
555+
if (_stricmp(Name, "Input") != 0)
556+
return;
557+
558+
switch (Params.CompType) {
559+
case ComponentType::F32: {
560+
float *Ptr = reinterpret_cast<float *>(Data.data());
561+
for (size_t I = 0; I < NumElements; I++)
562+
Ptr[I] = static_cast<float>(I + 1);
563+
break;
564+
}
565+
case ComponentType::I32: {
566+
int32_t *Ptr = reinterpret_cast<int32_t *>(Data.data());
567+
for (size_t I = 0; I < NumElements; I++)
568+
Ptr[I] = static_cast<int32_t>(I + 1);
569+
break;
570+
}
571+
default:
572+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
573+
break;
574+
}
575+
});
576+
577+
MappedData OutData;
578+
Result->Test->GetReadBackData("Output", &OutData);
579+
const BYTE *Out = static_cast<const BYTE *>(OutData.data());
580+
581+
std::vector<float> ActualFloats(NumElements);
582+
std::vector<int32_t> ActualInts(NumElements);
583+
for (size_t I = 0; I < NumElements; ++I) {
584+
switch (Params.CompType) {
585+
case ComponentType::F32: {
586+
float Actual;
587+
memcpy(&Actual, &Out[I * ElementSize], ElementSize);
588+
ActualFloats[I] = Actual;
589+
break;
590+
}
591+
case ComponentType::I32: {
592+
int32_t Actual;
593+
memcpy(&Actual, &Out[I * ElementSize], ElementSize);
594+
ActualInts[I] = Actual;
595+
break;
596+
}
597+
default:
598+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
599+
break;
600+
}
601+
}
602+
603+
// Verify element values match input data.
604+
switch (Params.CompType) {
605+
case ComponentType::F32:
606+
VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(),
607+
NumElements, Verbose));
608+
break;
609+
case ComponentType::I32:
610+
VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(),
611+
NumElements, Verbose));
612+
break;
613+
default:
614+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
615+
break;
616+
}
617+
618+
// The sum of the values returned by Length across all threads must be
619+
// greater than or equal to the total number of matrix elements
620+
size_t MatrixEndOffset = NumElements * ElementSize;
621+
size_t LengthValuesEnd = MatrixEndOffset + (NumThreads * sizeof(uint32_t));
622+
size_t TotalLength = 0;
623+
for (size_t I = MatrixEndOffset; I < LengthValuesEnd;
624+
I = I + sizeof(uint32_t)) {
625+
uint32_t Length;
626+
memcpy(&Length, &Out[I], sizeof(uint32_t));
627+
TotalLength += Length;
628+
}
629+
VERIFY_IS_TRUE(TotalLength >= NumElements,
630+
"Sum of all lengths must be gte num elements");
631+
}
632+
633+
void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
634+
MatrixParams Params = {};
635+
Params.CompType = ComponentType::F32;
636+
Params.M = 16;
637+
Params.N = 16;
638+
Params.Use = MatrixUse::Accumulator;
639+
Params.Scope = MatrixScope::Wave;
640+
Params.Layout = LinalgMatrixLayout::RowMajor;
641+
Params.NumThreads = 4;
642+
Params.Enable16Bit = false;
643+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
644+
}
645+
646+
void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
647+
MatrixParams Params = {};
648+
Params.CompType = ComponentType::I32;
649+
Params.M = 16;
650+
Params.N = 16;
651+
Params.Use = MatrixUse::Accumulator;
652+
Params.Scope = MatrixScope::Wave;
653+
Params.Layout = LinalgMatrixLayout::RowMajor;
654+
Params.NumThreads = 4;
655+
Params.Enable16Bit = false;
656+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
657+
}
658+
459659
} // namespace LinAlg

0 commit comments

Comments
 (0)