@@ -530,7 +530,9 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
530530
531531static const char AccumulateDescriptorShader[] = R"(
532532 #define USE_ACC 2
533- RWByteAddressBuffer Output : register(u0);
533+
534+ ByteAddressBuffer Input : register(t0);
535+ RWByteAddressBuffer Output : register(u1);
534536
535537 [WaveSize(4, 64)]
536538 [numthreads(NUMTHREADS, 1, 1)]
@@ -541,35 +543,44 @@ static const char AccumulateDescriptorShader[] = R"(
541543 __builtin_LinAlgMatrix
542544 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
543545 Mat;
544- __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
546+ __builtin_LinAlg_MatrixLoadFromDescriptor(
547+ Mat, Input, 0, STRIDE, LAYOUT, 128);
548+ __builtin_LinAlg_MatrixAccumulateToDescriptor(
549+ Mat, Output, 0, STRIDE, LAYOUT, 128);
545550 __builtin_LinAlg_MatrixAccumulateToDescriptor(
546551 Mat, Output, 0, STRIDE, LAYOUT, 128);
547552 }
548553)" ;
549554
550555static void runAccumulateDescriptor (ID3D12Device *Device,
551556 dxc::SpecificDllLoader &DxcSupport,
552- const MatrixParams &Params, float FillValue,
557+ const MatrixParams &Params, int FillValue,
553558 bool Verbose) {
554559 const size_t NumElements = Params.totalElements ();
555560 const size_t BufferSize = Params.totalBytes ();
556561
557- std::stringstream ExtraDefs;
558- ExtraDefs << " -DFILL_VALUE=" << FillValue;
559-
560- std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
562+ std::string Args = buildCompilerArgs (Params);
561563
562564 compileShader (DxcSupport, AccumulateDescriptorShader, " cs_6_10" , Args, Verbose);
563565
564566 auto Expected =
565- makeExpectedMat (Params.CompType , Params.M , Params.N , FillValue, false );
567+ makeExpectedMat (Params.CompType , Params.M , Params.N , static_cast < float >( FillValue) * 2 , false );
566568
567- auto Op =
568- createComputeOp (AccumulateDescriptorShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
569+ auto Op = createComputeOp (AccumulateDescriptorShader, " cs_6_10" ,
570+ " SRV(t0), UAV(u1)" , Args.c_str ());
571+ addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
569572 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
570- addRootUAV (Op.get (), 0 , " Output" );
573+ addRootUAV (Op.get (), 0 , " Input" );
574+ addRootUAV (Op.get (), 1 , " Output" );
571575
572- auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
576+ auto Result =
577+ runShaderOp (Device, DxcSupport, std::move (Op),
578+ [NumElements, Params, FillValue](LPCSTR Name, std::vector<BYTE> &Data,
579+ st::ShaderOp *) {
580+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType ,
581+ NumElements, /* StartingVal=*/ FillValue, /* Increment=*/ false ),
582+ " Saw unsupported component type" );
583+ });
573584
574585 MappedData OutData;
575586 Result->Test ->GetReadBackData (" Output" , &OutData);
@@ -588,7 +599,7 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
588599 Params.Layout = LinalgMatrixLayout::RowMajor;
589600 Params.NumThreads = 64 ;
590601 Params.Enable16Bit = true ;
591- runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 42 . 0f , VerboseLogging);
602+ runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 12 , VerboseLogging);
592603}
593604
594605void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16 () {
@@ -601,7 +612,7 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() {
601612 Params.Layout = LinalgMatrixLayout::RowMajor;
602613 Params.NumThreads = 1 ;
603614 Params.Enable16Bit = true ;
604- runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 42 . 0f , VerboseLogging);
615+ runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 19 , VerboseLogging);
605616}
606617
607618static const char ElementAccessShader[] = R"(
0 commit comments