Skip to content

Commit 2c71946

Browse files
committed
More test fixes
1 parent 90b5414 commit 2c71946

1 file changed

Lines changed: 15 additions & 19 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ class DxilConf_SM610_LinAlg {
322322
TEST_METHOD(LoadStoreDescriptor_Wave_16x16_F16);
323323
TEST_METHOD(SplatStore_Wave_16x16_F16);
324324
TEST_METHOD(AccumulateDescriptor_Wave_16x16_F16);
325-
TEST_METHOD(AccumulateDescriptor_Thread_16x16_F16);
326325

327326
// Load/Store/Accumulate Memory
328327
TEST_METHOD(LoadMemory_Wave_16x16_F16);
@@ -539,6 +538,9 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
539538
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
540539
}
541540

541+
// Since MatrixAccumulateToDescriptor requires an accumulator matrix and
542+
// MatrixLoadFromDescriptor always returns an A matrix when loading a Thread
543+
// matrix this shader only makes sense for Wave/ThreadGroup
542544
static const char AccumulateDescriptorShader[] = R"(
543545
#define USE_ACC 2
544546
@@ -615,19 +617,6 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
615617
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 12, VerboseLogging);
616618
}
617619

618-
void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() {
619-
MatrixParams Params = {};
620-
Params.CompType = ComponentType::F16;
621-
Params.M = 16;
622-
Params.N = 16;
623-
Params.Use = MatrixUse::Accumulator;
624-
Params.Scope = MatrixScope::Thread;
625-
Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
626-
Params.NumThreads = 1;
627-
Params.Enable16Bit = true;
628-
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 19, VerboseLogging);
629-
}
630-
631620
static const char ElementAccessShader[] = R"(
632621
RWByteAddressBuffer Input : register(u0);
633622
RWByteAddressBuffer Output : register(u1);
@@ -1222,7 +1211,7 @@ void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() {
12221211
Params.M = 16;
12231212
Params.N = 16;
12241213
Params.Scope = MatrixScope::Thread;
1225-
Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
1214+
Params.Layout = LinalgMatrixLayout::RowMajor;
12261215
Params.NumThreads = 1;
12271216
Params.Enable16Bit = true;
12281217
runMatVecMul(D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -1317,7 +1306,7 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13171306
Params.M = 16;
13181307
Params.N = 16;
13191308
Params.Scope = MatrixScope::Thread;
1320-
Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
1309+
Params.Layout = LinalgMatrixLayout::RowMajor;
13211310
Params.NumThreads = 1;
13221311
Params.Enable16Bit = true;
13231312
runMatVecMulAdd(D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -1326,8 +1315,14 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13261315
}
13271316

13281317
static const char OuterProductShader[] = R"(
1329-
#define USE_A 0
1318+
// OuterProduct Matrix must be Thread scope
13301319
#define SCOPE_THREAD 0
1320+
// OuterProduct/Accumulate must be Accumulator use
1321+
#define USE_ACC 2
1322+
// Accumulate Layout must be OuterProductOptimal
1323+
#define LAYOUT_OUTER_PROD_OPT 4
1324+
// Accumulate Stride msut be 0 for non Row/Col Major
1325+
#define STRIDE 0
13311326
13321327
RWByteAddressBuffer Input : register(u0);
13331328
RWByteAddressBuffer Output : register(u1);
@@ -1347,12 +1342,12 @@ static const char OuterProductShader[] = R"(
13471342
}
13481343
13491344
__builtin_LinAlgMatrix
1350-
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
1345+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE_THREAD)]]
13511346
Mat;
13521347
__builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB);
13531348
13541349
__builtin_LinAlg_MatrixAccumulateToDescriptor(
1355-
Mat, Output, 0, STRIDE, LAYOUT, 128);
1350+
Mat, Output, 0, STRIDE, LAYOUT_OUTER_PROD_OPT, 128);
13561351
}
13571352
)";
13581353

@@ -1400,6 +1395,7 @@ void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() {
14001395
Params.CompType = ComponentType::F16;
14011396
Params.M = 16;
14021397
Params.N = 16;
1398+
Params.Use = MatrixUse::Accumulator;
14031399
Params.Scope = MatrixScope::Thread;
14041400
Params.Layout = LinalgMatrixLayout::OuterProductOptimal;
14051401
Params.NumThreads = 1;

0 commit comments

Comments
 (0)