@@ -331,6 +331,7 @@ class DxilConf_SM610_LinAlg {
331331 // Matrix Vector Arithmetic
332332 TEST_METHOD (MatVecMul_Thread_16x16_F16);
333333 TEST_METHOD (MatVecMulAdd_Thread_16x16_F16);
334+ TEST_METHOD (OuterProduct_Thread_16x16_F16);
334335
335336private:
336337 CComPtr<ID3D12Device> D3DDevice;
@@ -870,8 +871,8 @@ void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
870871 Params.Layout = LinalgMatrixLayout::RowMajor;
871872 Params.NumThreads = 64 ;
872873 Params.Enable16Bit = true ;
873- // runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16,
874- // /*AFill=*/2.0f, /*BFill=*/3.0f);
874+ runMatMatMul (D3DDevice, DxcSupport, Params, VerboseLogging, /* K=*/ 16 ,
875+ /* AFill=*/ 2 .0f , /* BFill=*/ 3 .0f );
875876}
876877
877878static const char MatMatMulAccumShader[] = R"(
@@ -1209,4 +1210,85 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
12091210 ComponentType::F16);
12101211}
12111212
1213+ static const char OuterProductShader[] = R"(
1214+ #define USE_A 0
1215+ #define SCOPE_THREAD 0
1216+
1217+ RWByteAddressBuffer Input : register(u0);
1218+ RWByteAddressBuffer Output : register(u1);
1219+
1220+ [numthreads(NUMTHREADS, 1, 1)]
1221+ void main() {
1222+ vector<ELEM_TYPE, M_DIM> VecA;
1223+ for (uint I = 0; I < M_DIM; ++I) {
1224+ VecA[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
1225+ }
1226+
1227+ uint EndVecA = M_DIM * ELEM_SIZE;
1228+
1229+ vector<ELEM_TYPE, N_DIM> VecB;
1230+ for (uint I = 0; I < N_DIM; ++I) {
1231+ VecB[I] = Input.Load<ELEM_TYPE>(EndVecA + I * ELEM_SIZE);
1232+ }
1233+
1234+ __builtin_LinAlgMatrix
1235+ [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
1236+ Mat;
1237+ __builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB);
1238+
1239+ __builtin_LinAlg_MatrixStoreToDescriptor(
1240+ Mat, Output, 0, STRIDE, LAYOUT, 128);
1241+ }
1242+ )" ;
1243+
1244+ static void runOuterProduct (ID3D12Device *Device,
1245+ dxc::SpecificDllLoader &DxcSupport,
1246+ const MatrixParams &Params, bool Verbose) {
1247+ const size_t NumVecElements = Params.M + Params.N ;
1248+ const size_t InBuffSize = NumVecElements * elementSize (Params.CompType );
1249+ const size_t NumMatElements = Params.totalElements ();
1250+ const size_t OutBufferSize = Params.totalBytes ();
1251+
1252+ std::string Args = buildCompilerArgs (Params);
1253+
1254+ compileShader (DxcSupport, OuterProductShader, " cs_6_10" , Args, Verbose);
1255+
1256+ auto Expected = makeExpected (Params.CompType , Params.M , Params.N ,
1257+ 4 , /* Increment=*/ false );
1258+
1259+ auto Op = createComputeOp (OuterProductShader, " cs_6_10" , " UAV(u0), UAV(u1)" ,
1260+ Args.c_str ());
1261+ addUAVBuffer (Op.get (), " Input" , InBuffSize, false , " byname" );
1262+ addUAVBuffer (Op.get (), " Output" , OutBufferSize, true );
1263+ addRootUAV (Op.get (), 0 , " Input" );
1264+ addRootUAV (Op.get (), 1 , " Output" );
1265+
1266+ auto Result = runShaderOp (
1267+ Device, DxcSupport, std::move (Op),
1268+ [NumVecElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1269+ st::ShaderOp *) {
1270+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType , NumVecElements,
1271+ /* StartingVal=*/ 2 , /* Increment=*/ false ),
1272+ " Saw unsupported component type" );
1273+ });
1274+
1275+ MappedData OutData;
1276+ Result->Test ->GetReadBackData (" Output" , &OutData);
1277+
1278+ VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
1279+ Expected, NumMatElements, Verbose));
1280+ }
1281+
1282+ void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16 () {
1283+ MatrixParams Params = {};
1284+ Params.CompType = ComponentType::F16;
1285+ Params.M = 16 ;
1286+ Params.N = 16 ;
1287+ Params.Scope = MatrixScope::Thread;
1288+ Params.Layout = LinalgMatrixLayout::RowMajor;
1289+ Params.NumThreads = 1 ;
1290+ Params.Enable16Bit = true ;
1291+ runOuterProduct (D3DDevice, DxcSupport, Params, VerboseLogging);
1292+ }
1293+
12121294} // namespace LinAlg
0 commit comments