@@ -307,6 +307,14 @@ class DxilConf_SM610_LinAlg {
307307
308308 // Element access
309309 TEST_METHOD (ElementAccess_Wave_16x16_F16);
310+ TEST_METHOD (ElementSet_Wave_16x16_F16);
311+
312+ // Cast/Convert
313+ TEST_METHOD (CopyConvert_Wave_16x16_F16);
314+ TEST_METHOD (CopyConvert_Wave_16x16_F16_Transpose);
315+
316+ // Matrix Arithmetic
317+ TEST_METHOD (MatMatMul_Wave_16x16x16_F16);
310318
311319private:
312320 CComPtr<ID3D12Device> D3DDevice;
@@ -537,14 +545,9 @@ static void runElementAccess(ID3D12Device *Device,
537545 const MatrixParams &Params, bool Verbose) {
538546 const size_t NumElements = Params.totalElements ();
539547 const size_t NumThreads = Params.NumThreads ;
540- const size_t InputBufSize = Params.totalBytes ();
541- const size_t ElementSize = elementSize (Params.CompType );
542-
543- // Output: ElementSize bytes per element
544- // 1 element for each mat idx
545- // 1 uint for each thread's length
546- const size_t OutputBufSize =
547- NumElements * ElementSize + NumThreads * sizeof (uint32_t );
548+ const size_t MatrixSize = Params.totalBytes ();
549+ // OutputBuf needs to fit the Matrix plus one uint per thread
550+ const size_t OutputBufSize = MatrixSize + NumThreads * sizeof (uint32_t );
548551
549552 std::stringstream ExtraDefs;
550553 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
@@ -555,7 +558,7 @@ static void runElementAccess(ID3D12Device *Device,
555558
556559 auto Op = createComputeOp (ElementAccessShader, " cs_6_10" , " UAV(u0), UAV(u1)" ,
557560 Args.c_str ());
558- addUAVBuffer (Op.get (), " Input" , InputBufSize , false , " byname" );
561+ addUAVBuffer (Op.get (), " Input" , MatrixSize , false , " byname" );
559562 addUAVBuffer (Op.get (), " Output" , OutputBufSize, true );
560563 addRootUAV (Op.get (), 0 , " Input" );
561564 addRootUAV (Op.get (), 1 , " Output" );
@@ -579,9 +582,8 @@ static void runElementAccess(ID3D12Device *Device,
579582 // Verify the end of the buffer is NumThreads number of lengths, whose
580583 // sum is greater than or equal to NumElements
581584 const BYTE *Out = static_cast <const BYTE *>(OutData.data ());
582- size_t MatrixEndOffset = NumElements * ElementSize;
583585 const uint32_t *Lengths =
584- reinterpret_cast <const uint32_t *>(Out + MatrixEndOffset );
586+ reinterpret_cast <const uint32_t *>(Out + MatrixSize );
585587 uint32_t TotalLength = 0 ;
586588 for (size_t I = 0 ; I < NumThreads; ++I)
587589 TotalLength += Lengths[I];
@@ -602,4 +604,251 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
602604 runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
603605}
604606
607+ static const char ElementSetShader[] = R"(
608+ RWByteAddressBuffer Input : register(u0);
609+ RWByteAddressBuffer Output : register(u1);
610+
611+ [WaveSize(4, 64)]
612+ [numthreads(NUMTHREADS, 1, 1)]
613+ void main(uint threadID : SV_GroupIndex) {
614+ if (WaveReadLaneFirst(threadID) != 0)
615+ return;
616+
617+ __builtin_LinAlgMatrix
618+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
619+ Mat;
620+ __builtin_LinAlg_MatrixLoadFromDescriptor(
621+ Mat, Input, 0, STRIDE, LAYOUT, 128);
622+
623+ // Increment every element by 5
624+ for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
625+ ELEM_TYPE Elem;
626+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
627+ Elem = Elem + 5;
628+ __builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem);
629+ }
630+
631+ __builtin_LinAlg_MatrixStoreToDescriptor(
632+ Mat, Output, 0, STRIDE, LAYOUT, 128);
633+ }
634+ )" ;
635+
636+ static void runElementSet (ID3D12Device *Device,
637+ dxc::SpecificDllLoader &DxcSupport,
638+ const MatrixParams &Params, bool Verbose) {
639+ const size_t NumElements = Params.totalElements ();
640+ const size_t MatrixSize = Params.totalBytes ();
641+
642+ std::stringstream ExtraDefs;
643+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
644+
645+ compileShader (DxcSupport, ElementSetShader, " cs_6_10" , Args, Verbose);
646+
647+ // Start counting from 6 since each element was increased by 5
648+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 6 );
649+
650+ auto Op = createComputeOp (ElementSetShader, " cs_6_10" , " UAV(u0), UAV(u1)" ,
651+ Args.c_str ());
652+ addUAVBuffer (Op.get (), " Input" , MatrixSize, false , " byname" );
653+ addUAVBuffer (Op.get (), " Output" , MatrixSize, true );
654+ addRootUAV (Op.get (), 0 , " Input" );
655+ addRootUAV (Op.get (), 1 , " Output" );
656+
657+ auto Result =
658+ runShaderOp (Device, DxcSupport, std::move (Op),
659+ [NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
660+ st::ShaderOp *) {
661+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType ,
662+ NumElements),
663+ " Saw unsupported component type" );
664+ });
665+
666+ MappedData OutData;
667+ Result->Test ->GetReadBackData (" Output" , &OutData);
668+
669+ // Verify the front of the buffer is a list of elements of the expected type
670+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
671+ Expected, NumElements, Verbose));
672+
673+ }
674+
675+ void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16 () {
676+ MatrixParams Params = {};
677+ Params.CompType = ComponentType::F16;
678+ Params.M = 16 ;
679+ Params.N = 16 ;
680+ Params.Use = MatrixUse::Accumulator;
681+ Params.Scope = MatrixScope::Wave;
682+ Params.Layout = LinalgMatrixLayout::RowMajor;
683+ Params.NumThreads = 64 ;
684+ Params.Enable16Bit = true ;
685+ runElementSet (D3DDevice, DxcSupport, Params, VerboseLogging);
686+ }
687+
688+ static const char CopyConvertShader[] = R"(
689+ RWByteAddressBuffer Input : register(u0);
690+ RWByteAddressBuffer Output : register(u1);
691+
692+ [WaveSize(4, 64)]
693+ [numthreads(NUMTHREADS, 1, 1)]
694+ void main(uint threadID : SV_GroupIndex) {
695+ if (WaveReadLaneFirst(threadID) != 0)
696+ return;
697+
698+ __builtin_LinAlgMatrix
699+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
700+ Src;
701+ __builtin_LinAlgMatrix
702+ [[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]]
703+ Dst;
704+
705+ __builtin_LinAlg_MatrixLoadFromDescriptor(
706+ Src, Input, 0, STRIDE, LAYOUT, 128);
707+ __builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE);
708+ __builtin_LinAlg_MatrixStoreToDescriptor(
709+ Dst, Output, 0, STRIDE, LAYOUT, 128);
710+ }
711+ )" ;
712+
713+ static void runCopyConvert (ID3D12Device *Device,
714+ dxc::SpecificDllLoader &DxcSupport,
715+ const MatrixParams &Params, bool Verbose, bool Transpose) {
716+ const size_t NumElements = Params.totalElements ();
717+ const size_t BufferSize = Params.totalBytes ();
718+
719+ std::stringstream ExtraDefs;
720+ ExtraDefs << " -DTRANSPOSE=" << Transpose;
721+
722+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
723+
724+ compileShader (DxcSupport, CopyConvertShader, " cs_6_10" , Args, Verbose);
725+
726+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 1 , /* Increment=*/ true , Transpose);
727+
728+ // Construct the ShaderOp: two UAV buffers, load from one, store to other.
729+ auto Op = createComputeOp (CopyConvertShader, " cs_6_10" , " UAV(u0), UAV(u1)" ,
730+ Args.c_str ());
731+ addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
732+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
733+ addRootUAV (Op.get (), 0 , " Input" );
734+ addRootUAV (Op.get (), 1 , " Output" );
735+
736+ auto Result =
737+ runShaderOp (Device, DxcSupport, std::move (Op),
738+ [NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
739+ st::ShaderOp *) {
740+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType ,
741+ NumElements),
742+ " Saw unsupported component type" );
743+ });
744+
745+ MappedData OutData;
746+ Result->Test ->GetReadBackData (" Output" , &OutData);
747+
748+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
749+ Expected, NumElements, Verbose));
750+ }
751+
752+ void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16 () {
753+ MatrixParams Params = {};
754+ Params.CompType = ComponentType::F16;
755+ Params.M = 16 ;
756+ Params.N = 16 ;
757+ Params.Use = MatrixUse::A;
758+ Params.Scope = MatrixScope::Wave;
759+ Params.Layout = LinalgMatrixLayout::RowMajor;
760+ Params.NumThreads = 64 ;
761+ Params.Enable16Bit = true ;
762+ runCopyConvert (D3DDevice, DxcSupport, Params, VerboseLogging, /* Transpose=*/ false );
763+ }
764+
765+ void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose () {
766+ MatrixParams Params = {};
767+ Params.CompType = ComponentType::F16;
768+ Params.M = 16 ;
769+ Params.N = 16 ;
770+ Params.Use = MatrixUse::A;
771+ Params.Scope = MatrixScope::Wave;
772+ Params.Layout = LinalgMatrixLayout::RowMajor;
773+ Params.NumThreads = 64 ;
774+ Params.Enable16Bit = true ;
775+ runCopyConvert (D3DDevice, DxcSupport, Params, VerboseLogging, /* Transpose=*/ true );
776+ }
777+
778+ static const char MatMatMulShader[] = R"(
779+ #define USE_A 0
780+ #define USE_B 1
781+ #define USE_ACC 2
782+
783+ RWByteAddressBuffer Output : register(u0);
784+
785+ [WaveSize(4, 64)]
786+ [numthreads(NUMTHREADS, 1, 1)]
787+ void main(uint threadID : SV_GroupIndex) {
788+ if (WaveReadLaneFirst(threadID) != 0)
789+ return;
790+
791+ __builtin_LinAlgMatrix
792+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]]
793+ MatA;
794+ __builtin_LinAlg_FillMatrix(MatA, A_FILL);
795+
796+ __builtin_LinAlgMatrix
797+ [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]]
798+ MatB;
799+ __builtin_LinAlg_FillMatrix(MatB, B_FILL);
800+
801+ __builtin_LinAlgMatrix
802+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
803+ MatC;
804+ __builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB);
805+
806+ __builtin_LinAlg_MatrixStoreToDescriptor(
807+ MatC, Output, 0, STRIDE, LAYOUT, 128);
808+ }
809+ )" ;
810+
811+ static void runMatMatMul (ID3D12Device *Device,
812+ dxc::SpecificDllLoader &DxcSupport,
813+ const MatrixParams &Params, bool Verbose, MatrixDim K, float AFill, float BFill) {
814+ const size_t NumElements = Params.totalElements ();
815+ const size_t BufferSize = Params.totalBytes ();
816+
817+ std::stringstream ExtraDefs;
818+ ExtraDefs << " -DK_DIM=" << K;
819+ ExtraDefs << " -DA_FILL=" << AFill;
820+ ExtraDefs << " -DB_FILL=" << BFill;
821+
822+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
823+
824+ compileShader (DxcSupport, MatMatMulShader, " cs_6_10" , Args, Verbose);
825+
826+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N , AFill * BFill * K, /* Increment=*/ false );
827+
828+ auto Op =
829+ createComputeOp (MatMatMulShader, " cs_6_10" , " UAV(u0)" , Args.c_str ());
830+ addUAVBuffer (Op.get (), " Output" , BufferSize, true );
831+ addRootUAV (Op.get (), 0 , " Output" );
832+
833+ auto Result = runShaderOp (Device, DxcSupport, std::move (Op));
834+
835+ MappedData OutData;
836+ Result->Test ->GetReadBackData (" Output" , &OutData);
837+
838+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
839+ Expected, NumElements, Verbose));
840+ }
841+
842+ void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16 () {
843+ MatrixParams Params = {};
844+ Params.CompType = ComponentType::F16;
845+ Params.M = 16 ;
846+ Params.N = 16 ;
847+ Params.Scope = MatrixScope::Wave;
848+ Params.Layout = LinalgMatrixLayout::RowMajor;
849+ Params.NumThreads = 64 ;
850+ Params.Enable16Bit = true ;
851+ runMatMatMul (D3DDevice, DxcSupport, Params, VerboseLogging, /* K=*/ 16 , /* AFill=*/ 2 .0f , /* BFill=*/ 3 .0f );
852+ }
853+
605854} // namespace LinAlg
0 commit comments