@@ -72,21 +72,15 @@ struct MatrixParams {
7272 int NumThreads;
7373 bool Enable16Bit;
7474 bool EmulateTest;
75- bool GroupSharedMemory = false ;
7675
77- size_t rowStride () const {
78- // If not Row/Col major, spec says to list 0.
79- size_t RowElementCount = 0 ;
76+ size_t strideBytes () const {
77+ uint32_t ES = elementSize (CompType);
8078 if (Layout == LinalgMatrixLayout::RowMajor)
81- RowElementCount = N ;
79+ return N * ES ;
8280 if (Layout == LinalgMatrixLayout::ColumnMajor)
83- RowElementCount = M;
84-
85- if (GroupSharedMemory)
86- return RowElementCount;
87-
88- uint32_t ElementSize = elementSize (CompType);
89- return RowElementCount * ElementSize;
81+ return M * ES;
82+ // If not Row/Col major, spec says to use 0
83+ return 0 ;
9084 }
9185
9286 size_t totalElements () const { return M * N; }
@@ -103,7 +97,7 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
10397 SS << " -DN_DIM=" << Params.N ;
10498 SS << " -DUSE=" << static_cast <int >(Params.Use );
10599 SS << " -DSCOPE=" << static_cast <int >(Params.Scope );
106- SS << " -DSTRIDE=" << Params.rowStride ();
100+ SS << " -DSTRIDE=" << Params.strideBytes ();
107101 SS << " -DLAYOUT=" << static_cast <int >(Params.Layout );
108102 SS << " -DELEM_SIZE=" << static_cast <int >(elementSize (Params.CompType ));
109103 SS << " -DNUMTHREADS=" << Params.NumThreads ;
@@ -1468,9 +1462,9 @@ static const char LoadMemoryShader[] = R"(
14681462 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
14691463 Mat;
14701464 __builtin_LinAlg_MatrixLoadFromMemory(
1471- Mat, GsData, OFFSET, STRIDE, LAYOUT);
1465+ Mat, GsData, OFFSET / ELEM_SIZE , STRIDE / ELEM_SIZE , LAYOUT);
14721466 __builtin_LinAlg_MatrixStoreToDescriptor(
1473- Mat, Output, OFFSET, STRIDE * ELEM_SIZE , LAYOUT, 128);
1467+ Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
14741468 }
14751469)" ;
14761470
@@ -1522,7 +1516,6 @@ void DxilConf_SM610_LinAlg::LoadMemory_Wave_16x16_F16() {
15221516 Params.Layout = LinalgMatrixLayout::RowMajor;
15231517 Params.NumThreads = 64 ;
15241518 Params.Enable16Bit = true ;
1525- Params.GroupSharedMemory = true ;
15261519 runLoadMemory (D3DDevice, DxcSupport, Params, VerboseLogging);
15271520}
15281521
@@ -1542,7 +1535,7 @@ static const char StoreMemoryShader[] = R"(
15421535 __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
15431536
15441537 __builtin_LinAlg_MatrixStoreToMemory(
1545- Mat, GsData, OFFSET, STRIDE, LAYOUT);
1538+ Mat, GsData, OFFSET / ELEM_SIZE , STRIDE / ELEM_SIZE , LAYOUT);
15461539
15471540 for (uint I = 0; I < M_DIM*N_DIM; ++I) {
15481541 Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
@@ -1592,7 +1585,6 @@ void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() {
15921585 Params.Layout = LinalgMatrixLayout::RowMajor;
15931586 Params.NumThreads = 64 ;
15941587 Params.Enable16Bit = true ;
1595- Params.GroupSharedMemory = true ;
15961588 runStoreMemory (D3DDevice, DxcSupport, Params, VerboseLogging,
15971589 /* FillValue=*/ 7 .0f );
15981590}
@@ -1623,7 +1615,7 @@ static const char AccumulateMemoryShader[] = R"(
16231615 __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
16241616
16251617 __builtin_LinAlg_MatrixAccumulateToMemory(
1626- Mat, GsData, OFFSET, STRIDE, LAYOUT);
1618+ Mat, GsData, OFFSET / ELEM_SIZE , STRIDE / ELEM_SIZE , LAYOUT);
16271619
16281620 for (uint I = 0; I < M_DIM*N_DIM; ++I) {
16291621 Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
@@ -1673,7 +1665,6 @@ void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16() {
16731665 Params.Layout = LinalgMatrixLayout::RowMajor;
16741666 Params.NumThreads = 64 ;
16751667 Params.Enable16Bit = true ;
1676- Params.GroupSharedMemory = true ;
16771668 runAccumulateMemory (D3DDevice, DxcSupport, Params, VerboseLogging,
16781669 /* FillValue=*/ 7 .0f );
16791670}
0 commit comments