Skip to content

Commit bb28c90

Browse files
committed
[SM6.10][Exec] Implement Remaining Smoke Tests
1 parent 5f8d05f commit bb28c90

1 file changed

Lines changed: 260 additions & 11 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 260 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ class DxilConf_SM610_LinAlg {
307307

308308
// Element access
309309
TEST_METHOD(ElementAccess_Wave_16x16_F16);
310+
TEST_METHOD(ElementSet_Wave_16x16_F16);
311+
312+
// Cast/Convert
313+
TEST_METHOD(CopyConvert_Wave_16x16_F16);
314+
TEST_METHOD(CopyConvert_Wave_16x16_F16_Transpose);
315+
316+
// Matrix Arithmetic
317+
TEST_METHOD(MatMatMul_Wave_16x16x16_F16);
310318

311319
private:
312320
CComPtr<ID3D12Device> D3DDevice;
@@ -537,14 +545,9 @@ static void runElementAccess(ID3D12Device *Device,
537545
const MatrixParams &Params, bool Verbose) {
538546
const size_t NumElements = Params.totalElements();
539547
const size_t NumThreads = Params.NumThreads;
540-
const size_t InputBufSize = Params.totalBytes();
541-
const size_t ElementSize = elementSize(Params.CompType);
542-
543-
// Output: ElementSize bytes per element
544-
// 1 element for each mat idx
545-
// 1 uint for each thread's length
546-
const size_t OutputBufSize =
547-
NumElements * ElementSize + NumThreads * sizeof(uint32_t);
548+
const size_t MatrixSize = Params.totalBytes();
549+
// OutputBuf needs to fit the Matrix plus one uint per thread
550+
const size_t OutputBufSize = MatrixSize + NumThreads * sizeof(uint32_t);
548551

549552
std::stringstream ExtraDefs;
550553
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
@@ -555,7 +558,7 @@ static void runElementAccess(ID3D12Device *Device,
555558

556559
auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
557560
Args.c_str());
558-
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
561+
addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname");
559562
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
560563
addRootUAV(Op.get(), 0, "Input");
561564
addRootUAV(Op.get(), 1, "Output");
@@ -579,9 +582,8 @@ static void runElementAccess(ID3D12Device *Device,
579582
// Verify the end of the buffer is NumThreads number of lengths, whose
580583
// sum is greater than or equal to NumElements
581584
const BYTE *Out = static_cast<const BYTE *>(OutData.data());
582-
size_t MatrixEndOffset = NumElements * ElementSize;
583585
const uint32_t *Lengths =
584-
reinterpret_cast<const uint32_t *>(Out + MatrixEndOffset);
586+
reinterpret_cast<const uint32_t *>(Out + MatrixSize);
585587
uint32_t TotalLength = 0;
586588
for (size_t I = 0; I < NumThreads; ++I)
587589
TotalLength += Lengths[I];
@@ -602,4 +604,251 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
602604
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
603605
}
604606

607+
static const char ElementSetShader[] = R"(
608+
RWByteAddressBuffer Input : register(u0);
609+
RWByteAddressBuffer Output : register(u1);
610+
611+
[WaveSize(4, 64)]
612+
[numthreads(NUMTHREADS, 1, 1)]
613+
void main(uint threadID : SV_GroupIndex) {
614+
if (WaveReadLaneFirst(threadID) != 0)
615+
return;
616+
617+
__builtin_LinAlgMatrix
618+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
619+
Mat;
620+
__builtin_LinAlg_MatrixLoadFromDescriptor(
621+
Mat, Input, 0, STRIDE, LAYOUT, 128);
622+
623+
// Increment every element by 5
624+
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
625+
ELEM_TYPE Elem;
626+
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
627+
Elem = Elem + 5;
628+
__builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem);
629+
}
630+
631+
__builtin_LinAlg_MatrixStoreToDescriptor(
632+
Mat, Output, 0, STRIDE, LAYOUT, 128);
633+
}
634+
)";
635+
636+
static void runElementSet(ID3D12Device *Device,
637+
dxc::SpecificDllLoader &DxcSupport,
638+
const MatrixParams &Params, bool Verbose) {
639+
const size_t NumElements = Params.totalElements();
640+
const size_t MatrixSize = Params.totalBytes();
641+
642+
std::stringstream ExtraDefs;
643+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
644+
645+
compileShader(DxcSupport, ElementSetShader, "cs_6_10", Args, Verbose);
646+
647+
// Start counting from 6 since each element was increased by 5
648+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 6);
649+
650+
auto Op = createComputeOp(ElementSetShader, "cs_6_10", "UAV(u0), UAV(u1)",
651+
Args.c_str());
652+
addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname");
653+
addUAVBuffer(Op.get(), "Output", MatrixSize, true);
654+
addRootUAV(Op.get(), 0, "Input");
655+
addRootUAV(Op.get(), 1, "Output");
656+
657+
auto Result =
658+
runShaderOp(Device, DxcSupport, std::move(Op),
659+
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
660+
st::ShaderOp *) {
661+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
662+
NumElements),
663+
"Saw unsupported component type");
664+
});
665+
666+
MappedData OutData;
667+
Result->Test->GetReadBackData("Output", &OutData);
668+
669+
// Verify the front of the buffer is a list of elements of the expected type
670+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
671+
Expected, NumElements, Verbose));
672+
673+
}
674+
675+
void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16() {
676+
MatrixParams Params = {};
677+
Params.CompType = ComponentType::F16;
678+
Params.M = 16;
679+
Params.N = 16;
680+
Params.Use = MatrixUse::Accumulator;
681+
Params.Scope = MatrixScope::Wave;
682+
Params.Layout = LinalgMatrixLayout::RowMajor;
683+
Params.NumThreads = 64;
684+
Params.Enable16Bit = true;
685+
runElementSet(D3DDevice, DxcSupport, Params, VerboseLogging);
686+
}
687+
688+
static const char CopyConvertShader[] = R"(
689+
RWByteAddressBuffer Input : register(u0);
690+
RWByteAddressBuffer Output : register(u1);
691+
692+
[WaveSize(4, 64)]
693+
[numthreads(NUMTHREADS, 1, 1)]
694+
void main(uint threadID : SV_GroupIndex) {
695+
if (WaveReadLaneFirst(threadID) != 0)
696+
return;
697+
698+
__builtin_LinAlgMatrix
699+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
700+
Src;
701+
__builtin_LinAlgMatrix
702+
[[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]]
703+
Dst;
704+
705+
__builtin_LinAlg_MatrixLoadFromDescriptor(
706+
Src, Input, 0, STRIDE, LAYOUT, 128);
707+
__builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE);
708+
__builtin_LinAlg_MatrixStoreToDescriptor(
709+
Dst, Output, 0, STRIDE, LAYOUT, 128);
710+
}
711+
)";
712+
713+
static void runCopyConvert(ID3D12Device *Device,
714+
dxc::SpecificDllLoader &DxcSupport,
715+
const MatrixParams &Params, bool Verbose, bool Transpose) {
716+
const size_t NumElements = Params.totalElements();
717+
const size_t BufferSize = Params.totalBytes();
718+
719+
std::stringstream ExtraDefs;
720+
ExtraDefs << " -DTRANSPOSE=" << Transpose;
721+
722+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
723+
724+
compileShader(DxcSupport, CopyConvertShader, "cs_6_10", Args, Verbose);
725+
726+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1, /*Increment=*/true, Transpose);
727+
728+
// Construct the ShaderOp: two UAV buffers, load from one, store to other.
729+
auto Op = createComputeOp(CopyConvertShader, "cs_6_10", "UAV(u0), UAV(u1)",
730+
Args.c_str());
731+
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
732+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
733+
addRootUAV(Op.get(), 0, "Input");
734+
addRootUAV(Op.get(), 1, "Output");
735+
736+
auto Result =
737+
runShaderOp(Device, DxcSupport, std::move(Op),
738+
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
739+
st::ShaderOp *) {
740+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
741+
NumElements),
742+
"Saw unsupported component type");
743+
});
744+
745+
MappedData OutData;
746+
Result->Test->GetReadBackData("Output", &OutData);
747+
748+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
749+
Expected, NumElements, Verbose));
750+
}
751+
752+
void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16() {
753+
MatrixParams Params = {};
754+
Params.CompType = ComponentType::F16;
755+
Params.M = 16;
756+
Params.N = 16;
757+
Params.Use = MatrixUse::A;
758+
Params.Scope = MatrixScope::Wave;
759+
Params.Layout = LinalgMatrixLayout::RowMajor;
760+
Params.NumThreads = 64;
761+
Params.Enable16Bit = true;
762+
runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, /*Transpose=*/false);
763+
}
764+
765+
void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose() {
766+
MatrixParams Params = {};
767+
Params.CompType = ComponentType::F16;
768+
Params.M = 16;
769+
Params.N = 16;
770+
Params.Use = MatrixUse::A;
771+
Params.Scope = MatrixScope::Wave;
772+
Params.Layout = LinalgMatrixLayout::RowMajor;
773+
Params.NumThreads = 64;
774+
Params.Enable16Bit = true;
775+
runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, /*Transpose=*/true);
776+
}
777+
778+
static const char MatMatMulShader[] = R"(
779+
#define USE_A 0
780+
#define USE_B 1
781+
#define USE_ACC 2
782+
783+
RWByteAddressBuffer Output : register(u0);
784+
785+
[WaveSize(4, 64)]
786+
[numthreads(NUMTHREADS, 1, 1)]
787+
void main(uint threadID : SV_GroupIndex) {
788+
if (WaveReadLaneFirst(threadID) != 0)
789+
return;
790+
791+
__builtin_LinAlgMatrix
792+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]]
793+
MatA;
794+
__builtin_LinAlg_FillMatrix(MatA, A_FILL);
795+
796+
__builtin_LinAlgMatrix
797+
[[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]]
798+
MatB;
799+
__builtin_LinAlg_FillMatrix(MatB, B_FILL);
800+
801+
__builtin_LinAlgMatrix
802+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
803+
MatC;
804+
__builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB);
805+
806+
__builtin_LinAlg_MatrixStoreToDescriptor(
807+
MatC, Output, 0, STRIDE, LAYOUT, 128);
808+
}
809+
)";
810+
811+
static void runMatMatMul(ID3D12Device *Device,
812+
dxc::SpecificDllLoader &DxcSupport,
813+
const MatrixParams &Params, bool Verbose, MatrixDim K, float AFill, float BFill) {
814+
const size_t NumElements = Params.totalElements();
815+
const size_t BufferSize = Params.totalBytes();
816+
817+
std::stringstream ExtraDefs;
818+
ExtraDefs << " -DK_DIM=" << K;
819+
ExtraDefs << " -DA_FILL=" << AFill;
820+
ExtraDefs << " -DB_FILL=" << BFill;
821+
822+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
823+
824+
compileShader(DxcSupport, MatMatMulShader, "cs_6_10", Args, Verbose);
825+
826+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, AFill * BFill * K, /*Increment=*/false);
827+
828+
auto Op =
829+
createComputeOp(MatMatMulShader, "cs_6_10", "UAV(u0)", Args.c_str());
830+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
831+
addRootUAV(Op.get(), 0, "Output");
832+
833+
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
834+
835+
MappedData OutData;
836+
Result->Test->GetReadBackData("Output", &OutData);
837+
838+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
839+
Expected, NumElements, Verbose));
840+
}
841+
842+
void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
843+
MatrixParams Params = {};
844+
Params.CompType = ComponentType::F16;
845+
Params.M = 16;
846+
Params.N = 16;
847+
Params.Scope = MatrixScope::Wave;
848+
Params.Layout = LinalgMatrixLayout::RowMajor;
849+
Params.NumThreads = 64;
850+
Params.Enable16Bit = true;
851+
runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, /*AFill=*/2.0f, /*BFill=*/3.0f);
852+
}
853+
605854
} // namespace LinAlg

0 commit comments

Comments
 (0)