@@ -38,7 +38,7 @@ using hlsl::DXIL::MatrixScope;
3838using hlsl::DXIL::MatrixUse;
3939
4040// / Return the byte size of a single element for the given component type.
41- static int elemSize (ComponentType CT) {
41+ static int elementSize (ComponentType CT) {
4242 switch (CT) {
4343 case ComponentType::F16:
4444 case ComponentType::I16:
@@ -64,22 +64,23 @@ struct MatrixParams {
6464 bool Enable16Bit;
6565
6666 int strideBytes () const {
67- int ES = elemSize (CompType);
67+ int ES = elementSize (CompType);
6868 if (Layout == LinalgMatrixLayout::RowMajor)
6969 return N * ES;
7070 return M * ES;
7171 }
7272
7373 size_t totalElements () const { return static_cast <size_t >(M) * N; }
7474
75- size_t totalBytes () const { return totalElements () * elemSize (CompType); }
75+ size_t totalBytes () const { return totalElements () * elementSize (CompType); }
7676};
7777
7878static std::string buildCompilerArgs (const MatrixParams &Params,
7979 const char *ExtraDefines = nullptr ) {
8080 std::stringstream SS;
8181 SS << " -HV 202x" ;
8282 SS << " -DCOMP_TYPE=" << static_cast <int >(Params.CompType );
83+ SS << " -DCOMP_TYPE_F32=" << 9 ;
8384 SS << " -DM_DIM=" << Params.M ;
8485 SS << " -DN_DIM=" << Params.N ;
8586 SS << " -DUSE=" << static_cast <int >(Params.Use );
@@ -152,11 +153,18 @@ class DxilConf_SM610_LinAlg {
152153 TEST_CLASS_SETUP (setupClass);
153154 TEST_METHOD_SETUP (setupMethod);
154155
156+ // Load/Store
155157 TEST_METHOD (LoadStoreRoundtrip_Wave_F32);
156158 TEST_METHOD (LoadStoreRoundtrip_Wave_I32);
159+
160+ // Splat Store
157161 TEST_METHOD (SplatStore_Wave_F32);
158162 TEST_METHOD (SplatStore_Wave_I32);
159163
164+ // Element access
165+ TEST_METHOD (ElementAccess_Wave_F32);
166+ TEST_METHOD (ElementAccess_Wave_I32);
167+
160168private:
161169 bool createDevice ();
162170
@@ -266,7 +274,6 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
266274 }
267275#endif
268276
269- // Build expected data.
270277 std::vector<float > ExpectedFloats (NumElements);
271278 std::vector<int32_t > ExpectedInts (NumElements);
272279 for (size_t I = 0 ; I < NumElements; I++) {
@@ -456,4 +463,197 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
456463 runSplatStore (D3DDevice, DxcSupport, Params, 7 .0f , VerboseLogging);
457464}
458465
466+ static const char ElementAccessShader[] = R"(
467+ RWByteAddressBuffer Input : register(u0);
468+ RWByteAddressBuffer Output : register(u1);
469+
470+ // flatten the 2D index into a 1D index then scale by element size
471+ uint coordToByteOffset(uint2 coord) {
472+ return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
473+ }
474+
475+ [numthreads(NUMTHREADS, 1, 1)]
476+ void main(uint threadIndex : SV_GroupIndex) {
477+ __builtin_LinAlgMatrix
478+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
479+ Mat;
480+ __builtin_LinAlg_MatrixLoadFromDescriptor(
481+ Mat, Input, 0, STRIDE, LAYOUT, 0);
482+
483+ // Copy Matrix values from input to output without assuming order
484+ for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
485+ uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
486+ uint Offset = coordToByteOffset(Coord);
487+ #if COMP_TYPE == COMP_TYPE_F32
488+ float Elem;
489+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
490+ Output.Store(Offset, asuint(Elem));
491+ #else
492+ uint Elem;
493+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
494+ Output.Store(Offset, Elem);
495+ #endif
496+ }
497+
498+ // Save the matrix length that this thread saw. The length is written
499+ // to the output right after the matrix, offset by the thread index
500+ uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
501+ uint Len = __builtin_LinAlg_MatrixLength(Mat);
502+ Output.Store(LenIdx, Len);
503+ }
504+ )" ;
505+
506+ static void runElementAccess (ID3D12Device *Device,
507+ dxc::SpecificDllLoader &DxcSupport,
508+ const MatrixParams &Params, bool Verbose) {
509+ const size_t NumElements = Params.totalElements ();
510+ const size_t NumThreads = Params.NumThreads ;
511+ const size_t InputBufSize = Params.totalBytes ();
512+ const size_t ElementSize = elementSize (Params.CompType );
513+ const size_t MajorDim =
514+ Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N ;
515+ // Output: ElementSize bytes per element
516+ // 1 element for each mat idx
517+ // 1 uint for each thread's length
518+ const size_t OutputBufSize =
519+ NumElements * ElementSize + NumThreads * sizeof (uint32_t );
520+
521+ std::stringstream ExtraDefs;
522+ ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
523+ ExtraDefs << " -DELEM_SIZE=" << ElementSize;
524+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
525+
526+ compileShader (DxcSupport, ElementAccessShader, " cs_6_10" , Args, Verbose);
527+
528+ #ifndef _HLK_CONF
529+ // Skip GPU execution if no device.
530+ if (!Device) {
531+ hlsl_test::LogCommentFmt (
532+ L" Shader compiled OK; skipping execution (no SM 6.10 device)" );
533+ WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
534+ return ;
535+ }
536+ #endif
537+
538+ std::vector<float > ExpectedFloats (NumElements);
539+ std::vector<int32_t > ExpectedInts (NumElements);
540+ for (size_t I = 0 ; I < NumElements; I++) {
541+ ExpectedFloats[I] = static_cast <float >(I + 1 );
542+ ExpectedInts[I] = static_cast <int32_t >(I + 1 );
543+ }
544+
545+ auto Op = createComputeOp (ElementAccessShader, " cs_6_10" , " UAV(u0), UAV(u1)" ,
546+ Args.c_str ());
547+ addUAVBuffer (Op.get (), " Input" , InputBufSize, false , " byname" );
548+ addUAVBuffer (Op.get (), " Output" , OutputBufSize, true );
549+ addRootUAV (Op.get (), 0 , " Input" );
550+ addRootUAV (Op.get (), 1 , " Output" );
551+
552+ auto Result =
553+ runShaderOp (Device, DxcSupport, std::move (Op),
554+ [&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
555+ if (_stricmp (Name, " Input" ) != 0 )
556+ return ;
557+
558+ switch (Params.CompType ) {
559+ case ComponentType::F32: {
560+ float *Ptr = reinterpret_cast <float *>(Data.data ());
561+ for (size_t I = 0 ; I < NumElements; I++)
562+ Ptr[I] = static_cast <float >(I + 1 );
563+ break ;
564+ }
565+ case ComponentType::I32: {
566+ int32_t *Ptr = reinterpret_cast <int32_t *>(Data.data ());
567+ for (size_t I = 0 ; I < NumElements; I++)
568+ Ptr[I] = static_cast <int32_t >(I + 1 );
569+ break ;
570+ }
571+ default :
572+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
573+ break ;
574+ }
575+ });
576+
577+ MappedData OutData;
578+ Result->Test ->GetReadBackData (" Output" , &OutData);
579+ const BYTE *Out = static_cast <const BYTE *>(OutData.data ());
580+
581+ std::vector<float > ActualFloats (NumElements);
582+ std::vector<int32_t > ActualInts (NumElements);
583+ for (size_t I = 0 ; I < NumElements; ++I) {
584+ switch (Params.CompType ) {
585+ case ComponentType::F32: {
586+ float Actual;
587+ memcpy (&Actual, &Out[I * ElementSize], ElementSize);
588+ ActualFloats[I] = Actual;
589+ break ;
590+ }
591+ case ComponentType::I32: {
592+ int32_t Actual;
593+ memcpy (&Actual, &Out[I * ElementSize], ElementSize);
594+ ActualInts[I] = Actual;
595+ break ;
596+ }
597+ default :
598+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
599+ break ;
600+ }
601+ }
602+
603+ // Verify element values match input data.
604+ switch (Params.CompType ) {
605+ case ComponentType::F32:
606+ VERIFY_IS_TRUE (verifyFloatBuffer (ActualFloats.data (), ExpectedFloats.data (),
607+ NumElements, Verbose));
608+ break ;
609+ case ComponentType::I32:
610+ VERIFY_IS_TRUE (verifyIntBuffer (ActualInts.data (), ExpectedInts.data (),
611+ NumElements, Verbose));
612+ break ;
613+ default :
614+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
615+ break ;
616+ }
617+
618+ // The sum of the values returned by Length across all threads must be
619+ // greater than or equal to the total number of matrix elements
620+ size_t MatrixEndOffset = NumElements * ElementSize;
621+ size_t LengthValuesEnd = MatrixEndOffset + (NumThreads * sizeof (uint32_t ));
622+ size_t TotalLength = 0 ;
623+ for (size_t I = MatrixEndOffset; I < LengthValuesEnd;
624+ I = I + sizeof (uint32_t )) {
625+ uint32_t Length;
626+ memcpy (&Length, &Out[I], sizeof (uint32_t ));
627+ TotalLength += Length;
628+ }
629+ VERIFY_IS_TRUE (TotalLength >= NumElements,
630+ " Sum of all lengths must be gte num elements" );
631+ }
632+
633+ void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32 () {
634+ MatrixParams Params = {};
635+ Params.CompType = ComponentType::F32;
636+ Params.M = 16 ;
637+ Params.N = 16 ;
638+ Params.Use = MatrixUse::Accumulator;
639+ Params.Scope = MatrixScope::Wave;
640+ Params.Layout = LinalgMatrixLayout::RowMajor;
641+ Params.NumThreads = 4 ;
642+ Params.Enable16Bit = false ;
643+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
644+ }
645+
646+ void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32 () {
647+ MatrixParams Params = {};
648+ Params.CompType = ComponentType::I32;
649+ Params.M = 16 ;
650+ Params.N = 16 ;
651+ Params.Use = MatrixUse::Accumulator;
652+ Params.Scope = MatrixScope::Wave;
653+ Params.Layout = LinalgMatrixLayout::RowMajor;
654+ Params.NumThreads = 4 ;
655+ Params.Enable16Bit = false ;
656+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
657+ }
658+
459659} // namespace LinAlg
0 commit comments