@@ -72,12 +72,21 @@ struct MatrixParams {
7272 int NumThreads;
7373 bool Enable16Bit;
7474 bool EmulateTest;
75+ bool GroupSharedMemory = false ;
7576
76- size_t strideBytes () const {
77- uint32_t ES = elementSize (CompType);
77+ size_t rowStride () const {
78+ // If not Row/Col major, spec says to list 0.
79+ size_t RowElementCount = 0 ;
7880 if (Layout == LinalgMatrixLayout::RowMajor)
79- return N * ES;
80- return M * ES;
81+ RowElementCount = N;
82+ if (Layout == LinalgMatrixLayout::ColumnMajor)
83+ RowElementCount = M;
84+
85+ if (GroupSharedMemory)
86+ return RowElementCount;
87+
88+ uint32_t ElementSize = elementSize (CompType);
89+ return RowElementCount * ElementSize;
8190 }
8291
8392 size_t totalElements () const { return M * N; }
@@ -94,7 +103,7 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
94103 SS << " -DN_DIM=" << Params.N ;
95104 SS << " -DUSE=" << static_cast <int >(Params.Use );
96105 SS << " -DSCOPE=" << static_cast <int >(Params.Scope );
97- SS << " -DSTRIDE=" << Params.strideBytes ();
106+ SS << " -DSTRIDE=" << Params.rowStride ();
98107 SS << " -DLAYOUT=" << static_cast <int >(Params.Layout );
99108 SS << " -DELEM_SIZE=" << static_cast <int >(elementSize (Params.CompType ));
100109 SS << " -DNUMTHREADS=" << Params.NumThreads ;
@@ -320,7 +329,6 @@ class DxilConf_SM610_LinAlg {
320329 TEST_METHOD (LoadStoreDescriptor_Wave_16x16_F16);
321330 TEST_METHOD (SplatStore_Wave_16x16_F16);
322331 TEST_METHOD (AccumulateDescriptor_Wave_16x16_F16);
323- TEST_METHOD (AccumulateDescriptor_Thread_16x16_F16);
324332
325333 // Load/Store/Accumulate Memory
326334 TEST_METHOD (LoadMemory_Wave_16x16_F16);
@@ -613,27 +621,14 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
613621 runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 12 , VerboseLogging);
614622}
615623
616- void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16 () {
617- MatrixParams Params = {};
618- Params.CompType = ComponentType::F16;
619- Params.M = 16 ;
620- Params.N = 16 ;
621- Params.Use = MatrixUse::Accumulator;
622- Params.Scope = MatrixScope::Thread;
623- Params.Layout = LinalgMatrixLayout::RowMajor;
624- Params.NumThreads = 1 ;
625- Params.Enable16Bit = true ;
626- runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 19 , VerboseLogging);
627- }
628-
629624static const char ElementAccessShader[] = R"(
630625 RWByteAddressBuffer Input : register(u0);
631626 RWByteAddressBuffer Output : register(u1);
632627
633628 // flatten the 2D index into a 1D index then scale by element size
634629 // Always store row-major and work it out in the test runner
635630 uint coordToByteOffset(uint2 coord) {
636- return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
631+ return (coord.y * M_DIM + coord.x) * ELEM_SIZE;
637632 }
638633
639634 [WaveSize(4, 64)]
@@ -1394,6 +1389,7 @@ static void runOuterProduct(ID3D12Device *Device,
13941389}
13951390
13961391void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16 () {
1392+ /*
13971393 MatrixParams Params = {};
13981394 Params.CompType = ComponentType::F16;
13991395 Params.M = 16;
@@ -1403,6 +1399,10 @@ void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() {
14031399 Params.NumThreads = 1;
14041400 Params.Enable16Bit = true;
14051401 runOuterProduct(D3DDevice, DxcSupport, Params, VerboseLogging);
1402+ */
1403+ hlsl_test::LogCommentFmt (
1404+ L" Skipping test as not implemented" );
1405+ WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
14061406}
14071407
14081408static const char QueryAccumLayoutShader[] = R"(
@@ -1471,7 +1471,7 @@ static const char LoadMemoryShader[] = R"(
14711471 __builtin_LinAlg_MatrixLoadFromMemory(
14721472 Mat, GsData, OFFSET, STRIDE, LAYOUT);
14731473 __builtin_LinAlg_MatrixStoreToDescriptor(
1474- Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
1474+ Mat, Output, OFFSET, STRIDE * ELEM_SIZE , LAYOUT, 128);
14751475 }
14761476)" ;
14771477
@@ -1523,6 +1523,7 @@ void DxilConf_SM610_LinAlg::LoadMemory_Wave_16x16_F16() {
15231523 Params.Layout = LinalgMatrixLayout::RowMajor;
15241524 Params.NumThreads = 64 ;
15251525 Params.Enable16Bit = true ;
1526+ Params.GroupSharedMemory = true ;
15261527 runLoadMemory (D3DDevice, DxcSupport, Params, VerboseLogging);
15271528}
15281529
@@ -1592,6 +1593,7 @@ void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() {
15921593 Params.Layout = LinalgMatrixLayout::RowMajor;
15931594 Params.NumThreads = 64 ;
15941595 Params.Enable16Bit = true ;
1596+ Params.GroupSharedMemory = true ;
15951597 runStoreMemory (D3DDevice, DxcSupport, Params, VerboseLogging,
15961598 /* FillValue=*/ 7 .0f );
15971599}
@@ -1672,6 +1674,7 @@ void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16() {
16721674 Params.Layout = LinalgMatrixLayout::RowMajor;
16731675 Params.NumThreads = 64 ;
16741676 Params.Enable16Bit = true ;
1677+ Params.GroupSharedMemory = true ;
16751678 runAccumulateMemory (D3DDevice, DxcSupport, Params, VerboseLogging,
16761679 /* FillValue=*/ 7 .0f );
16771680}
0 commit comments