@@ -315,11 +315,11 @@ class DxilConf_SM610_LinAlg {
315315 TEST_CLASS_SETUP (setupClass);
316316 TEST_METHOD_SETUP (setupMethod);
317317
318- // Load/Store
318+ // Load/Store/Accumulate Descriptor
319319 TEST_METHOD (LoadStoreDescriptor_Wave_16x16_F16);
320-
321- // Splat Store
322320 TEST_METHOD (SplatStore_Wave_16x16_F16);
321+ TEST_METHOD (AccumulateDescriptor_Wave_16x16_F16);
322+ TEST_METHOD (AccumulateDescriptor_Thread_16x16_F16);
323323
324324 // Element access
325325 TEST_METHOD (ElementAccess_Wave_16x16_F16);
@@ -528,6 +528,82 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
528528 runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
529529}
530530
531+ static const char AccumulateDescriptorShader[] = R"(
532+ #define USE_ACC 2
533+ RWByteAddressBuffer Output : register(u0);
534+
535+ [WaveSize(4, 64)]
536+ [numthreads(NUMTHREADS, 1, 1)]
537+ void main(uint threadID : SV_GroupIndex) {
538+ if (WaveReadLaneFirst(threadID) != 0)
539+ return;
540+
541+ __builtin_LinAlgMatrix
542+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
543+ Mat;
544+ __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
545+ __builtin_LinAlg_MatrixAccumulateToDescriptor(
546+ Mat, Output, 0, STRIDE, LAYOUT, 128);
547+ }
548+ )" ;
549+
550+ static void runAccumulateDescriptor (ID3D12Device *Device,
551+ dxc::SpecificDllLoader &DxcSupport,
552+ const MatrixParams &Params, float FillValue,
553+ bool Verbose) {
554+ const size_t NumElements = Params.totalElements ();
555+ const size_t BufferSize = Params.totalBytes ();
556+
557+ std::stringstream ExtraDefs;
558+ ExtraDefs << " -DFILL_VALUE=" << FillValue;
559+
560+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
561+
562+ compileShader (DxcSupport, AccumulateDescriptorShader, " cs_6_10" , Args, Verbose);
563+
564+ auto Expected =
565+ makeExpectedMat (Params.CompType , Params.M , Params.N , FillValue, false );
566+
567+ auto Op =
568+ createComputeOp (AccumulateDescriptorShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
569+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
570+ addRootUAV (Op.get (), 0 , " Output" );
571+
572+ auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
573+
574+ MappedData OutData;
575+ Result->Test ->GetReadBackData (" Output" , &OutData);
576+
577+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
578+ Expected, NumElements, Verbose));
579+ }
580+
581+ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16 () {
582+ MatrixParams Params = {};
583+ Params.CompType = ComponentType::F16;
584+ Params.M = 16 ;
585+ Params.N = 16 ;
586+ Params.Use = MatrixUse::Accumulator;
587+ Params.Scope = MatrixScope::Wave;
588+ Params.Layout = LinalgMatrixLayout::RowMajor;
589+ Params.NumThreads = 64 ;
590+ Params.Enable16Bit = true ;
591+ runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
592+ }
593+
594+ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16 () {
595+ MatrixParams Params = {};
596+ Params.CompType = ComponentType::F16;
597+ Params.M = 16 ;
598+ Params.N = 16 ;
599+ Params.Use = MatrixUse::Accumulator;
600+ Params.Scope = MatrixScope::Thread;
601+ Params.Layout = LinalgMatrixLayout::RowMajor;
602+ Params.NumThreads = 1 ;
603+ Params.Enable16Bit = true ;
604+ runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
605+ }
606+
531607static const char ElementAccessShader[] = R"(
532608 RWByteAddressBuffer Input : register(u0);
533609 RWByteAddressBuffer Output : register(u1);
0 commit comments