Skip to content

Commit 7e6700d

Browse files
committed
Address comments
1 parent 9df4f0a commit 7e6700d

1 file changed

Lines changed: 11 additions & 20 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,15 @@ struct MatrixParams {
7272
int NumThreads;
7373
bool Enable16Bit;
7474
bool EmulateTest;
75-
bool GroupSharedMemory = false;
7675

77-
size_t rowStride() const {
78-
// If not Row/Col major, spec says to list 0.
79-
size_t RowElementCount = 0;
76+
size_t strideBytes() const {
77+
uint32_t ES = elementSize(CompType);
8078
if (Layout == LinalgMatrixLayout::RowMajor)
81-
RowElementCount = N;
79+
return N * ES;
8280
if (Layout == LinalgMatrixLayout::ColumnMajor)
83-
RowElementCount = M;
84-
85-
if (GroupSharedMemory)
86-
return RowElementCount;
87-
88-
uint32_t ElementSize = elementSize(CompType);
89-
return RowElementCount * ElementSize;
81+
return M * ES;
82+
// If not Row/Col major, spec says to use 0
83+
return 0;
9084
}
9185

9286
size_t totalElements() const { return M * N; }
@@ -103,7 +97,7 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
10397
SS << " -DN_DIM=" << Params.N;
10498
SS << " -DUSE=" << static_cast<int>(Params.Use);
10599
SS << " -DSCOPE=" << static_cast<int>(Params.Scope);
106-
SS << " -DSTRIDE=" << Params.rowStride();
100+
SS << " -DSTRIDE=" << Params.strideBytes();
107101
SS << " -DLAYOUT=" << static_cast<int>(Params.Layout);
108102
SS << " -DELEM_SIZE=" << static_cast<int>(elementSize(Params.CompType));
109103
SS << " -DNUMTHREADS=" << Params.NumThreads;
@@ -1468,9 +1462,9 @@ static const char LoadMemoryShader[] = R"(
14681462
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
14691463
Mat;
14701464
__builtin_LinAlg_MatrixLoadFromMemory(
1471-
Mat, GsData, OFFSET, STRIDE, LAYOUT);
1465+
Mat, GsData, OFFSET / ELEM_SIZE, STRIDE / ELEM_SIZE, LAYOUT);
14721466
__builtin_LinAlg_MatrixStoreToDescriptor(
1473-
Mat, Output, OFFSET, STRIDE * ELEM_SIZE, LAYOUT, 128);
1467+
Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
14741468
}
14751469
)";
14761470

@@ -1522,7 +1516,6 @@ void DxilConf_SM610_LinAlg::LoadMemory_Wave_16x16_F16() {
15221516
Params.Layout = LinalgMatrixLayout::RowMajor;
15231517
Params.NumThreads = 64;
15241518
Params.Enable16Bit = true;
1525-
Params.GroupSharedMemory = true;
15261519
runLoadMemory(D3DDevice, DxcSupport, Params, VerboseLogging);
15271520
}
15281521

@@ -1542,7 +1535,7 @@ static const char StoreMemoryShader[] = R"(
15421535
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
15431536
15441537
__builtin_LinAlg_MatrixStoreToMemory(
1545-
Mat, GsData, OFFSET, STRIDE, LAYOUT);
1538+
Mat, GsData, OFFSET / ELEM_SIZE, STRIDE / ELEM_SIZE, LAYOUT);
15461539
15471540
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
15481541
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
@@ -1592,7 +1585,6 @@ void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() {
15921585
Params.Layout = LinalgMatrixLayout::RowMajor;
15931586
Params.NumThreads = 64;
15941587
Params.Enable16Bit = true;
1595-
Params.GroupSharedMemory = true;
15961588
runStoreMemory(D3DDevice, DxcSupport, Params, VerboseLogging,
15971589
/*FillValue=*/7.0f);
15981590
}
@@ -1623,7 +1615,7 @@ static const char AccumulateMemoryShader[] = R"(
16231615
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
16241616
16251617
__builtin_LinAlg_MatrixAccumulateToMemory(
1626-
Mat, GsData, OFFSET, STRIDE, LAYOUT);
1618+
Mat, GsData, OFFSET / ELEM_SIZE, STRIDE / ELEM_SIZE, LAYOUT);
16271619
16281620
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
16291621
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
@@ -1673,7 +1665,6 @@ void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16() {
16731665
Params.Layout = LinalgMatrixLayout::RowMajor;
16741666
Params.NumThreads = 64;
16751667
Params.Enable16Bit = true;
1676-
Params.GroupSharedMemory = true;
16771668
runAccumulateMemory(D3DDevice, DxcSupport, Params, VerboseLogging,
16781669
/*FillValue=*/7.0f);
16791670
}

0 commit comments

Comments
 (0)