Skip to content

Commit 7ce1fe2

Browse files
committed
AccumulateToDescriptor
1 parent 40ad371 commit 7ce1fe2

1 file changed

Lines changed: 79 additions & 3 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ class DxilConf_SM610_LinAlg {
315315
TEST_CLASS_SETUP(setupClass);
316316
TEST_METHOD_SETUP(setupMethod);
317317

318-
// Load/Store
318+
// Load/Store/Accumulate Descriptor
319319
TEST_METHOD(LoadStoreDescriptor_Wave_16x16_F16);
320-
321-
// Splat Store
322320
TEST_METHOD(SplatStore_Wave_16x16_F16);
321+
TEST_METHOD(AccumulateDescriptor_Wave_16x16_F16);
322+
TEST_METHOD(AccumulateDescriptor_Thread_16x16_F16);
323323

324324
// Element access
325325
TEST_METHOD(ElementAccess_Wave_16x16_F16);
@@ -528,6 +528,82 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
528528
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
529529
}
530530

531+
static const char AccumulateDescriptorShader[] = R"(
532+
#define USE_ACC 2
533+
RWByteAddressBuffer Output : register(u0);
534+
535+
[WaveSize(4, 64)]
536+
[numthreads(NUMTHREADS, 1, 1)]
537+
void main(uint threadID : SV_GroupIndex) {
538+
if (WaveReadLaneFirst(threadID) != 0)
539+
return;
540+
541+
__builtin_LinAlgMatrix
542+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
543+
Mat;
544+
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
545+
__builtin_LinAlg_MatrixAccumulateToDescriptor(
546+
Mat, Output, 0, STRIDE, LAYOUT, 128);
547+
}
548+
)";
549+
550+
static void runAccumulateDescriptor(ID3D12Device *Device,
551+
dxc::SpecificDllLoader &DxcSupport,
552+
const MatrixParams &Params, float FillValue,
553+
bool Verbose) {
554+
const size_t NumElements = Params.totalElements();
555+
const size_t BufferSize = Params.totalBytes();
556+
557+
std::stringstream ExtraDefs;
558+
ExtraDefs << "-DFILL_VALUE=" << FillValue;
559+
560+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
561+
562+
compileShader(DxcSupport, AccumulateDescriptorShader, "cs_6_10", Args, Verbose);
563+
564+
auto Expected =
565+
makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue, false);
566+
567+
auto Op =
568+
createComputeOp(AccumulateDescriptorShader, "cs_6_10", "UAV(u0)", Args.c_str());
569+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
570+
addRootUAV(Op.get(), 0, "Output");
571+
572+
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
573+
574+
MappedData OutData;
575+
Result->Test->GetReadBackData("Output", &OutData);
576+
577+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
578+
Expected, NumElements, Verbose));
579+
}
580+
581+
void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
582+
MatrixParams Params = {};
583+
Params.CompType = ComponentType::F16;
584+
Params.M = 16;
585+
Params.N = 16;
586+
Params.Use = MatrixUse::Accumulator;
587+
Params.Scope = MatrixScope::Wave;
588+
Params.Layout = LinalgMatrixLayout::RowMajor;
589+
Params.NumThreads = 64;
590+
Params.Enable16Bit = true;
591+
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
592+
}
593+
594+
void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() {
595+
MatrixParams Params = {};
596+
Params.CompType = ComponentType::F16;
597+
Params.M = 16;
598+
Params.N = 16;
599+
Params.Use = MatrixUse::Accumulator;
600+
Params.Scope = MatrixScope::Thread;
601+
Params.Layout = LinalgMatrixLayout::RowMajor;
602+
Params.NumThreads = 1;
603+
Params.Enable16Bit = true;
604+
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
605+
}
606+
531607
static const char ElementAccessShader[] = R"(
532608
RWByteAddressBuffer Input : register(u0);
533609
RWByteAddressBuffer Output : register(u1);

0 commit comments

Comments
 (0)