@@ -322,7 +322,6 @@ class DxilConf_SM610_LinAlg {
322322 TEST_METHOD (LoadStoreDescriptor_Wave_16x16_F16);
323323 TEST_METHOD (SplatStore_Wave_16x16_F16);
324324 TEST_METHOD (AccumulateDescriptor_Wave_16x16_F16);
325- TEST_METHOD (AccumulateDescriptor_Thread_16x16_F16);
326325
327326 // Load/Store/Accumulate Memory
328327 TEST_METHOD (LoadMemory_Wave_16x16_F16);
@@ -539,6 +538,9 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
539538 runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
540539}
541540
541+ // Since MatrixAccumulateToDescriptor requires an accumulator matrix and
542+ // MatrixLoadFromDescriptor always returns an A matrix when loading a Thread
543+ // matrix this shader only makes sense for Wave/ThreadGroup
542544static const char AccumulateDescriptorShader[] = R"(
543545 #define USE_ACC 2
544546
@@ -615,19 +617,6 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
615617 runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 12 , VerboseLogging);
616618}
617619
618- void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16 () {
619- MatrixParams Params = {};
620- Params.CompType = ComponentType::F16;
621- Params.M = 16 ;
622- Params.N = 16 ;
623- Params.Use = MatrixUse::Accumulator;
624- Params.Scope = MatrixScope::Thread;
625- Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
626- Params.NumThreads = 1 ;
627- Params.Enable16Bit = true ;
628- runAccumulateDescriptor (D3DDevice, DxcSupport, Params, 19 , VerboseLogging);
629- }
630-
631620static const char ElementAccessShader[] = R"(
632621 RWByteAddressBuffer Input : register(u0);
633622 RWByteAddressBuffer Output : register(u1);
@@ -1222,7 +1211,7 @@ void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() {
12221211 Params.M = 16 ;
12231212 Params.N = 16 ;
12241213 Params.Scope = MatrixScope::Thread;
1225- Params.Layout = LinalgMatrixLayout::OuterProductOptimal ;
1214+ Params.Layout = LinalgMatrixLayout::RowMajor ;
12261215 Params.NumThreads = 1 ;
12271216 Params.Enable16Bit = true ;
12281217 runMatVecMul (D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -1317,7 +1306,7 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13171306 Params.M = 16 ;
13181307 Params.N = 16 ;
13191308 Params.Scope = MatrixScope::Thread;
1320- Params.Layout = LinalgMatrixLayout::OuterProductOptimal ;
1309+ Params.Layout = LinalgMatrixLayout::RowMajor ;
13211310 Params.NumThreads = 1 ;
13221311 Params.Enable16Bit = true ;
13231312 runMatVecMulAdd (D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -1326,8 +1315,14 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13261315}
13271316
13281317static const char OuterProductShader[] = R"(
1329- #define USE_A 0
1318+ // OuterProduct Matrix must be Thread scope
13301319 #define SCOPE_THREAD 0
1320+ // OuterProduct/Accumulate must be Accumulator use
1321+ #define USE_ACC 2
1322+ // Accumulate Layout must be OuterProductOptimal
1323+ #define LAYOUT_OUTER_PROD_OPT 4
1324+ // Accumulate Stride msut be 0 for non Row/Col Major
1325+ #define STRIDE 0
13311326
13321327 RWByteAddressBuffer Input : register(u0);
13331328 RWByteAddressBuffer Output : register(u1);
@@ -1347,12 +1342,12 @@ static const char OuterProductShader[] = R"(
13471342 }
13481343
13491344 __builtin_LinAlgMatrix
1350- [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A , SCOPE_THREAD)]]
1345+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC , SCOPE_THREAD)]]
13511346 Mat;
13521347 __builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB);
13531348
13541349 __builtin_LinAlg_MatrixAccumulateToDescriptor(
1355- Mat, Output, 0, STRIDE, LAYOUT , 128);
1350+ Mat, Output, 0, STRIDE, LAYOUT_OUTER_PROD_OPT , 128);
13561351 }
13571352)" ;
13581353
@@ -1400,6 +1395,7 @@ void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() {
14001395 Params.CompType = ComponentType::F16;
14011396 Params.M = 16 ;
14021397 Params.N = 16 ;
1398+ Params.Use = MatrixUse::Accumulator;
14031399 Params.Scope = MatrixScope::Thread;
14041400 Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
14051401 Params.NumThreads = 1 ;
0 commit comments