Skip to content

Commit 3e6b534

Browse files
committed
fix the matvecmul tests
1 parent e752af9 commit 3e6b534

1 file changed

Lines changed: 39 additions & 40 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,15 +1131,16 @@ static const char MatVecMulShader[] = R"(
11311131
#define USE_A 0
11321132
#define SCOPE_THREAD 0
11331133
1134-
RWByteAddressBuffer Input : register(u0);
1134+
ByteAddressBuffer Input : register(t0);
11351135
RWByteAddressBuffer Output : register(u1);
11361136
11371137
[numthreads(NUMTHREADS, 1, 1)]
11381138
void main() {
11391139
__builtin_LinAlgMatrix
11401140
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
11411141
Mat;
1142-
__builtin_LinAlg_FillMatrix(Mat, MAT_FILL);
1142+
__builtin_LinAlg_MatrixLoadFromDescriptor(
1143+
Mat, Input, 0, STRIDE, LAYOUT, 128);
11431144
11441145
vector<ELEM_TYPE, M_DIM> InVec;
11451146
for (uint I = 0; I < M_DIM; ++I) {
@@ -1159,44 +1160,43 @@ static const char MatVecMulShader[] = R"(
11591160
static void runMatVecMul(ID3D12Device *Device,
11601161
dxc::SpecificDllLoader &DxcSupport,
11611162
const MatrixParams &Params, bool Verbose,
1162-
float MatFill, bool OutputSigned,
1163+
int FillValue, bool OutputSigned,
11631164
ComponentType InputInterp) {
1164-
const size_t NumElements = Params.M;
1165-
const size_t BufferSize = elementSize(Params.CompType) * NumElements;
1165+
const size_t NumElements = Params.totalElements();
1166+
const size_t BufferSize = Params.totalBytes();
11661167

11671168
std::stringstream ExtraDefs;
1168-
ExtraDefs << " -DMAT_FILL=" << MatFill;
11691169
ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned;
11701170
ExtraDefs << " -DIN_INTERP=" << static_cast<int>(InputInterp);
11711171

11721172
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
11731173

11741174
compileShader(DxcSupport, MatVecMulShader, "cs_6_10", Args, Verbose);
11751175

1176-
auto Expected = makeExpectedVec(Params.CompType, Params.M, MatFill * Params.N,
1177-
/*Increment=*/false);
1176+
auto Expected = makeExpectedVec(Params.CompType, Params.M,
1177+
static_cast<float>(FillValue * FillValue * Params.N), /*Increment=*/false);
11781178

1179-
auto Op = createComputeOp(MatVecMulShader, "cs_6_10", "UAV(u0), UAV(u1)",
1180-
Args.c_str());
1179+
auto Op = createComputeOp(MatVecMulShader, "cs_6_10",
1180+
"SRV(t0), UAV(u1)", Args.c_str());
11811181
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
11821182
addUAVBuffer(Op.get(), "Output", BufferSize, true);
11831183
addRootUAV(Op.get(), 0, "Input");
11841184
addRootUAV(Op.get(), 1, "Output");
11851185

1186-
auto Result = runShaderOp(
1187-
Device, DxcSupport, std::move(Op),
1188-
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1189-
st::ShaderOp *) {
1190-
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements,
1191-
/*StartingVal=*/1, /*Increment=*/false),
1192-
"Saw unsupported component type");
1193-
});
1186+
auto Result =
1187+
runShaderOp(Device, DxcSupport, std::move(Op),
1188+
[NumElements, Params, FillValue](LPCSTR Name, std::vector<BYTE> &Data,
1189+
st::ShaderOp *) {
1190+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
1191+
NumElements, /*StartingVal=*/ FillValue, /*Increment=*/false),
1192+
"Saw unsupported component type");
1193+
});
11941194

11951195
MappedData OutData;
11961196
Result->Test->GetReadBackData("Output", &OutData);
11971197

11981198
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
1199-
Expected, NumElements, Verbose));
1199+
Expected, Params.M, Verbose));
12001200
}
12011201

12021202
void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() {
@@ -1209,29 +1209,29 @@ void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() {
12091209
Params.NumThreads = 1;
12101210
Params.Enable16Bit = true;
12111211
runMatVecMul(D3DDevice, DxcSupport, Params, VerboseLogging,
1212-
/*MatFill=*/2.0f, /*OutputSigned=*/true, ComponentType::F16);
1212+
/*FillValue=*/2, /*OutputSigned=*/true, ComponentType::F16);
12131213
}
12141214

12151215
static const char MatVecMulAddShader[] = R"(
12161216
#define USE_A 0
12171217
#define SCOPE_THREAD 0
12181218
1219-
RWByteAddressBuffer Input : register(u0);
1219+
ByteAddressBuffer Input : register(t0);
12201220
RWByteAddressBuffer Output : register(u1);
12211221
12221222
[numthreads(NUMTHREADS, 1, 1)]
12231223
void main() {
12241224
__builtin_LinAlgMatrix
12251225
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
12261226
Mat;
1227-
__builtin_LinAlg_FillMatrix(Mat, MAT_FILL);
1227+
__builtin_LinAlg_MatrixLoadFromDescriptor(
1228+
Mat, Input, 0, STRIDE, LAYOUT, 128);
12281229
12291230
vector<ELEM_TYPE, M_DIM> InVec;
12301231
for (uint I = 0; I < M_DIM; ++I) {
12311232
InVec[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
12321233
}
12331234
1234-
// TODO: this is just copying InVec but it should be a unique value
12351235
vector<ELEM_TYPE, M_DIM> BiasVec;
12361236
for (uint I = 0; I < M_DIM; ++I) {
12371237
BiasVec[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
@@ -1250,14 +1250,13 @@ static const char MatVecMulAddShader[] = R"(
12501250
static void runMatVecMulAdd(ID3D12Device *Device,
12511251
dxc::SpecificDllLoader &DxcSupport,
12521252
const MatrixParams &Params, bool Verbose,
1253-
float MatFill, bool OutputSigned,
1253+
int FillValue, bool OutputSigned,
12541254
ComponentType InputInterp,
12551255
ComponentType BiasInterp) {
1256-
const size_t NumElements = Params.M;
1257-
const size_t BufferSize = elementSize(Params.CompType) * NumElements;
1256+
const size_t NumElements = Params.totalElements();
1257+
const size_t BufferSize = Params.totalBytes();
12581258

12591259
std::stringstream ExtraDefs;
1260-
ExtraDefs << " -DMAT_FILL=" << MatFill;
12611260
ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned;
12621261
ExtraDefs << " -DIN_INTERP=" << static_cast<int>(InputInterp);
12631262
ExtraDefs << " -DBIAS_INTERP=" << static_cast<int>(BiasInterp);
@@ -1267,29 +1266,29 @@ static void runMatVecMulAdd(ID3D12Device *Device,
12671266
compileShader(DxcSupport, MatVecMulAddShader, "cs_6_10", Args, Verbose);
12681267

12691268
auto Expected = makeExpectedVec(Params.CompType, Params.M,
1270-
MatFill * Params.N + 1, /*Increment=*/false);
1269+
static_cast<float>(FillValue * FillValue * Params.N + FillValue), /*Increment=*/false);
12711270

1272-
auto Op = createComputeOp(MatVecMulAddShader, "cs_6_10", "UAV(u0), UAV(u1)",
1273-
Args.c_str());
1271+
auto Op = createComputeOp(MatVecMulAddShader, "cs_6_10",
1272+
"SRV(t0), UAV(u1)", Args.c_str());
12741273
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
12751274
addUAVBuffer(Op.get(), "Output", BufferSize, true);
12761275
addRootUAV(Op.get(), 0, "Input");
12771276
addRootUAV(Op.get(), 1, "Output");
12781277

1279-
auto Result = runShaderOp(
1280-
Device, DxcSupport, std::move(Op),
1281-
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1282-
st::ShaderOp *) {
1283-
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements,
1284-
/*StartingVal=*/1, /*Increment=*/false),
1285-
"Saw unsupported component type");
1286-
});
1278+
auto Result =
1279+
runShaderOp(Device, DxcSupport, std::move(Op),
1280+
[NumElements, Params, FillValue](LPCSTR Name, std::vector<BYTE> &Data,
1281+
st::ShaderOp *) {
1282+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
1283+
NumElements, /*StartingVal=*/ FillValue, /*Increment=*/false),
1284+
"Saw unsupported component type");
1285+
});
12871286

12881287
MappedData OutData;
12891288
Result->Test->GetReadBackData("Output", &OutData);
12901289

12911290
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
1292-
Expected, NumElements, Verbose));
1291+
Expected, Params.M, Verbose));
12931292
}
12941293

12951294
void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
@@ -1302,7 +1301,7 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13021301
Params.NumThreads = 1;
13031302
Params.Enable16Bit = true;
13041303
runMatVecMulAdd(D3DDevice, DxcSupport, Params, VerboseLogging,
1305-
/*MatFill=*/2.0f, /*OutputSigned=*/true, ComponentType::F16,
1304+
/*FillValue=*/2, /*OutputSigned=*/true, ComponentType::F16,
13061305
ComponentType::F16);
13071306
}
13081307

0 commit comments

Comments
 (0)