@@ -468,13 +468,9 @@ static const char ElementAccessShader[] = R"(
468468 RWByteAddressBuffer Input : register(u0);
469469 RWByteAddressBuffer Output : register(u1);
470470
471- // 0,0 = 0
472- // 0,1 = 4
473- // 1,0 = 16
474- // 3,3 = 60
475- // TODO: this assumes M=4,N=4
471+ // flatten the 2D index into a 1d index then scale by element size
476472 uint cordToByteOffset(uint2 coord) {
477- return (coord.x * 4 + coord.y) * 4 ;
473+ return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE ;
478474 }
479475
480476 [numthreads(NUMTHREADS, 1, 1)]
@@ -501,24 +497,28 @@ static const char ElementAccessShader[] = R"(
501497 }
502498
503499 // Store each threads Length in the output after the copied matrix
504- uint finalIdx = (M_DIM * N_DIM + threadIndex) * 4 ;
500+ uint finalIdx = (M_DIM * N_DIM + threadIndex) * ELEM_SIZE ;
505501 uint Len = __builtin_LinAlg_MatrixLength(Mat);
506502 Output.Store(finalIdx, Len);
507503 }
508504)" ;
509505
510506static void runElementAccess (ID3D12Device *Device,
511507 dxc::SpecificDllLoader &DxcSupport,
512- const MatrixParams &Params, bool Verbose) {
508+ const MatrixParams &Params, int MajorDim, bool Verbose) {
513509 const size_t NumElements = Params.totalElements ();
514510 const size_t NumThreads = Params.NumThreads ;
515511 const size_t InputBufSize = Params.totalBytes ();
516- // Output: 4 bytes per element
512+ const size_t ElementSize = elemSize (Params.CompType );
513+ // Output: ElementSize bytes per element
517514 // 1 element for each mat idx
518515 // 1 element for each thread's length
519- const size_t OutputBufSize = (NumElements + NumThreads) * 4 ;
516+ const size_t OutputBufSize = (NumElements + NumThreads) * ElementSize ;
520517
521- std::string Args = buildCompilerArgs (Params);
518+ std::stringstream ExtraDefs;
519+ ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
520+ ExtraDefs << " -DELEM_SIZE=" << ElementSize;
521+ std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
522522
523523 compileShader (DxcSupport, ElementAccessShader, " cs_6_10" , Args, Verbose);
524524
@@ -633,7 +633,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
633633 Params.Layout = LinalgMatrixLayout::RowMajor;
634634 Params.NumThreads = 4 ;
635635 Params.Enable16Bit = false ;
636- runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
636+ runElementAccess (D3DDevice, DxcSupport, Params, Params. M , VerboseLogging);
637637}
638638
639639void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32 () {
@@ -646,7 +646,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
646646 Params.Layout = LinalgMatrixLayout::RowMajor;
647647 Params.NumThreads = 4 ;
648648 Params.Enable16Bit = false ;
649- runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
649+ runElementAccess (D3DDevice, DxcSupport, Params, Params. M , VerboseLogging);
650650}
651651
652652} // namespace LinAlg
0 commit comments