Skip to content

Commit e752af9

Browse files
committed
Fix thread shader to use SRV and Load
1 parent 7ce1fe2 commit e752af9

1 file changed

Lines changed: 25 additions & 14 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,9 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
530530

531531
static const char AccumulateDescriptorShader[] = R"(
532532
#define USE_ACC 2
533-
RWByteAddressBuffer Output : register(u0);
533+
534+
ByteAddressBuffer Input : register(t0);
535+
RWByteAddressBuffer Output : register(u1);
534536
535537
[WaveSize(4, 64)]
536538
[numthreads(NUMTHREADS, 1, 1)]
@@ -541,35 +543,44 @@ static const char AccumulateDescriptorShader[] = R"(
541543
__builtin_LinAlgMatrix
542544
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
543545
Mat;
544-
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
546+
__builtin_LinAlg_MatrixLoadFromDescriptor(
547+
Mat, Input, 0, STRIDE, LAYOUT, 128);
548+
__builtin_LinAlg_MatrixAccumulateToDescriptor(
549+
Mat, Output, 0, STRIDE, LAYOUT, 128);
545550
__builtin_LinAlg_MatrixAccumulateToDescriptor(
546551
Mat, Output, 0, STRIDE, LAYOUT, 128);
547552
}
548553
)";
549554

550555
static void runAccumulateDescriptor(ID3D12Device *Device,
551556
dxc::SpecificDllLoader &DxcSupport,
552-
const MatrixParams &Params, float FillValue,
557+
const MatrixParams &Params, int FillValue,
553558
bool Verbose) {
554559
const size_t NumElements = Params.totalElements();
555560
const size_t BufferSize = Params.totalBytes();
556561

557-
std::stringstream ExtraDefs;
558-
ExtraDefs << "-DFILL_VALUE=" << FillValue;
559-
560-
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
562+
std::string Args = buildCompilerArgs(Params);
561563

562564
compileShader(DxcSupport, AccumulateDescriptorShader, "cs_6_10", Args, Verbose);
563565

564566
auto Expected =
565-
makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue, false);
567+
makeExpectedMat(Params.CompType, Params.M, Params.N, static_cast<float>(FillValue) * 2, false);
566568

567-
auto Op =
568-
createComputeOp(AccumulateDescriptorShader, "cs_6_10", "UAV(u0)", Args.c_str());
569+
auto Op = createComputeOp(AccumulateDescriptorShader, "cs_6_10",
570+
"SRV(t0), UAV(u1)", Args.c_str());
571+
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
569572
addUAVBuffer(Op.get(), "Output", BufferSize, true);
570-
addRootUAV(Op.get(), 0, "Output");
573+
addRootUAV(Op.get(), 0, "Input");
574+
addRootUAV(Op.get(), 1, "Output");
571575

572-
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
576+
auto Result =
577+
runShaderOp(Device, DxcSupport, std::move(Op),
578+
[NumElements, Params, FillValue](LPCSTR Name, std::vector<BYTE> &Data,
579+
st::ShaderOp *) {
580+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
581+
NumElements, /*StartingVal=*/ FillValue, /*Increment=*/false),
582+
"Saw unsupported component type");
583+
});
573584

574585
MappedData OutData;
575586
Result->Test->GetReadBackData("Output", &OutData);
@@ -588,7 +599,7 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() {
588599
Params.Layout = LinalgMatrixLayout::RowMajor;
589600
Params.NumThreads = 64;
590601
Params.Enable16Bit = true;
591-
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
602+
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 12, VerboseLogging);
592603
}
593604

594605
void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() {
@@ -601,7 +612,7 @@ void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() {
601612
Params.Layout = LinalgMatrixLayout::RowMajor;
602613
Params.NumThreads = 1;
603614
Params.Enable16Bit = true;
604-
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
615+
runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 19, VerboseLogging);
605616
}
606617

607618
static const char ElementAccessShader[] = R"(

0 commit comments

Comments
 (0)