@@ -992,14 +992,29 @@ static bool CheckLinalgInterpretation(uint32_t Input, bool InRegister) {
992992 }
993993}
994994
995- static bool CheckMatrixLayout (unsigned Input ) {
996- return Input <=
995+ static bool CheckMatrixLayoutForMatVecMulOps (unsigned Layout ) {
996+ return Layout <=
997997 static_cast <unsigned >(DXIL::LinalgMatrixLayout::OuterProductOptimal);
998998}
999999
1000- static bool CheckTransposeForMatrixLayout (DXIL::LinalgMatrixLayout Layout,
1001- bool Transposed) {
1002- switch (Layout) {
1000+ std::string GetMatrixLayoutStr (unsigned Layout) {
1001+ switch (static_cast <DXIL::LinalgMatrixLayout>(Layout)) {
1002+ case DXIL::LinalgMatrixLayout::RowMajor:
1003+ return " RowMajor" ;
1004+ case DXIL::LinalgMatrixLayout::ColumnMajor:
1005+ return " ColumnMajor" ;
1006+ case DXIL::LinalgMatrixLayout::MulOptimal:
1007+ return " MulOptimal" ;
1008+ case DXIL::LinalgMatrixLayout::OuterProductOptimal:
1009+ return " OuterProductOptimal" ;
1010+ default :
1011+ DXASSERT_NOMSG (false );
1012+ return " Invalid" ;
1013+ }
1014+ }
1015+
1016+ static bool CheckTransposeForMatrixLayout (unsigned Layout, bool Transposed) {
1017+ switch (static_cast <DXIL::LinalgMatrixLayout>(Layout)) {
10031018 case DXIL::LinalgMatrixLayout::RowMajor:
10041019 case DXIL::LinalgMatrixLayout::ColumnMajor:
10051020 return !Transposed;
@@ -1033,115 +1048,151 @@ static Value *GetMatVecOpIsOutputUnsigned(CallInst *CI, DXIL::OpCode OpCode) {
10331048static void ValidateImmOperandsForMatVecOps (CallInst *CI, DXIL::OpCode OpCode,
10341049 ValidationContext &ValCtx) {
10351050
1036- // Check Common operands
10371051 llvm::Value *IsInputUnsigned =
10381052 CI->getOperand (DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx );
1039- llvm::Value *InputInterpretation =
1040- CI->getOperand (DXIL::OperandIndex::kMatVecMulInputInterpretationIdx );
1041- llvm::Value *MatrixInterpretation =
1042- CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx );
1043- llvm::Value *MatrixM =
1044- CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixMIdx );
1045- llvm::Value *MatrixK =
1046- CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixKIdx );
1047- llvm::Value *MatrixLayout =
1048- CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx );
1049- llvm::Value *MatrixTranspose =
1050- CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx );
1051- llvm::Value *IsOutputUnsigned = GetMatVecOpIsOutputUnsigned (CI, OpCode);
1052-
10531053 ConstantInt *IsInputUnsignedConst =
10541054 dyn_cast<llvm::ConstantInt>(IsInputUnsigned);
10551055 if (!IsInputUnsignedConst) {
1056- ValCtx.EmitInstrError (CI,
1057- ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1056+ ValCtx.EmitInstrFormatError (
1057+ CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst,
1058+ {" IsInputUnsigned" });
10581059 return ;
10591060 }
10601061
1062+ llvm::Value *IsOutputUnsigned = GetMatVecOpIsOutputUnsigned (CI, OpCode);
10611063 ConstantInt *IsOutputUnsignedConst =
10621064 dyn_cast<llvm::ConstantInt>(IsOutputUnsigned);
10631065 if (!IsOutputUnsignedConst) {
1064- ValCtx.EmitInstrError (CI,
1065- ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1066+ ValCtx.EmitInstrFormatError (
1067+ CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst,
1068+ {" IsOutputUnsigned" });
10661069 return ;
10671070 }
10681071
1072+ llvm::Value *InputInterpretation =
1073+ CI->getOperand (DXIL::OperandIndex::kMatVecMulInputInterpretationIdx );
10691074 ConstantInt *II = dyn_cast<ConstantInt>(InputInterpretation);
1070- ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1071- if (!II || !MI) {
1072- ValCtx. EmitInstrError (
1073- CI, ValidationRule::InstrLinalgInterpretationParamAreConst );
1075+ if (!II) {
1076+ ValCtx. EmitInstrFormatError (
1077+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1078+ { " InputInterpretation " } );
10741079 return ;
10751080 }
1076-
1077- // Check if InputInterpretation and MatrixInterpretation are valid
10781081 uint64_t IIValue = II->getLimitedValue ();
10791082 if (!CheckLinalgInterpretation (IIValue, true )) {
1080- ValCtx.EmitInstrError (
1081- CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1083+ ValCtx.EmitInstrFormatError (
1084+ CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue,
1085+ {std::to_string (IIValue), " Input" });
10821086 return ;
10831087 }
10841088
1089+ llvm::Value *MatrixInterpretation =
1090+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx );
1091+ ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1092+ if (!MI) {
1093+ ValCtx.EmitInstrFormatError (
1094+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1095+ {" MatrixInterpretation" });
1096+ return ;
1097+ }
10851098 uint64_t MIValue = MI->getLimitedValue ();
10861099 if (!CheckLinalgInterpretation (MIValue, false )) {
1087- ValCtx.EmitInstrError (CI,
1088- ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1100+ ValCtx.EmitInstrFormatError (
1101+ CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1102+ {std::to_string (MIValue), " Matrix" });
10891103 return ;
10901104 }
10911105
1092- ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1093- ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
1094- if (!llvm::isa<llvm::Constant>(MatrixM) ||
1095- !llvm::isa<llvm::Constant>(MatrixK) || !MatrixLayoutConst ||
1096- !MatrixTransposeConst) {
1097- ValCtx.EmitInstrError (CI,
1098- ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1106+ llvm::Value *MatrixM =
1107+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixMIdx );
1108+ if (!llvm::isa<llvm::Constant>(MatrixM)) {
1109+ ValCtx.EmitInstrFormatError (
1110+ CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1111+ {" Matrix M dimension" });
1112+ return ;
1113+ }
1114+
1115+ llvm::Value *MatrixK =
1116+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixKIdx );
1117+ if (!llvm::isa<llvm::Constant>(MatrixK)) {
1118+ ValCtx.EmitInstrFormatError (
1119+ CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1120+ {" Matrix K dimension" });
10991121 return ;
11001122 }
11011123
1124+ llvm::Value *MatrixLayout =
1125+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx );
1126+
1127+ ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
1128+ if (!MatrixLayoutConst) {
1129+ ValCtx.EmitInstrFormatError (
1130+ CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1131+ {" Matrix Layout" });
1132+ return ;
1133+ }
11021134 uint64_t MLValue = MatrixLayoutConst->getLimitedValue ();
1103- if (!CheckMatrixLayout (MLValue)) {
1104- ValCtx.EmitInstrError (CI,
1105- ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1135+ if (!CheckMatrixLayoutForMatVecMulOps (MLValue)) {
1136+ ValCtx.EmitInstrFormatError (
1137+ CI, ValidationRule::InstrLinalgInvalidMatrixLayoutValueForMatVecOps,
1138+ {std::to_string (MLValue),
1139+ std::to_string (
1140+ static_cast <unsigned >(DXIL::LinalgMatrixLayout::RowMajor)),
1141+ std::to_string (static_cast <unsigned >(
1142+ DXIL::LinalgMatrixLayout::OuterProductOptimal))});
1143+ return ;
1144+ }
1145+
1146+ llvm::Value *MatrixTranspose =
1147+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx );
1148+ ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1149+ if (!MatrixTransposeConst) {
1150+ ValCtx.EmitInstrFormatError (
1151+ CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1152+ {" MatrixTranspose" });
11061153 return ;
11071154 }
11081155
1109- if (!CheckTransposeForMatrixLayout (
1110- static_cast <DXIL::LinalgMatrixLayout>(MLValue),
1111- MatrixTransposeConst-> getLimitedValue ())) {
1112- ValCtx. EmitInstrError (
1113- CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable );
1156+ if (!CheckTransposeForMatrixLayout (MLValue,
1157+ MatrixTransposeConst-> getLimitedValue ())) {
1158+ ValCtx. EmitInstrFormatError (
1159+ CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable,
1160+ { GetMatrixLayoutStr (MLValue)} );
11141161 return ;
11151162 }
11161163
11171164 llvm::Value *InputVector =
11181165 CI->getOperand (DXIL::OperandIndex::kMatVecMulInputVectorIdx );
11191166 if (!CheckUnsignedFlag (InputVector->getType (),
11201167 IsInputUnsignedConst->getLimitedValue ())) {
1121- ValCtx.EmitInstrError (CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1168+ ValCtx.EmitInstrFormatError (
1169+ CI, ValidationRule::InstrLinalgNotAnUnsignedType, {" Input" });
11221170 return ;
11231171 }
11241172
11251173 if (!CheckUnsignedFlag (CI->getType (),
11261174 IsOutputUnsignedConst->getLimitedValue ())) {
1127- ValCtx.EmitInstrError (CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1175+ ValCtx.EmitInstrFormatError (
1176+ CI, ValidationRule::InstrLinalgNotAnUnsignedType, {" Output" });
11281177 return ;
11291178 }
11301179
11311180 switch (OpCode) {
11321181 case DXIL::OpCode::MatVecMulAdd: {
11331182 llvm::Value *BiasInterpretation =
11341183 CI->getOperand (DXIL::OperandIndex::kMatVecMulAddBiasInterpretation );
1135- if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
1136- ValCtx.EmitInstrError (
1137- CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1184+ ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
1185+ if (!BI) {
1186+ ValCtx.EmitInstrFormatError (
1187+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1188+ {" BiasInterpretation" });
11381189 return ;
11391190 }
1140- ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
11411191 uint64_t BIValue = BI->getLimitedValue ();
11421192 if (!CheckLinalgInterpretation (BIValue, false )) {
1143- ValCtx.EmitInstrError (
1144- CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1193+ ValCtx.EmitInstrFormatError (
1194+ CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1195+ {std::to_string (BIValue), " Bias vector" });
11451196 return ;
11461197 }
11471198 } break ;
@@ -1155,26 +1206,29 @@ static void ValidateImmOperandsForOuterProdAcc(CallInst *CI,
11551206
11561207 llvm::Value *MatrixInterpretation =
11571208 CI->getOperand (DXIL::OperandIndex::kOuterProdAccMatrixInterpretation );
1158- llvm::Value *MatrixLayout =
1159- CI->getOperand (DXIL::OperandIndex::kOuterProdAccMatrixLayout );
1160-
1161- if (!llvm::isa<llvm::Constant>(MatrixInterpretation))
1162- ValCtx.EmitInstrError (
1163- CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
11641209 ConstantInt *MI = cast<ConstantInt>(MatrixInterpretation);
1210+ if (!MI) {
1211+ ValCtx.EmitInstrFormatError (
1212+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1213+ {" MatrixInterpretation" });
1214+ return ;
1215+ }
11651216 uint64_t MIValue = MI->getLimitedValue ();
1166- if (!CheckLinalgInterpretation (MIValue, false ))
1167- ValCtx.EmitInstrError (CI,
1168- ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1217+ if (!CheckLinalgInterpretation (MIValue, false )) {
1218+ ValCtx.EmitInstrFormatError (
1219+ CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1220+ {std::to_string (MIValue), " Matrix" });
1221+ return ;
1222+ }
11691223
1170- if (! llvm::isa<llvm::Constant>( MatrixLayout))
1171- ValCtx. EmitInstrError (CI,
1172- ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1173- ConstantInt *ML = cast<ConstantInt>(MatrixLayout);
1174- uint64_t MLValue = ML-> getLimitedValue ();
1175- if (! CheckMatrixLayout (MLValue))
1176- ValCtx. EmitInstrError (CI,
1177- ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1224+ llvm::Value * MatrixLayout =
1225+ CI-> getOperand (DXIL::OperandIndex:: kOuterProdAccMatrixLayout );
1226+ if (!llvm::isa<llvm::Constant>(MatrixLayout)) {
1227+ ValCtx. EmitInstrFormatError (
1228+ CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1229+ { " MatrixLayout " });
1230+ return ;
1231+ }
11781232}
11791233
11801234// Validate the type-defined mask compared to the store value mask which
0 commit comments