Skip to content

Commit 56dcdbc

Browse files
[SM6.10][HLK] Use Template Load/Store for Shaders (#8350)
Build on top of #8343, will move out of draft after that merges Internal feedback was provided to use the templated Load/Store methods. This PR makes that change, and then simplifies the shaders/built in defines in response to those changes. --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ec6ba97 commit 56dcdbc

1 file changed

Lines changed: 20 additions & 23 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
8989
std::stringstream SS;
9090
SS << "-HV 202x";
9191
SS << " -DCOMP_TYPE=" << static_cast<int>(Params.CompType);
92-
SS << " -DCOMP_TYPE_F16=" << 8;
93-
SS << " -DCOMP_TYPE_F32=" << 9;
9492
SS << " -DM_DIM=" << Params.M;
9593
SS << " -DN_DIM=" << Params.N;
9694
SS << " -DUSE=" << static_cast<int>(Params.Use);
@@ -99,6 +97,17 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
9997
SS << " -DLAYOUT=" << static_cast<int>(Params.Layout);
10098
SS << " -DELEM_SIZE=" << elementSize(Params.CompType);
10199
SS << " -DNUMTHREADS=" << Params.NumThreads;
100+
switch (Params.CompType) {
101+
case ComponentType::F16:
102+
SS << " -DELEM_TYPE=half";
103+
break;
104+
case ComponentType::F32:
105+
SS << " -DELEM_TYPE=float";
106+
break;
107+
default:
108+
SS << " -DELEM_TYPE=uint";
109+
break;
110+
}
102111
if (Params.EmulateTest)
103112
SS << " -DEMULATE_TEST";
104113
if (Params.Enable16Bit)
@@ -389,7 +398,7 @@ static const char LoadStoreShader[] = R"(
389398
[numthreads(NUMTHREADS, 1, 1)]
390399
void main() {
391400
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
392-
Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
401+
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
393402
}
394403
}
395404
#endif
@@ -477,15 +486,9 @@ static const char SplatStoreShader[] = R"(
477486
#else
478487
[numthreads(NUMTHREADS, 1, 1)]
479488
void main() {
480-
#if COMP_TYPE == COMP_TYPE_F32
481-
float fill = FILL_VALUE;
482-
#elif COMP_TYPE == COMP_TYPE_F16
483-
half fill = FILL_VALUE;
484-
#else
485-
uint fill = FILL_VALUE;
486-
#endif
489+
ELEM_TYPE fill = FILL_VALUE;
487490
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
488-
Output.Store(I*ELEM_SIZE, fill);
491+
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, fill);
489492
}
490493
}
491494
#endif
@@ -567,34 +570,28 @@ static const char ElementAccessShader[] = R"(
567570
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
568571
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
569572
uint Offset = coordToByteOffset(Coord);
570-
#if COMP_TYPE == COMP_TYPE_F32
571-
float Elem;
572-
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
573-
Output.Store(Offset, asuint(Elem));
574-
#else
575-
uint Elem;
576-
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
577-
Output.Store(Offset, Elem);
578-
#endif
573+
ELEM_TYPE Elem;
574+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
575+
Output.Store<ELEM_TYPE>(Offset, Elem);
579576
}
580577
581578
// Save the matrix length that this thread saw. The length is written
582579
// to the output right after the matrix, offset by the thread index
583580
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
584581
uint Len = __builtin_LinAlg_MatrixLength(Mat);
585-
Output.Store(LenIdx, Len);
582+
Output.Store<uint>(LenIdx, Len);
586583
}
587584
#else
588585
[numthreads(NUMTHREADS, 1, 1)]
589586
void main(uint threadIndex : SV_GroupIndex) {
590587
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
591-
Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS);
588+
Output.Store<uint>(LenIdx, M_DIM * N_DIM / NUMTHREADS);
592589
593590
if (threadIndex != 0)
594591
return;
595592
596593
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
597-
Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
594+
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
598595
}
599596
}
600597
#endif

0 commit comments

Comments
 (0)