@@ -325,6 +325,7 @@ class DxilConf_SM610_LinAlg {
325325 // Load/Store/Accumulate Memory
326326 TEST_METHOD (LoadMemory_Wave_16x16_F16);
327327 TEST_METHOD (StoreMemory_Wave_16x16_F16);
328+ TEST_METHOD (AccumulateMemory_Wave_16x16_F16);
328329
329330 // Element access
330331 TEST_METHOD (ElementAccess_Wave_16x16_F16);
@@ -1444,7 +1445,6 @@ void DxilConf_SM610_LinAlg::QueryAccumLayout() {
14441445static const char LoadMemoryShader[] = R"(
14451446 RWByteAddressBuffer Input : register(u0);
14461447 RWByteAddressBuffer Output : register(u1);
1447-
14481448 groupshared ELEM_TYPE GsData[M_DIM * N_DIM];
14491449
14501450 #define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS)
@@ -1593,4 +1593,83 @@ void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() {
15931593 runStoreMemory (D3DDevice, DxcSupport, Params, VerboseLogging, /* FillValue=*/ 7 .0f );
15941594}
15951595
1596+ static const char AccumulateMemoryShader[] = R"(
1597+ RWByteAddressBuffer Output : register(u0);
1598+ groupshared ELEM_TYPE GsData[M_DIM * N_DIM];
1599+
1600+ #define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS)
1601+
1602+ [WaveSize(4, 64)]
1603+ [numthreads(NUMTHREADS, 1, 1)]
1604+ void main(uint threadID : SV_GroupIndex) {
1605+ ELEM_TYPE fill = FILL_VALUE;
1606+ for (uint I = 0; I < ELEM_PER_THREAD; ++I) {
1607+ uint Index = threadID * ELEM_PER_THREAD + I;
1608+ GsData[Index] = fill;
1609+ }
1610+
1611+ GroupMemoryBarrierWithGroupSync();
1612+
1613+ if (WaveReadLaneFirst(threadID) != 0)
1614+ return;
1615+
1616+ __builtin_LinAlgMatrix
1617+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
1618+ Mat;
1619+ __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
1620+
1621+ __builtin_LinAlg_MatrixAccumulateToMemory(
1622+ Mat, GsData, OFFSET, STRIDE, LAYOUT);
1623+
1624+ for (uint I = 0; I < M_DIM*N_DIM; ++I) {
1625+ Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
1626+ }
1627+ }
1628+ )" ;
1629+
1630+ static void runAccumulateMemory (ID3D12Device *Device,
1631+ dxc::SpecificDllLoader &DxcSupport,
1632+ const MatrixParams &Params, bool Verbose,
1633+ float FillValue) {
1634+ const size_t NumElements = Params.totalElements ();
1635+ const size_t BufferSize = Params.totalBytes ();
1636+
1637+ std::stringstream ExtraDefs;
1638+ ExtraDefs << " -DOFFSET=" << 0 ;
1639+ ExtraDefs << " -DFILL_VALUE=" << FillValue;
1640+
1641+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
1642+
1643+ compileShader (DxcSupport, AccumulateMemoryShader, " cs_6_10" , Args,
1644+ Verbose);
1645+
1646+ auto Expected = makeExpectedMat (Params.CompType , Params.M , Params.N , FillValue * 2 , /* Increment=*/ false );
1647+
1648+ auto Op = createComputeOp (AccumulateMemoryShader, " cs_6_10" , " UAV(u0)" ,
1649+ Args.c_str ());
1650+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
1651+ addRootUAV (Op.get (), 0 , " Output" );
1652+
1653+ auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
1654+
1655+ MappedData OutData;
1656+ Result->Test ->GetReadBackData (" Output" , &OutData);
1657+
1658+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
1659+ Expected, NumElements, Verbose));
1660+ }
1661+
1662+ void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16 () {
1663+ MatrixParams Params = {};
1664+ Params.CompType = ComponentType::F16;
1665+ Params.M = 16 ;
1666+ Params.N = 16 ;
1667+ Params.Use = MatrixUse::Accumulator;
1668+ Params.Scope = MatrixScope::Wave;
1669+ Params.Layout = LinalgMatrixLayout::RowMajor;
1670+ Params.NumThreads = 64 ;
1671+ Params.Enable16Bit = true ;
1672+ runAccumulateMemory (D3DDevice, DxcSupport, Params, VerboseLogging, /* FillValue=*/ 7 .0f );
1673+ }
1674+
15961675} // namespace LinAlg
0 commit comments