@@ -1006,11 +1006,44 @@ static bool CheckMatrixLayout(unsigned Input) {
10061006 static_cast <unsigned >(DXIL::DXILMatrixLayout::OuterProductOptimal));
10071007}
10081008
1009+ static bool CheckTransposeForMatrixLayout (DXIL::DXILMatrixLayout Layout,
1010+ bool Transposed) {
1011+ switch (Layout) {
1012+ case DXIL::DXILMatrixLayout::RowMajor:
1013+ case DXIL::DXILMatrixLayout::ColumnMajor:
1014+ return !Transposed;
1015+
1016+ default :
1017+ return true ;
1018+ }
1019+ }
1020+
1021+ static bool CheckUnsignedFlag (Type *VecTy, bool IsUnsigned) {
1022+ Type *ElemTy = VecTy->getScalarType ();
1023+ if (ElemTy->isFloatingPointTy ())
1024+ return !IsUnsigned;
1025+
1026+ return true ;
1027+ }
1028+
1029+ static Value *getMatVecOpIsOutputUnsigned (CallInst *CI, DXIL::OpCode OpCode) {
1030+ switch (OpCode) {
1031+ case DXIL::OpCode::MatVecMul:
1032+ return CI->getOperand (DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx );
1033+ case DXIL::OpCode::MatVecMulAdd:
1034+ return CI->getOperand (DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx );
1035+
1036+ default :
1037+ DXASSERT_NOMSG (false );
1038+ return nullptr ;
1039+ }
1040+ }
1041+
10091042static void ValidateImmOperandsForMatVecOps (CallInst *CI, DXIL::OpCode OpCode,
10101043 ValidationContext &ValCtx) {
10111044
10121045 // Check Common operands
1013- llvm::Value *InputIsUnsigned =
1046+ llvm::Value *IsInputUnsigned =
10141047 CI->getOperand (DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx );
10151048 llvm::Value *InputInterpretation =
10161049 CI->getOperand (DXIL::OperandIndex::kMatVecMulInputInterpretationIdx );
@@ -1024,76 +1057,101 @@ static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode OpCode,
10241057 CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx );
10251058 llvm::Value *MatrixTranspose =
10261059 CI->getOperand (DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx );
1060+ llvm::Value *IsOutputUnsigned = getMatVecOpIsOutputUnsigned (CI, OpCode);
10271061
1028- if (!llvm::isa<llvm::Constant>(InputIsUnsigned)) {
1062+ ConstantInt *IsInputUnsignedConst =
1063+ dyn_cast<llvm::ConstantInt>(IsInputUnsigned);
1064+ if (!IsInputUnsignedConst) {
10291065 ValCtx.EmitInstrError (CI,
10301066 ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1067+ return ;
10311068 }
10321069
1033- if (!llvm::isa<llvm::Constant>(InputInterpretation) ||
1034- !llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1070+ ConstantInt *IsOutputUnsignedConst =
1071+ dyn_cast<llvm::ConstantInt>(IsOutputUnsigned);
1072+ if (!IsOutputUnsignedConst) {
1073+ ValCtx.EmitInstrError (CI,
1074+ ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1075+ return ;
1076+ }
1077+
1078+ ConstantInt *II = dyn_cast<ConstantInt>(InputInterpretation);
1079+ ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1080+ if (!II || !MI) {
10351081 ValCtx.EmitInstrError (
10361082 CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1083+ return ;
10371084 }
10381085
10391086 // Check if InputInterpretation and MatrixInterpretation are valid
1040- ConstantInt *II = cast<ConstantInt>(InputInterpretation);
10411087 auto IIValue = II->getLimitedValue ();
10421088 if (!CheckFromRegisterInterpretations (IIValue)) {
10431089 ValCtx.EmitInstrError (
10441090 CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1091+ return ;
10451092 }
10461093
1047- ConstantInt *MI = cast<ConstantInt>(MatrixInterpretation);
10481094 auto MIValue = MI->getLimitedValue ();
10491095 if (!CheckInMemoryInterpretations (MIValue)) {
10501096 ValCtx.EmitInstrError (CI,
10511097 ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1098+ return ;
10521099 }
10531100
1101+ ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1102+ ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
10541103 if (!llvm::isa<llvm::Constant>(MatrixM) ||
1055- !llvm::isa<llvm::Constant>(MatrixK) ||
1056- !llvm::isa<llvm::Constant>(MatrixLayout) ||
1057- !llvm::isa<llvm::Constant>(MatrixTranspose)) {
1104+ !llvm::isa<llvm::Constant>(MatrixK) || !MatrixLayoutConst ||
1105+ !MatrixTransposeConst) {
10581106 ValCtx.EmitInstrError (CI,
10591107 ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1108+ return ;
10601109 }
10611110
1062- ConstantInt *ML = cast<ConstantInt>(MatrixLayout);
1063- auto MLValue = ML->getLimitedValue ();
1111+ auto MLValue = MatrixLayoutConst->getLimitedValue ();
10641112 if (!CheckMatrixLayout (MLValue)) {
10651113 ValCtx.EmitInstrError (CI,
10661114 ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1115+ return ;
10671116 }
10681117
1069- switch (OpCode) {
1070- case DXIL::OpCode::MatVecMul: {
1071- llvm::Value *OutputIsUnsigned =
1072- CI->getOperand (DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx );
1073- if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1074- ValCtx.EmitInstrError (
1075- CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1076- }
1118+ if (!CheckTransposeForMatrixLayout (
1119+ static_cast <DXIL::DXILMatrixLayout>(MLValue),
1120+ MatrixTransposeConst->getLimitedValue ())) {
1121+ ValCtx.EmitInstrError (
1122+ CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable);
1123+ return ;
1124+ }
10771125
1078- } break ;
1126+ llvm::Value *InputVector =
1127+ CI->getOperand (DXIL::OperandIndex::kMatVecMulInputVectorIdx );
1128+ if (!CheckUnsignedFlag (InputVector->getType (),
1129+ IsInputUnsignedConst->getLimitedValue ())) {
1130+ ValCtx.EmitInstrError (CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1131+ return ;
1132+ }
1133+
1134+ if (!CheckUnsignedFlag (CI->getType (),
1135+ IsOutputUnsignedConst->getLimitedValue ())) {
1136+ ValCtx.EmitInstrError (CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1137+ return ;
1138+ }
1139+
1140+ switch (OpCode) {
10791141 case DXIL::OpCode::MatVecMulAdd: {
1080- llvm::Value *OutputIsUnsigned =
1081- CI->getOperand (DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx );
1082- if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1083- ValCtx.EmitInstrError (
1084- CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1085- }
10861142 llvm::Value *BiasInterpretation =
10871143 CI->getOperand (DXIL::OperandIndex::kMatVecMulAddBiasInterpretation );
10881144 if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
10891145 ValCtx.EmitInstrError (
10901146 CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1147+ return ;
10911148 }
10921149 ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
10931150 auto BIValue = BI->getLimitedValue ();
10941151 if (!CheckInMemoryInterpretations (BIValue)) {
10951152 ValCtx.EmitInstrError (
10961153 CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1154+ return ;
10971155 }
10981156 } break ;
10991157 default :
0 commit comments