Skip to content

Commit 740cd62

Browse files
committed
Add more tests
1 parent 6b0b6a5 commit 740cd62

1 file changed

Lines changed: 158 additions & 0 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ class DxilConf_SM610_LinAlg {
315315

316316
// Matrix Arithmetic
317317
TEST_METHOD(MatMatMul_Wave_16x16x16_F16);
318+
TEST_METHOD(MatMatMulAccum_Wave_16x16x16_F16);
319+
TEST_METHOD(MatAccum_Wave_16x16_F16);
318320

319321
private:
320322
CComPtr<ID3D12Device> D3DDevice;
@@ -857,4 +859,160 @@ void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
857859
/*AFill=*/2.0f, /*BFill=*/3.0f);
858860
}
859861

862+
static const char MatMatMulAccumShader[] = R"(
863+
#define USE_A 0
864+
#define USE_B 1
865+
#define USE_ACC 2
866+
867+
RWByteAddressBuffer Output : register(u0);
868+
869+
[WaveSize(4, 64)]
870+
[numthreads(NUMTHREADS, 1, 1)]
871+
void main(uint threadID : SV_GroupIndex) {
872+
if (WaveReadLaneFirst(threadID) != 0)
873+
return;
874+
875+
__builtin_LinAlgMatrix
876+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]]
877+
MatA;
878+
__builtin_LinAlg_FillMatrix(MatA, A_FILL);
879+
880+
__builtin_LinAlgMatrix
881+
[[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]]
882+
MatB;
883+
__builtin_LinAlg_FillMatrix(MatB, B_FILL);
884+
885+
__builtin_LinAlgMatrix
886+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
887+
MatC;
888+
__builtin_LinAlg_FillMatrix(MatC, C_FILL);
889+
890+
__builtin_LinAlg_MatrixMatrixMultiplyAccumulate(MatC, MatA, MatB, MatC);
891+
892+
__builtin_LinAlg_MatrixStoreToDescriptor(
893+
MatC, Output, 0, STRIDE, LAYOUT, 128);
894+
}
895+
)";
896+
897+
static void runMatMatMulAccum(ID3D12Device *Device,
898+
dxc::SpecificDllLoader &DxcSupport,
899+
const MatrixParams &Params, bool Verbose, MatrixDim K,
900+
float AFill, float BFill, float CFill) {
901+
const size_t NumElements = Params.totalElements();
902+
const size_t BufferSize = Params.totalBytes();
903+
904+
std::stringstream ExtraDefs;
905+
ExtraDefs << " -DK_DIM=" << K;
906+
ExtraDefs << " -DA_FILL=" << AFill;
907+
ExtraDefs << " -DB_FILL=" << BFill;
908+
ExtraDefs << " -DC_FILL=" << CFill;
909+
910+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
911+
912+
compileShader(DxcSupport, MatMatMulAccumShader, "cs_6_10", Args, Verbose);
913+
914+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N,
915+
AFill * BFill * K + CFill, /*Increment=*/false);
916+
917+
auto Op =
918+
createComputeOp(MatMatMulAccumShader, "cs_6_10", "UAV(u0)", Args.c_str());
919+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
920+
addRootUAV(Op.get(), 0, "Output");
921+
922+
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
923+
924+
MappedData OutData;
925+
Result->Test->GetReadBackData("Output", &OutData);
926+
927+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
928+
Expected, NumElements, Verbose));
929+
}
930+
931+
void DxilConf_SM610_LinAlg::MatMatMulAccum_Wave_16x16x16_F16() {
932+
MatrixParams Params = {};
933+
Params.CompType = ComponentType::F16;
934+
Params.M = 16;
935+
Params.N = 16;
936+
Params.Scope = MatrixScope::Wave;
937+
Params.Layout = LinalgMatrixLayout::RowMajor;
938+
Params.NumThreads = 64;
939+
Params.Enable16Bit = true;
940+
runMatMatMulAccum(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16,
941+
/*AFill=*/2.0f, /*BFill=*/3.0f, /*CFill=*/4.0f);
942+
}
943+
944+
static const char MatAccumShader[] = R"(
945+
#define USE_A 0
946+
#define USE_ACC 2
947+
948+
RWByteAddressBuffer Output : register(u0);
949+
950+
[WaveSize(4, 64)]
951+
[numthreads(NUMTHREADS, 1, 1)]
952+
void main(uint threadID : SV_GroupIndex) {
953+
if (WaveReadLaneFirst(threadID) != 0)
954+
return;
955+
956+
__builtin_LinAlgMatrix
957+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
958+
MatLHS;
959+
__builtin_LinAlg_FillMatrix(MatLHS, LHS_FILL);
960+
961+
__builtin_LinAlgMatrix
962+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE)]]
963+
MatRHS;
964+
__builtin_LinAlg_FillMatrix(MatRHS, RHS_FILL);
965+
966+
__builtin_LinAlg_MatrixAccumulate(MatLHS, MatLHS, MatRHS);
967+
968+
__builtin_LinAlg_MatrixStoreToDescriptor(
969+
MatLHS, Output, 0, STRIDE, LAYOUT, 128);
970+
}
971+
)";
972+
973+
static void runMatAccum(ID3D12Device *Device,
974+
dxc::SpecificDllLoader &DxcSupport,
975+
const MatrixParams &Params, bool Verbose,
976+
float LHSFill, float RHSFill) {
977+
const size_t NumElements = Params.totalElements();
978+
const size_t BufferSize = Params.totalBytes();
979+
980+
std::stringstream ExtraDefs;
981+
ExtraDefs << " -DLHS_FILL=" << LHSFill;
982+
ExtraDefs << " -DRHS_FILL=" << RHSFill;
983+
984+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
985+
986+
compileShader(DxcSupport, MatAccumShader, "cs_6_10", Args, Verbose);
987+
988+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N,
989+
LHSFill + RHSFill, /*Increment=*/false);
990+
991+
auto Op =
992+
createComputeOp(MatAccumShader, "cs_6_10", "UAV(u0)", Args.c_str());
993+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
994+
addRootUAV(Op.get(), 0, "Output");
995+
996+
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
997+
998+
MappedData OutData;
999+
Result->Test->GetReadBackData("Output", &OutData);
1000+
1001+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
1002+
Expected, NumElements, Verbose));
1003+
}
1004+
1005+
void DxilConf_SM610_LinAlg::MatAccum_Wave_16x16_F16() {
1006+
MatrixParams Params = {};
1007+
Params.CompType = ComponentType::F16;
1008+
Params.M = 16;
1009+
Params.N = 16;
1010+
Params.Scope = MatrixScope::Wave;
1011+
Params.Layout = LinalgMatrixLayout::RowMajor;
1012+
Params.NumThreads = 64;
1013+
Params.Enable16Bit = true;
1014+
runMatAccum(D3DDevice, DxcSupport, Params, VerboseLogging,
1015+
/*LHSFill=*/2.0f, /*RHSFill=*/3.0f);
1016+
}
1017+
8601018
} // namespace LinAlg

0 commit comments

Comments
 (0)