@@ -970,6 +970,175 @@ static void ValidateImmOperandForMathDxilOp(CallInst *CI, DXIL::OpCode opcode,
970970 }
971971}
972972
973+ static bool CheckFromRegisterInterpretations (uint32_t Ri) {
974+ std::set<DXIL::ComponentType> ValidSet = {
975+ DXIL::ComponentType::I16, DXIL::ComponentType::U16,
976+ DXIL::ComponentType::I32, DXIL::ComponentType::U32,
977+ DXIL::ComponentType::F16, DXIL::ComponentType::F32,
978+ DXIL::ComponentType::PackedS8x32, DXIL::ComponentType::PackedU8x32,
979+ DXIL::ComponentType::U8, DXIL::ComponentType::I8,
980+ DXIL::ComponentType::F8_E4M3, DXIL::ComponentType::F8_E5M2};
981+
982+ if (ValidSet.find (static_cast <DXIL::ComponentType>(Ri)) != ValidSet.end ()) {
983+ return true ;
984+ }
985+ return false ;
986+ }
987+
988+ static bool CheckInMemoryInterpretations (uint32_t Mi) {
989+ std::set<DXIL::ComponentType> ValidSet = {
990+ DXIL::ComponentType::I16, DXIL::ComponentType::U16,
991+ DXIL::ComponentType::I32, DXIL::ComponentType::U32,
992+ DXIL::ComponentType::F16, DXIL::ComponentType::F32,
993+ DXIL::ComponentType::U8, DXIL::ComponentType::I8,
994+ DXIL::ComponentType::F8_E4M3, DXIL::ComponentType::F8_E5M2};
995+
996+ if (ValidSet.find (static_cast <DXIL::ComponentType>(Mi)) != ValidSet.end ()) {
997+ return true ;
998+ }
999+ return false ;
1000+ }
1001+
1002+ static bool CheckMatrixLayout (uint32_t Ml) {
1003+ std::set<DXIL::DXILMatrixLayout> ValidSet = {
1004+ DXIL::DXILMatrixLayout::RowMajor, DXIL::DXILMatrixLayout::ColumnMajor,
1005+ DXIL::DXILMatrixLayout::MulOptimal,
1006+ DXIL::DXILMatrixLayout::OuterProductOptimal};
1007+
1008+ if (ValidSet.find (static_cast <DXIL::DXILMatrixLayout>(Ml)) !=
1009+ ValidSet.end ()) {
1010+ return true ;
1011+ }
1012+ return false ;
1013+ }
1014+
1015+ static void ValidateImmOperandsForMatVecOps (CallInst *CI, DXIL::OpCode opcode,
1016+ ValidationContext &ValCtx) {
1017+
1018+ // Check Common operands
1019+ llvm::Value *InputIsUnsigned =
1020+ CI->getOperand (DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx );
1021+ llvm::Value *InputInterpretation =
1022+ CI->getOperand (DXIL::OperandIndex::kMatVecMulInputInterpretationIdx );
1023+ llvm::Value *MatrixInterpretation =
1024+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx );
1025+ llvm::Value *MatrixM =
1026+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixMIdx );
1027+ llvm::Value *MatrixK =
1028+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixKIdx );
1029+ llvm::Value *MatrixLayout =
1030+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx );
1031+ llvm::Value *MatrixTranspose =
1032+ CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx );
1033+
1034+ if (!llvm::isa<llvm::Constant>(InputIsUnsigned)) {
1035+ ValCtx.EmitInstrError (CI,
1036+ ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1037+ }
1038+
1039+ if (!llvm::isa<llvm::Constant>(InputInterpretation) ||
1040+ !llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1041+ ValCtx.EmitInstrError (
1042+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1043+ }
1044+
1045+ // Check if InputInterpretation and MatrixInterpretation are valid
1046+ ConstantInt *Ii = cast<ConstantInt>(InputInterpretation);
1047+ auto IiValue = Ii->getLimitedValue ();
1048+ if (!CheckFromRegisterInterpretations (IiValue)) {
1049+ ValCtx.EmitInstrError (
1050+ CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1051+ }
1052+
1053+ ConstantInt *Mi = cast<ConstantInt>(MatrixInterpretation);
1054+ auto MiValue = Mi->getLimitedValue ();
1055+ if (!CheckInMemoryInterpretations (MiValue)) {
1056+ ValCtx.EmitInstrError (CI,
1057+ ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1058+ }
1059+
1060+ if (!llvm::isa<llvm::Constant>(MatrixM) ||
1061+ !llvm::isa<llvm::Constant>(MatrixK) ||
1062+ !llvm::isa<llvm::Constant>(MatrixLayout) ||
1063+ !llvm::isa<llvm::Constant>(MatrixTranspose)) {
1064+ ValCtx.EmitInstrError (CI,
1065+ ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1066+ }
1067+
1068+ ConstantInt *Ml = cast<ConstantInt>(MatrixLayout);
1069+ auto MlValue = Ml->getLimitedValue ();
1070+ if (!CheckMatrixLayout (MlValue)) {
1071+ ValCtx.EmitInstrError (CI,
1072+ ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1073+ }
1074+
1075+ switch (opcode) {
1076+ case DXIL::OpCode::MatVecMul: {
1077+ llvm::Value *OutputIsUnsigned =
1078+ CI->getOperand (DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx );
1079+ if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1080+ ValCtx.EmitInstrError (
1081+ CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1082+ }
1083+
1084+ } break ;
1085+ case DXIL::OpCode::MatVecMulAdd: {
1086+ llvm::Value *OutputIsUnsigned =
1087+ CI->getOperand (DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx );
1088+ if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1089+ ValCtx.EmitInstrError (
1090+ CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1091+ }
1092+ llvm::Value *BiasInterpretation =
1093+ CI->getOperand (DXIL::OperandIndex::kMatVecMulAddBiasInterpretation );
1094+ if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
1095+ ValCtx.EmitInstrError (
1096+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1097+ }
1098+ ConstantInt *Bi = cast<ConstantInt>(BiasInterpretation);
1099+ auto BiValue = Bi->getLimitedValue ();
1100+ if (!CheckInMemoryInterpretations (BiValue)) {
1101+ ValCtx.EmitInstrError (
1102+ CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1103+ }
1104+ } break ;
1105+ default :
1106+ break ;
1107+ }
1108+ }
1109+
1110+ static void ValidateImmOperandsForOuterProdAcc (CallInst *CI,
1111+ DXIL::OpCode opcode,
1112+ ValidationContext &ValCtx) {
1113+
1114+ llvm::Value *MatrixInterpretation =
1115+ CI->getOperand (DXIL::OperandIndex::kOuterProdAccMatrixInterpretation );
1116+ llvm::Value *MatrixLayout =
1117+ CI->getOperand (DXIL::OperandIndex::kOuterProdAccMatrixLayout );
1118+
1119+ if (!llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1120+ ValCtx.EmitInstrError (
1121+ CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1122+ }
1123+ ConstantInt *Mi = cast<ConstantInt>(MatrixInterpretation);
1124+ auto MiValue = Mi->getLimitedValue ();
1125+ if (!CheckInMemoryInterpretations (MiValue)) {
1126+ ValCtx.EmitInstrError (CI,
1127+ ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1128+ }
1129+
1130+ if (!llvm::isa<llvm::Constant>(MatrixLayout)) {
1131+ ValCtx.EmitInstrError (CI,
1132+ ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1133+ }
1134+ ConstantInt *Ml = cast<ConstantInt>(MatrixLayout);
1135+ auto MlValue = Ml->getLimitedValue ();
1136+ if (!CheckMatrixLayout (MlValue)) {
1137+ ValCtx.EmitInstrError (CI,
1138+ ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1139+ }
1140+ }
1141+
9731142// Validate the type-defined mask compared to the store value mask which
9741143// indicates which parts were defined returns true if caller should continue
9751144// validation
@@ -1942,6 +2111,16 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
19422111 GetLaunchTypeStr (nodeLaunchType)});
19432112
19442113 break ;
2114+ case DXIL::OpCode::MatVecMul:
2115+ case DXIL::OpCode::MatVecMulAdd:
2116+ ValidateImmOperandsForMatVecOps (CI, opcode, ValCtx);
2117+ break ;
2118+ case DXIL::OpCode::OuterProductAccumulate:
2119+ ValidateImmOperandsForOuterProdAcc (CI, opcode, ValCtx);
2120+ break ;
2121+ case DXIL::OpCode::VectorAccumulate:
2122+
2123+ break ;
19452124
19462125 default :
19472126 // TODO: make sure every opcode is checked.
0 commit comments