@@ -1131,15 +1131,16 @@ static const char MatVecMulShader[] = R"(
11311131 #define USE_A 0
11321132 #define SCOPE_THREAD 0
11331133
1134- RWByteAddressBuffer Input : register(u0 );
1134+ ByteAddressBuffer Input : register(t0 );
11351135 RWByteAddressBuffer Output : register(u1);
11361136
11371137 [numthreads(NUMTHREADS, 1, 1)]
11381138 void main() {
11391139 __builtin_LinAlgMatrix
11401140 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
11411141 Mat;
1142- __builtin_LinAlg_FillMatrix(Mat, MAT_FILL);
1142+ __builtin_LinAlg_MatrixLoadFromDescriptor(
1143+ Mat, Input, 0, STRIDE, LAYOUT, 128);
11431144
11441145 vector<ELEM_TYPE, M_DIM> InVec;
11451146 for (uint I = 0; I < M_DIM; ++I) {
@@ -1159,44 +1160,43 @@ static const char MatVecMulShader[] = R"(
11591160static void runMatVecMul (ID3D12Device *Device,
11601161 dxc::SpecificDllLoader &DxcSupport,
11611162 const MatrixParams &Params, bool Verbose,
1162- float MatFill , bool OutputSigned,
1163+ int FillValue , bool OutputSigned,
11631164 ComponentType InputInterp) {
1164- const size_t NumElements = Params.M ;
1165- const size_t BufferSize = elementSize ( Params.CompType ) * NumElements ;
1165+ const size_t NumElements = Params.totalElements () ;
1166+ const size_t BufferSize = Params.totalBytes () ;
11661167
11671168 std::stringstream ExtraDefs;
1168- ExtraDefs << " -DMAT_FILL=" << MatFill;
11691169 ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned;
11701170 ExtraDefs << " -DIN_INTERP=" << static_cast <int >(InputInterp);
11711171
11721172 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
11731173
11741174 compileShader (DxcSupport, MatVecMulShader, " cs_6_10" , Args, Verbose);
11751175
1176- auto Expected = makeExpectedVec (Params.CompType , Params.M , MatFill * Params. N ,
1177- /* Increment=*/ false );
1176+ auto Expected = makeExpectedVec (Params.CompType , Params.M ,
1177+ static_cast < float >(FillValue * FillValue * Params. N ), /* Increment=*/ false );
11781178
1179- auto Op = createComputeOp (MatVecMulShader, " cs_6_10" , " UAV(u0), UAV(u1) " ,
1180- Args.c_str ());
1179+ auto Op = createComputeOp (MatVecMulShader, " cs_6_10" ,
1180+ " SRV(t0), UAV(u1) " , Args.c_str ());
11811181 addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
11821182 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
11831183 addRootUAV (Op.get (), 0 , " Input" );
11841184 addRootUAV (Op.get (), 1 , " Output" );
11851185
1186- auto Result = runShaderOp (
1187- Device, DxcSupport, std::move (Op),
1188- [NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1189- st::ShaderOp *) {
1190- VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType , NumElements ,
1191- /* StartingVal=*/ 1 , /* Increment=*/ false ),
1192- " Saw unsupported component type" );
1193- });
1186+ auto Result =
1187+ runShaderOp ( Device, DxcSupport, std::move (Op),
1188+ [NumElements, Params, FillValue ](LPCSTR Name, std::vector<BYTE> &Data,
1189+ st::ShaderOp *) {
1190+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType ,
1191+ NumElements, /* StartingVal=*/ FillValue , /* Increment=*/ false ),
1192+ " Saw unsupported component type" );
1193+ });
11941194
11951195 MappedData OutData;
11961196 Result->Test ->GetReadBackData (" Output" , &OutData);
11971197
11981198 VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
1199- Expected, NumElements , Verbose));
1199+ Expected, Params. M , Verbose));
12001200}
12011201
12021202void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16 () {
@@ -1209,29 +1209,29 @@ void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() {
12091209 Params.NumThreads = 1 ;
12101210 Params.Enable16Bit = true ;
12111211 runMatVecMul (D3DDevice, DxcSupport, Params, VerboseLogging,
1212- /* MatFill =*/ 2 . 0f , /* OutputSigned=*/ true , ComponentType::F16);
1212+ /* FillValue =*/ 2 , /* OutputSigned=*/ true , ComponentType::F16);
12131213}
12141214
12151215static const char MatVecMulAddShader[] = R"(
12161216 #define USE_A 0
12171217 #define SCOPE_THREAD 0
12181218
1219- RWByteAddressBuffer Input : register(u0 );
1219+ ByteAddressBuffer Input : register(t0 );
12201220 RWByteAddressBuffer Output : register(u1);
12211221
12221222 [numthreads(NUMTHREADS, 1, 1)]
12231223 void main() {
12241224 __builtin_LinAlgMatrix
12251225 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]]
12261226 Mat;
1227- __builtin_LinAlg_FillMatrix(Mat, MAT_FILL);
1227+ __builtin_LinAlg_MatrixLoadFromDescriptor(
1228+ Mat, Input, 0, STRIDE, LAYOUT, 128);
12281229
12291230 vector<ELEM_TYPE, M_DIM> InVec;
12301231 for (uint I = 0; I < M_DIM; ++I) {
12311232 InVec[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
12321233 }
12331234
1234- // TODO: this is just copying InVec but it should be a unique value
12351235 vector<ELEM_TYPE, M_DIM> BiasVec;
12361236 for (uint I = 0; I < M_DIM; ++I) {
12371237 BiasVec[I] = Input.Load<ELEM_TYPE>(I * ELEM_SIZE);
@@ -1250,14 +1250,13 @@ static const char MatVecMulAddShader[] = R"(
12501250static void runMatVecMulAdd (ID3D12Device *Device,
12511251 dxc::SpecificDllLoader &DxcSupport,
12521252 const MatrixParams &Params, bool Verbose,
1253- float MatFill , bool OutputSigned,
1253+ int FillValue , bool OutputSigned,
12541254 ComponentType InputInterp,
12551255 ComponentType BiasInterp) {
1256- const size_t NumElements = Params.M ;
1257- const size_t BufferSize = elementSize ( Params.CompType ) * NumElements ;
1256+ const size_t NumElements = Params.totalElements () ;
1257+ const size_t BufferSize = Params.totalBytes () ;
12581258
12591259 std::stringstream ExtraDefs;
1260- ExtraDefs << " -DMAT_FILL=" << MatFill;
12611260 ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned;
12621261 ExtraDefs << " -DIN_INTERP=" << static_cast <int >(InputInterp);
12631262 ExtraDefs << " -DBIAS_INTERP=" << static_cast <int >(BiasInterp);
@@ -1267,29 +1266,29 @@ static void runMatVecMulAdd(ID3D12Device *Device,
12671266 compileShader (DxcSupport, MatVecMulAddShader, " cs_6_10" , Args, Verbose);
12681267
12691268 auto Expected = makeExpectedVec (Params.CompType , Params.M ,
1270- MatFill * Params.N + 1 , /* Increment=*/ false );
1269+ static_cast < float >(FillValue * FillValue * Params.N + FillValue) , /* Increment=*/ false );
12711270
1272- auto Op = createComputeOp (MatVecMulAddShader, " cs_6_10" , " UAV(u0), UAV(u1) " ,
1273- Args.c_str ());
1271+ auto Op = createComputeOp (MatVecMulAddShader, " cs_6_10" ,
1272+ " SRV(t0), UAV(u1) " , Args.c_str ());
12741273 addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
12751274 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
12761275 addRootUAV (Op.get (), 0 , " Input" );
12771276 addRootUAV (Op.get (), 1 , " Output" );
12781277
1279- auto Result = runShaderOp (
1280- Device, DxcSupport, std::move (Op),
1281- [NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
1282- st::ShaderOp *) {
1283- VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType , NumElements ,
1284- /* StartingVal=*/ 1 , /* Increment=*/ false ),
1285- " Saw unsupported component type" );
1286- });
1278+ auto Result =
1279+ runShaderOp ( Device, DxcSupport, std::move (Op),
1280+ [NumElements, Params, FillValue ](LPCSTR Name, std::vector<BYTE> &Data,
1281+ st::ShaderOp *) {
1282+ VERIFY_IS_TRUE (fillInputBuffer (Name, Data, Params.CompType ,
1283+ NumElements, /* StartingVal=*/ FillValue , /* Increment=*/ false ),
1284+ " Saw unsupported component type" );
1285+ });
12871286
12881287 MappedData OutData;
12891288 Result->Test ->GetReadBackData (" Output" , &OutData);
12901289
12911290 VERIFY_IS_TRUE (verifyComponentBuffer (Params.CompType , OutData.data (),
1292- Expected, NumElements , Verbose));
1291+ Expected, Params. M , Verbose));
12931292}
12941293
12951294void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16 () {
@@ -1302,7 +1301,7 @@ void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() {
13021301 Params.NumThreads = 1 ;
13031302 Params.Enable16Bit = true ;
13041303 runMatVecMulAdd (D3DDevice, DxcSupport, Params, VerboseLogging,
1305- /* MatFill =*/ 2 . 0f , /* OutputSigned=*/ true , ComponentType::F16,
1304+ /* FillValue =*/ 2 , /* OutputSigned=*/ true , ComponentType::F16,
13061305 ComponentType::F16);
13071306}
13081307
0 commit comments