@@ -89,8 +89,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
8989 std::stringstream SS;
9090 SS << " -HV 202x" ;
9191 SS << " -DCOMP_TYPE=" << static_cast <int >(Params.CompType );
92- SS << " -DCOMP_TYPE_F16=" << 8 ;
93- SS << " -DCOMP_TYPE_F32=" << 9 ;
9492 SS << " -DM_DIM=" << Params.M ;
9593 SS << " -DN_DIM=" << Params.N ;
9694 SS << " -DUSE=" << static_cast <int >(Params.Use );
@@ -99,6 +97,17 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
9997 SS << " -DLAYOUT=" << static_cast <int >(Params.Layout );
10098 SS << " -DELEM_SIZE=" << elementSize (Params.CompType );
10199 SS << " -DNUMTHREADS=" << Params.NumThreads ;
100+ switch (Params.CompType ) {
101+ case ComponentType::F16:
102+ SS << " -DELEM_TYPE=half" ;
103+ break ;
104+ case ComponentType::F32:
105+ SS << " -DELEM_TYPE=float" ;
106+ break ;
107+ default :
108+ SS << " -DELEM_TYPE=uint" ;
109+ break ;
110+ }
102111 if (Params.EmulateTest )
103112 SS << " -DEMULATE_TEST" ;
104113 if (Params.Enable16Bit )
@@ -389,7 +398,7 @@ static const char LoadStoreShader[] = R"(
389398 [numthreads(NUMTHREADS, 1, 1)]
390399 void main() {
391400 for (uint I = 0; I < M_DIM*N_DIM; ++I) {
392- Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
401+ Output.Store<ELEM_TYPE> (I*ELEM_SIZE, Input.Load<ELEM_TYPE> (I*ELEM_SIZE));
393402 }
394403 }
395404#endif
@@ -477,15 +486,9 @@ static const char SplatStoreShader[] = R"(
477486#else
478487 [numthreads(NUMTHREADS, 1, 1)]
479488 void main() {
480- #if COMP_TYPE == COMP_TYPE_F32
481- float fill = FILL_VALUE;
482- #elif COMP_TYPE == COMP_TYPE_F16
483- half fill = FILL_VALUE;
484- #else
485- uint fill = FILL_VALUE;
486- #endif
489+ ELEM_TYPE fill = FILL_VALUE;
487490 for (uint I = 0; I < M_DIM*N_DIM; ++I) {
488- Output.Store(I*ELEM_SIZE, fill);
491+ Output.Store<ELEM_TYPE> (I*ELEM_SIZE, fill);
489492 }
490493 }
491494#endif
@@ -567,34 +570,28 @@ static const char ElementAccessShader[] = R"(
567570 for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
568571 uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
569572 uint Offset = coordToByteOffset(Coord);
570- #if COMP_TYPE == COMP_TYPE_F32
571- float Elem;
572- __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
573- Output.Store(Offset, asuint(Elem));
574- #else
575- uint Elem;
576- __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
577- Output.Store(Offset, Elem);
578- #endif
573+ ELEM_TYPE Elem;
574+ __builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
575+ Output.Store<ELEM_TYPE>(Offset, Elem);
579576 }
580577
581578 // Save the matrix length that this thread saw. The length is written
582579 // to the output right after the matrix, offset by the thread index
583580 uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
584581 uint Len = __builtin_LinAlg_MatrixLength(Mat);
585- Output.Store(LenIdx, Len);
582+ Output.Store<uint> (LenIdx, Len);
586583 }
587584#else
588585 [numthreads(NUMTHREADS, 1, 1)]
589586 void main(uint threadIndex : SV_GroupIndex) {
590587 uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
591- Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS);
588+ Output.Store<uint> (LenIdx, M_DIM * N_DIM / NUMTHREADS);
592589
593590 if (threadIndex != 0)
594591 return;
595592
596593 for (uint I = 0; I < M_DIM*N_DIM; ++I) {
597- Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
594+ Output.Store<ELEM_TYPE> (I*ELEM_SIZE, Input.Load<ELEM_TYPE> (I*ELEM_SIZE));
598595 }
599596 }
600597#endif
0 commit comments