@@ -315,6 +315,8 @@ class DxilConf_SM610_LinAlg {
315315
316316 // Matrix Arithmetic
317317 TEST_METHOD (MatMatMul_Wave_16x16x16_F16);
318+ TEST_METHOD (MatMatMulAccum_Wave_16x16x16_F16);
319+ TEST_METHOD (MatAccum_Wave_16x16_F16);
318320
319321private:
320322 CComPtr<ID3D12Device> D3DDevice;
@@ -857,4 +859,160 @@ void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
857859 /* AFill=*/ 2 .0f , /* BFill=*/ 3 .0f );
858860}
859861
862+ static const char MatMatMulAccumShader[] = R"(
863+ #define USE_A 0
864+ #define USE_B 1
865+ #define USE_ACC 2
866+
867+ RWByteAddressBuffer Output : register(u0);
868+
869+ [WaveSize(4, 64)]
870+ [numthreads(NUMTHREADS, 1, 1)]
871+ void main(uint threadID : SV_GroupIndex) {
872+ if (WaveReadLaneFirst(threadID) != 0)
873+ return;
874+
875+ __builtin_LinAlgMatrix
876+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]]
877+ MatA;
878+ __builtin_LinAlg_FillMatrix(MatA, A_FILL);
879+
880+ __builtin_LinAlgMatrix
881+ [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]]
882+ MatB;
883+ __builtin_LinAlg_FillMatrix(MatB, B_FILL);
884+
885+ __builtin_LinAlgMatrix
886+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
887+ MatC;
888+ __builtin_LinAlg_FillMatrix(MatC, C_FILL);
889+
890+ __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(MatC, MatA, MatB, MatC);
891+
892+ __builtin_LinAlg_MatrixStoreToDescriptor(
893+ MatC, Output, 0, STRIDE, LAYOUT, 128);
894+ }
895+ )" ;
896+
897+ static void runMatMatMulAccum (ID3D12Device *Device,
898+ dxc::SpecificDllLoader &DxcSupport,
899+ const MatrixParams &Params, bool Verbose, MatrixDim K,
900+ float AFill, float BFill, float CFill) {
901+ const size_t NumElements = Params.totalElements ();
902+ const size_t BufferSize = Params.totalBytes ();
903+
904+ std::stringstream ExtraDefs;
905+ ExtraDefs << " -DK_DIM=" << K;
906+ ExtraDefs << " -DA_FILL=" << AFill;
907+ ExtraDefs << " -DB_FILL=" << BFill;
908+ ExtraDefs << " -DC_FILL=" << CFill;
909+
910+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
911+
912+ compileShader (DxcSupport, MatMatMulAccumShader, " cs_6_10" , Args, Verbose);
913+
914+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N ,
915+ AFill * BFill * K + CFill, /* Increment=*/ false );
916+
917+ auto Op =
918+ createComputeOp (MatMatMulAccumShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
919+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
920+ addRootUAV (Op.get (), 0 , " Output" );
921+
922+ auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
923+
924+ MappedData OutData;
925+ Result->Test ->GetReadBackData (" Output" , &OutData);
926+
927+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
928+ Expected, NumElements, Verbose));
929+ }
930+
931+ void DxilConf_SM610_LinAlg::MatMatMulAccum_Wave_16x16x16_F16 () {
932+ MatrixParams Params = {};
933+ Params.CompType = ComponentType::F16;
934+ Params.M = 16 ;
935+ Params.N = 16 ;
936+ Params.Scope = MatrixScope::Wave;
937+ Params.Layout = LinalgMatrixLayout::RowMajor;
938+ Params.NumThreads = 64 ;
939+ Params.Enable16Bit = true ;
940+ runMatMatMulAccum (D3DDevice, DxcSupport, Params, VerboseLogging, /* K=*/ 16 ,
941+ /* AFill=*/ 2 .0f , /* BFill=*/ 3 .0f , /* CFill=*/ 4 .0f );
942+ }
943+
944+ static const char MatAccumShader[] = R"(
945+ #define USE_A 0
946+ #define USE_ACC 2
947+
948+ RWByteAddressBuffer Output : register(u0);
949+
950+ [WaveSize(4, 64)]
951+ [numthreads(NUMTHREADS, 1, 1)]
952+ void main(uint threadID : SV_GroupIndex) {
953+ if (WaveReadLaneFirst(threadID) != 0)
954+ return;
955+
956+ __builtin_LinAlgMatrix
957+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
958+ MatLHS;
959+ __builtin_LinAlg_FillMatrix(MatLHS, LHS_FILL);
960+
961+ __builtin_LinAlgMatrix
962+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE)]]
963+ MatRHS;
964+ __builtin_LinAlg_FillMatrix(MatRHS, RHS_FILL);
965+
966+ __builtin_LinAlg_MatrixAccumulate(MatLHS, MatLHS, MatRHS);
967+
968+ __builtin_LinAlg_MatrixStoreToDescriptor(
969+ MatLHS, Output, 0, STRIDE, LAYOUT, 128);
970+ }
971+ )" ;
972+
973+ static void runMatAccum (ID3D12Device *Device,
974+ dxc::SpecificDllLoader &DxcSupport,
975+ const MatrixParams &Params, bool Verbose,
976+ float LHSFill, float RHSFill) {
977+ const size_t NumElements = Params.totalElements ();
978+ const size_t BufferSize = Params.totalBytes ();
979+
980+ std::stringstream ExtraDefs;
981+ ExtraDefs << " -DLHS_FILL=" << LHSFill;
982+ ExtraDefs << " -DRHS_FILL=" << RHSFill;
983+
984+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
985+
986+ compileShader (DxcSupport, MatAccumShader, " cs_6_10" , Args, Verbose);
987+
988+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N ,
989+ LHSFill + RHSFill, /* Increment=*/ false );
990+
991+ auto Op =
992+ createComputeOp (MatAccumShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
993+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
994+ addRootUAV (Op.get (), 0 , " Output" );
995+
996+ auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
997+
998+ MappedData OutData;
999+ Result->Test ->GetReadBackData (" Output" , &OutData);
1000+
1001+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
1002+ Expected, NumElements, Verbose));
1003+ }
1004+
1005+ void DxilConf_SM610_LinAlg::MatAccum_Wave_16x16_F16 () {
1006+ MatrixParams Params = {};
1007+ Params.CompType = ComponentType::F16;
1008+ Params.M = 16 ;
1009+ Params.N = 16 ;
1010+ Params.Scope = MatrixScope::Wave;
1011+ Params.Layout = LinalgMatrixLayout::RowMajor;
1012+ Params.NumThreads = 64 ;
1013+ Params.Enable16Bit = true ;
1014+ runMatAccum (D3DDevice, DxcSupport, Params, VerboseLogging,
1015+ /* LHSFill=*/ 2 .0f , /* RHSFill=*/ 3 .0f );
1016+ }
1017+
8601018} // namespace LinAlg
0 commit comments