@@ -152,11 +152,18 @@ class DxilConf_SM610_LinAlg {
152152 TEST_CLASS_SETUP (setupClass);
153153 TEST_METHOD_SETUP (setupMethod);
154154
155+ // Load/Store
155156 TEST_METHOD (LoadStoreRoundtrip_Wave_F32);
156157 TEST_METHOD (LoadStoreRoundtrip_Wave_I32);
158+
159+ // Splat Store
157160 TEST_METHOD (SplatStore_Wave_F32);
158161 TEST_METHOD (SplatStore_Wave_I32);
159162
163+ // Element access
164+ TEST_METHOD (ElementAccess_Wave_F32);
165+ TEST_METHOD (ElementAccess_Wave_I32);
166+
160167private:
161168 bool createDevice ();
162169
@@ -389,6 +396,7 @@ static void runSplatStore(ID3D12Device *Device,
389396 }
390397#endif
391398
399+ // Build expected data.
392400 std::vector<float > ExpectedFloats;
393401 std::vector<int32_t > ExpectedInts;
394402 switch (Params.CompType ) {
@@ -456,4 +464,187 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
456464 runSplatStore (D3DDevice, DxcSupport, Params, 7 .0f , VerboseLogging);
457465}
458466
467+ static const char ElementAccessShader[] = R"(
468+ RWByteAddressBuffer Input : register(u0);
469+ RWByteAddressBuffer Output : register(u1);
470+
471+ // 0,0 = 0
472+ // 0,1 = 4
473+ // 1,0 = 16
474+ // 3,3 = 60
475+ // TODO: this assumes M=4,N=4
476+ uint cordToByteOffset(uint2 coord) {
477+ return (coord.x * 4 + coord.y) * 4;
478+ }
479+
480+ [numthreads(NUMTHREADS, 1, 1)]
481+ void main(uint threadIndex : SV_GroupIndex) {
482+ __builtin_LinAlgMatrix
483+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
484+ Mat;
485+ __builtin_LinAlg_MatrixLoadFromDescriptor(
486+ Mat, Input, 0, STRIDE, LAYOUT, 0);
487+
488+ // Copy Matrix values from input to output without assuming order
489+ for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
490+ uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
491+ uint Offset = cordToByteOffset(Coord);
492+ #if COMP_TYPE == 9
493+ float Elem;
494+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
495+ Output.Store(Offset, asuint(Elem));
496+ #else
497+ uint Elem;
498+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
499+ Output.Store(Offset, Elem);
500+ #endif
501+ }
502+
503+ // Store each threads Length in the output after the copied matrix
504+ uint finalIdx = (M_DIM * N_DIM + threadIndex) * 4;
505+ uint Len = __builtin_LinAlg_MatrixLength(Mat);
506+ Output.Store(finalIdx, Len);
507+ }
508+ )" ;
509+
510+ static void runElementAccess (ID3D12Device *Device,
511+ dxc::SpecificDllLoader &DxcSupport,
512+ const MatrixParams &Params, bool Verbose) {
513+ const size_t NumElements = Params.totalElements ();
514+ const size_t NumThreads = Params.NumThreads ;
515+ const size_t InputBufSize = Params.totalBytes ();
516+ // Output: 4 bytes per element
517+ // 1 element for each mat idx
518+ // 1 element for each thread's length
519+ const size_t OutputBufSize = (NumElements + NumThreads) * 4 ;
520+
521+ std::string Args = buildCompilerArgs (Params);
522+
523+ compileShader (DxcSupport, ElementAccessShader, " cs_6_10" , Args, Verbose);
524+
525+ #ifndef _HLK_CONF
526+ // Skip GPU execution if no device.
527+ if (!Device) {
528+ hlsl_test::LogCommentFmt (
529+ L" Shader compiled OK; skipping execution (no SM 6.10 device)" );
530+ WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
531+ return ;
532+ }
533+ #endif
534+
535+ // Build expected data.
536+ std::vector<float > ExpectedFloats (NumElements);
537+ std::vector<int32_t > ExpectedInts (NumElements);
538+ for (size_t I = 0 ; I < NumElements; I++) {
539+ ExpectedFloats[I] = static_cast <float >(I + 1 );
540+ ExpectedInts[I] = static_cast <int32_t >(I + 1 );
541+ }
542+
543+ auto Op = createComputeOp (ElementAccessShader, " cs_6_10" ,
544+ " UAV(u0), UAV(u1)" , Args.c_str ());
545+ addUAVBuffer (Op.get (), " Input" , InputBufSize, false , " byname" );
546+ addUAVBuffer (Op.get (), " Output" , OutputBufSize, true );
547+ addRootUAV (Op.get (), 0 , " Input" );
548+ addRootUAV (Op.get (), 1 , " Output" );
549+
550+ auto Result = runShaderOp (
551+ Device, DxcSupport, std::move (Op),
552+ [&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
553+ if (_stricmp (Name, " Input" ) != 0 )
554+ return ;
555+
556+ switch (Params.CompType ) {
557+ case ComponentType::F32: {
558+ float *Ptr = reinterpret_cast <float *>(Data.data ());
559+ for (size_t I = 0 ; I < NumElements; I++)
560+ Ptr[I] = static_cast <float >(I + 1 );
561+ break ;
562+ }
563+ case ComponentType::I32: {
564+ int32_t *Ptr = reinterpret_cast <int32_t *>(Data.data ());
565+ for (size_t I = 0 ; I < NumElements; I++)
566+ Ptr[I] = static_cast <int32_t >(I + 1 );
567+ break ;
568+ }
569+ default :
570+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
571+ break ;
572+ }
573+ });
574+
575+ MappedData OutData;
576+ Result->Test ->GetReadBackData (" Output" , &OutData);
577+ const uint32_t *Out = static_cast <const uint32_t *>(OutData.data ());
578+
579+ // Build actual data.
580+ std::vector<float > ActualFloats (NumElements);
581+ std::vector<int32_t > ActualInts (NumElements);
582+ for (size_t I = 0 ; I < NumElements * 4 ; I = I + 4 ) {
583+ switch (Params.CompType ) {
584+ case ComponentType::F32: {
585+ float Actual;
586+ memcpy (&Actual, &Out[I], sizeof (float ));
587+ ActualFloats[I/4 ] = Actual;
588+ break ;
589+ }
590+ case ComponentType::I32: {
591+ ActualInts[I/4 ] = Out[I];
592+ break ;
593+ }
594+ default :
595+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
596+ break ;
597+ }
598+ }
599+
600+ // Verify element values match input data.
601+ switch (Params.CompType ) {
602+ case ComponentType::F32:
603+ VERIFY_IS_TRUE (verifyFloatBuffer (ActualFloats.data (), ExpectedFloats.data (),
604+ NumElements, Verbose));
605+ break ;
606+ case ComponentType::I32:
607+ VERIFY_IS_TRUE (verifyIntBuffer (ActualInts.data (), ExpectedInts.data (),
608+ NumElements, Verbose));
609+ break ;
610+ default :
611+ VERIFY_IS_TRUE (false , " Saw unsupported component type" );
612+ break ;
613+ }
614+
615+ // The sum of the values returned by Length across all threads must be
616+ // greater than or equal to the total number of matrix elements
617+ size_t TotalLength = 0 ;
618+ for (size_t I = NumElements * 4 ; I < (NumElements + NumThreads) * 4 ; I = I + 4 ) {
619+ TotalLength += Out[I];
620+ }
621+ VERIFY_IS_TRUE (TotalLength >= NumElements, " Sum of all lengths must be gte num elements" );
622+ }
623+
624+ void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32 () {
625+ MatrixParams Params = {};
626+ Params.CompType = ComponentType::F32;
627+ Params.M = 4 ;
628+ Params.N = 4 ;
629+ Params.Use = MatrixUse::Accumulator;
630+ Params.Scope = MatrixScope::Wave;
631+ Params.Layout = LinalgMatrixLayout::RowMajor;
632+ Params.NumThreads = 4 ;
633+ Params.Enable16Bit = false ;
634+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
635+ }
636+
637+ void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32 () {
638+ MatrixParams Params = {};
639+ Params.CompType = ComponentType::I32;
640+ Params.M = 4 ;
641+ Params.N = 4 ;
642+ Params.Use = MatrixUse::Accumulator;
643+ Params.Scope = MatrixScope::Wave;
644+ Params.Layout = LinalgMatrixLayout::RowMajor;
645+ Params.NumThreads = 4 ;
646+ Params.Enable16Bit = false ;
647+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
648+ }
649+
459650} // namespace LinAlg
0 commit comments