Skip to content

Commit 35a30df

Browse files
committed
[SM6.10][HLK] LinAlg element access exec tests
1 parent 2762c56 commit 35a30df

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,18 @@ class DxilConf_SM610_LinAlg {
152152
TEST_CLASS_SETUP(setupClass);
153153
TEST_METHOD_SETUP(setupMethod);
154154

155+
// Load/Store
155156
TEST_METHOD(LoadStoreRoundtrip_Wave_F32);
156157
TEST_METHOD(LoadStoreRoundtrip_Wave_I32);
158+
159+
// Splat Store
157160
TEST_METHOD(SplatStore_Wave_F32);
158161
TEST_METHOD(SplatStore_Wave_I32);
159162

163+
// Element access
164+
TEST_METHOD(ElementAccess_Wave_F32);
165+
TEST_METHOD(ElementAccess_Wave_I32);
166+
160167
private:
161168
bool createDevice();
162169

@@ -389,6 +396,7 @@ static void runSplatStore(ID3D12Device *Device,
389396
}
390397
#endif
391398

399+
// Build expected data.
392400
std::vector<float> ExpectedFloats;
393401
std::vector<int32_t> ExpectedInts;
394402
switch (Params.CompType) {
@@ -456,4 +464,187 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
456464
runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging);
457465
}
458466

467+
static const char ElementAccessShader[] = R"(
468+
RWByteAddressBuffer Input : register(u0);
469+
RWByteAddressBuffer Output : register(u1);
470+
471+
// 0,0 = 0
472+
// 0,1 = 4
473+
// 1,0 = 16
474+
// 3,3 = 60
475+
// TODO: this assumes M=4,N=4
476+
uint cordToByteOffset(uint2 coord) {
477+
return (coord.x * 4 + coord.y) * 4;
478+
}
479+
480+
[numthreads(NUMTHREADS, 1, 1)]
481+
void main(uint threadIndex : SV_GroupIndex) {
482+
__builtin_LinAlgMatrix
483+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
484+
Mat;
485+
__builtin_LinAlg_MatrixLoadFromDescriptor(
486+
Mat, Input, 0, STRIDE, LAYOUT, 0);
487+
488+
// Copy Matrix values from input to output without assuming order
489+
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
490+
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
491+
uint Offset = cordToByteOffset(Coord);
492+
#if COMP_TYPE == 9
493+
float Elem;
494+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
495+
Output.Store(Offset, asuint(Elem));
496+
#else
497+
uint Elem;
498+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
499+
Output.Store(Offset, Elem);
500+
#endif
501+
}
502+
503+
// Store each threads Length in the output after the copied matrix
504+
uint finalIdx = (M_DIM * N_DIM + threadIndex) * 4;
505+
uint Len = __builtin_LinAlg_MatrixLength(Mat);
506+
Output.Store(finalIdx, Len);
507+
}
508+
)";
509+
510+
static void runElementAccess(ID3D12Device *Device,
511+
dxc::SpecificDllLoader &DxcSupport,
512+
const MatrixParams &Params, bool Verbose) {
513+
const size_t NumElements = Params.totalElements();
514+
const size_t NumThreads = Params.NumThreads;
515+
const size_t InputBufSize = Params.totalBytes();
516+
// Output: 4 bytes per element
517+
// 1 element for each mat idx
518+
// 1 element for each thread's length
519+
const size_t OutputBufSize = (NumElements + NumThreads) * 4;
520+
521+
std::string Args = buildCompilerArgs(Params);
522+
523+
compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);
524+
525+
#ifndef _HLK_CONF
526+
// Skip GPU execution if no device.
527+
if (!Device) {
528+
hlsl_test::LogCommentFmt(
529+
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
530+
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
531+
return;
532+
}
533+
#endif
534+
535+
// Build expected data.
536+
std::vector<float> ExpectedFloats(NumElements);
537+
std::vector<int32_t> ExpectedInts(NumElements);
538+
for (size_t I = 0; I < NumElements; I++) {
539+
ExpectedFloats[I] = static_cast<float>(I + 1);
540+
ExpectedInts[I] = static_cast<int32_t>(I + 1);
541+
}
542+
543+
auto Op = createComputeOp(ElementAccessShader, "cs_6_10",
544+
"UAV(u0), UAV(u1)", Args.c_str());
545+
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
546+
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
547+
addRootUAV(Op.get(), 0, "Input");
548+
addRootUAV(Op.get(), 1, "Output");
549+
550+
auto Result = runShaderOp(
551+
Device, DxcSupport, std::move(Op),
552+
[&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
553+
if (_stricmp(Name, "Input") != 0)
554+
return;
555+
556+
switch (Params.CompType) {
557+
case ComponentType::F32: {
558+
float *Ptr = reinterpret_cast<float *>(Data.data());
559+
for (size_t I = 0; I < NumElements; I++)
560+
Ptr[I] = static_cast<float>(I + 1);
561+
break;
562+
}
563+
case ComponentType::I32: {
564+
int32_t *Ptr = reinterpret_cast<int32_t *>(Data.data());
565+
for (size_t I = 0; I < NumElements; I++)
566+
Ptr[I] = static_cast<int32_t>(I + 1);
567+
break;
568+
}
569+
default:
570+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
571+
break;
572+
}
573+
});
574+
575+
MappedData OutData;
576+
Result->Test->GetReadBackData("Output", &OutData);
577+
const uint32_t *Out = static_cast<const uint32_t *>(OutData.data());
578+
579+
// Build actual data.
580+
std::vector<float> ActualFloats(NumElements);
581+
std::vector<int32_t> ActualInts(NumElements);
582+
for (size_t I = 0; I < NumElements * 4; I = I + 4) {
583+
switch (Params.CompType) {
584+
case ComponentType::F32: {
585+
float Actual;
586+
memcpy(&Actual, &Out[I], sizeof(float));
587+
ActualFloats[I/4] = Actual;
588+
break;
589+
}
590+
case ComponentType::I32: {
591+
ActualInts[I/4] = Out[I];
592+
break;
593+
}
594+
default:
595+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
596+
break;
597+
}
598+
}
599+
600+
// Verify element values match input data.
601+
switch (Params.CompType) {
602+
case ComponentType::F32:
603+
VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(),
604+
NumElements, Verbose));
605+
break;
606+
case ComponentType::I32:
607+
VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(),
608+
NumElements, Verbose));
609+
break;
610+
default:
611+
VERIFY_IS_TRUE(false, "Saw unsupported component type");
612+
break;
613+
}
614+
615+
// The sum of the values returned by Length across all threads must be
616+
// greater than or equal to the total number of matrix elements
617+
size_t TotalLength = 0;
618+
for (size_t I = NumElements * 4; I < (NumElements + NumThreads) * 4; I = I + 4) {
619+
TotalLength += Out[I];
620+
}
621+
VERIFY_IS_TRUE(TotalLength >= NumElements, "Sum of all lengths must be gte num elements");
622+
}
623+
624+
void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
625+
MatrixParams Params = {};
626+
Params.CompType = ComponentType::F32;
627+
Params.M = 4;
628+
Params.N = 4;
629+
Params.Use = MatrixUse::Accumulator;
630+
Params.Scope = MatrixScope::Wave;
631+
Params.Layout = LinalgMatrixLayout::RowMajor;
632+
Params.NumThreads = 4;
633+
Params.Enable16Bit = false;
634+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
635+
}
636+
637+
void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
638+
MatrixParams Params = {};
639+
Params.CompType = ComponentType::I32;
640+
Params.M = 4;
641+
Params.N = 4;
642+
Params.Use = MatrixUse::Accumulator;
643+
Params.Scope = MatrixScope::Wave;
644+
Params.Layout = LinalgMatrixLayout::RowMajor;
645+
Params.NumThreads = 4;
646+
Params.Enable16Bit = false;
647+
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
648+
}
649+
459650
} // namespace LinAlg

0 commit comments

Comments
 (0)