Skip to content

Commit be24353

Browse files
committed
outer product
1 parent 8f45a0c commit be24353

1 file changed

Lines changed: 84 additions & 2 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ class DxilConf_SM610_LinAlg {
331331
// Matrix Vector Arithmetic
332332
TEST_METHOD(MatVecMul_Thread_16x16_F16);
333333
TEST_METHOD(MatVecMulAdd_Thread_16x16_F16);
334+
TEST_METHOD(OuterProduct_Thread_16x16_F16);
334335

335336
private:
336337
CComPtr<ID3D12Device> D3DDevice;
@@ -870,8 +871,8 @@ void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
870871
Params.Layout = LinalgMatrixLayout::RowMajor;
871872
Params.NumThreads = 64;
872873
Params.Enable16Bit = true;
873-
// runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16,
874-
// /*AFill=*/2.0f, /*BFill=*/3.0f);
874+
runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16,
875+
/*AFill=*/2.0f, /*BFill=*/3.0f);
875876
}
876877

877878
static const char MatMatMulAccumShader[] = R"(
@@ -1209,4 +1210,85 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
12091210
ComponentType::F16);
12101211
}
12111212

1213+
static const char OuterProductShader[] = R"(
1214+
#define USE_A 0
1215+
#define SCOPE_THREAD 0
1216+
1217+
RWByteAddressBuffer Input : register(u0);
1218+
RWByteAddressBuffer Output : register(u1);
1219+
1220+
[numthreads(NUMTHREADS, 1, 1)]
1221+
void main() {
1222+
vector<ELEM_TYPE, M_DIM> VecA;
1223+
for (uint I = 0; I < M_DIM; ++I) {
1224+
VecA[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
1225+
}
1226+
1227+
uint EndVecA = M_DIM * ELEM_SIZE;
1228+
1229+
vector<ELEM_TYPE, N_DIM> VecB;
1230+
for (uint I = 0; I < N_DIM; ++I) {
1231+
VecB[I] = Input.Load<ELEM_TYPE>(EndVecA + I * ELEM_SIZE);
1232+
}
1233+
1234+
__builtin_LinAlgMatrix
1235+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
1236+
Mat;
1237+
__builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB);
1238+
1239+
__builtin_LinAlg_MatrixStoreToDescriptor(
1240+
Mat, Output, 0, STRIDE, LAYOUT, 128);
1241+
}
1242+
)";
1243+
1244+
static void runOuterProduct(ID3D12Device *Device,
1245+
dxc::SpecificDllLoader &DxcSupport,
1246+
const MatrixParams &Params, bool Verbose) {
1247+
const size_t NumVecElements = Params.M + Params.N;
1248+
const size_t InBuffSize = NumVecElements * elementSize(Params.CompType);
1249+
const size_t NumMatElements = Params.totalElements();
1250+
const size_t OutBufferSize = Params.totalBytes();
1251+
1252+
std::string Args = buildCompilerArgs(Params);
1253+
1254+
compileShader(DxcSupport, OuterProductShader, "cs_6_10", Args, Verbose);
1255+
1256+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N,
1257+
4, /*Increment=*/false);
1258+
1259+
auto Op = createComputeOp(OuterProductShader, "cs_6_10", "UAV(u0), UAV(u1)",
1260+
Args.c_str());
1261+
addUAVBuffer(Op.get(), "Input", InBuffSize, false, "byname");
1262+
addUAVBuffer(Op.get(), "Output", OutBufferSize, true);
1263+
addRootUAV(Op.get(), 0, "Input");
1264+
addRootUAV(Op.get(), 1, "Output");
1265+
1266+
auto Result = runShaderOp(
1267+
Device, DxcSupport, std::move(Op),
1268+
[NumVecElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1269+
st::ShaderOp *) {
1270+
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumVecElements,
1271+
/*StartingVal=*/2, /*Increment=*/false),
1272+
"Saw unsupported component type");
1273+
});
1274+
1275+
MappedData OutData;
1276+
Result->Test->GetReadBackData("Output", &OutData);
1277+
1278+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
1279+
Expected, NumMatElements, Verbose));
1280+
}
1281+
1282+
void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() {
1283+
MatrixParams Params = {};
1284+
Params.CompType = ComponentType::F16;
1285+
Params.M = 16;
1286+
Params.N = 16;
1287+
Params.Scope = MatrixScope::Thread;
1288+
Params.Layout = LinalgMatrixLayout::RowMajor;
1289+
Params.NumThreads = 1;
1290+
Params.Enable16Bit = true;
1291+
runOuterProduct(D3DDevice, DxcSupport, Params, VerboseLogging);
1292+
}
1293+
12121294
} // namespace LinAlg

0 commit comments

Comments
 (0)