Skip to content

Commit 54dad8f

Browse files
committed
comments
1 parent e0f9777 commit 54dad8f

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,9 @@ static const char ElementAccessShader[] = R"(
468468
RWByteAddressBuffer Input : register(u0);
469469
RWByteAddressBuffer Output : register(u1);
470470
471-
// 0,0 = 0
472-
// 0,1 = 4
473-
// 1,0 = 16
474-
// 3,3 = 60
475-
// TODO: this assumes M=4,N=4
471+
// flatten the 2D index into a 1d index then scale by element size
476472
uint cordToByteOffset(uint2 coord) {
477-
return (coord.x * 4 + coord.y) * 4;
473+
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
478474
}
479475
480476
[numthreads(NUMTHREADS, 1, 1)]
@@ -501,24 +497,28 @@ static const char ElementAccessShader[] = R"(
501497
}
502498
503499
// Store each threads Length in the output after the copied matrix
504-
uint finalIdx = (M_DIM * N_DIM + threadIndex) * 4;
500+
uint finalIdx = (M_DIM * N_DIM + threadIndex) * ELEM_SIZE;
505501
uint Len = __builtin_LinAlg_MatrixLength(Mat);
506502
Output.Store(finalIdx, Len);
507503
}
508504
)";
509505

510506
static void runElementAccess(ID3D12Device *Device,
511507
dxc::SpecificDllLoader &DxcSupport,
512-
const MatrixParams &Params, bool Verbose) {
508+
const MatrixParams &Params, int MajorDim, bool Verbose) {
513509
const size_t NumElements = Params.totalElements();
514510
const size_t NumThreads = Params.NumThreads;
515511
const size_t InputBufSize = Params.totalBytes();
516-
// Output: 4 bytes per element
512+
const size_t ElementSize = elemSize(Params.CompType);
513+
// Output: ElementSize bytes per element
517514
// 1 element for each mat idx
518515
// 1 element for each thread's length
519-
const size_t OutputBufSize = (NumElements + NumThreads) * 4;
516+
const size_t OutputBufSize = (NumElements + NumThreads) * ElementSize;
520517

521-
std::string Args = buildCompilerArgs(Params);
518+
std::stringstream ExtraDefs;
519+
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
520+
ExtraDefs << " -DELEM_SIZE=" << ElementSize;
521+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
522522

523523
compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);
524524

@@ -633,7 +633,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
633633
Params.Layout = LinalgMatrixLayout::RowMajor;
634634
Params.NumThreads = 4;
635635
Params.Enable16Bit = false;
636-
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
636+
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
637637
}
638638

639639
void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
@@ -646,7 +646,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
646646
Params.Layout = LinalgMatrixLayout::RowMajor;
647647
Params.NumThreads = 4;
648648
Params.Enable16Bit = false;
649-
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
649+
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
650650
}
651651

652652
} // namespace LinAlg

0 commit comments

Comments
 (0)