Skip to content

Commit 5c61273

Browse files
committed
Diagnose unsigned<>type mismatches and transpose<>layout mismatch
1 parent 55d8fc7 commit 5c61273

2 files changed

Lines changed: 94 additions & 26 deletions

File tree

lib/DxilValidation/DxilValidation.cpp

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,11 +1006,44 @@ static bool CheckMatrixLayout(unsigned Input) {
10061006
static_cast<unsigned>(DXIL::DXILMatrixLayout::OuterProductOptimal));
10071007
}
10081008

1009+
static bool CheckTransposeForMatrixLayout(DXIL::DXILMatrixLayout Layout,
1010+
bool Transposed) {
1011+
switch (Layout) {
1012+
case DXIL::DXILMatrixLayout::RowMajor:
1013+
case DXIL::DXILMatrixLayout::ColumnMajor:
1014+
return !Transposed;
1015+
1016+
default:
1017+
return true;
1018+
}
1019+
}
1020+
1021+
static bool CheckUnsignedFlag(Type *VecTy, bool IsUnsigned) {
1022+
Type *ElemTy = VecTy->getScalarType();
1023+
if (ElemTy->isFloatingPointTy())
1024+
return !IsUnsigned;
1025+
1026+
return true;
1027+
}
1028+
1029+
static Value *getMatVecOpIsOutputUnsigned(CallInst *CI, DXIL::OpCode OpCode) {
1030+
switch (OpCode) {
1031+
case DXIL::OpCode::MatVecMul:
1032+
return CI->getOperand(DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx);
1033+
case DXIL::OpCode::MatVecMulAdd:
1034+
return CI->getOperand(DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx);
1035+
1036+
default:
1037+
DXASSERT_NOMSG(false);
1038+
return nullptr;
1039+
}
1040+
}
1041+
10091042
static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode OpCode,
10101043
ValidationContext &ValCtx) {
10111044

10121045
// Check Common operands
1013-
llvm::Value *InputIsUnsigned =
1046+
llvm::Value *IsInputUnsigned =
10141047
CI->getOperand(DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx);
10151048
llvm::Value *InputInterpretation =
10161049
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputInterpretationIdx);
@@ -1024,76 +1057,101 @@ static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode OpCode,
10241057
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx);
10251058
llvm::Value *MatrixTranspose =
10261059
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx);
1060+
llvm::Value *IsOutputUnsigned = getMatVecOpIsOutputUnsigned(CI, OpCode);
10271061

1028-
if (!llvm::isa<llvm::Constant>(InputIsUnsigned)) {
1062+
ConstantInt *IsInputUnsignedConst =
1063+
dyn_cast<llvm::ConstantInt>(IsInputUnsigned);
1064+
if (!IsInputUnsignedConst) {
10291065
ValCtx.EmitInstrError(CI,
10301066
ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1067+
return;
10311068
}
10321069

1033-
if (!llvm::isa<llvm::Constant>(InputInterpretation) ||
1034-
!llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1070+
ConstantInt *IsOutputUnsignedConst =
1071+
dyn_cast<llvm::ConstantInt>(IsOutputUnsigned);
1072+
if (!IsOutputUnsignedConst) {
1073+
ValCtx.EmitInstrError(CI,
1074+
ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1075+
return;
1076+
}
1077+
1078+
ConstantInt *II = dyn_cast<ConstantInt>(InputInterpretation);
1079+
ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1080+
if (!II || !MI) {
10351081
ValCtx.EmitInstrError(
10361082
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1083+
return;
10371084
}
10381085

10391086
// Check if InputInterpretation and MatrixInterpretation are valid
1040-
ConstantInt *II = cast<ConstantInt>(InputInterpretation);
10411087
auto IIValue = II->getLimitedValue();
10421088
if (!CheckFromRegisterInterpretations(IIValue)) {
10431089
ValCtx.EmitInstrError(
10441090
CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1091+
return;
10451092
}
10461093

1047-
ConstantInt *MI = cast<ConstantInt>(MatrixInterpretation);
10481094
auto MIValue = MI->getLimitedValue();
10491095
if (!CheckInMemoryInterpretations(MIValue)) {
10501096
ValCtx.EmitInstrError(CI,
10511097
ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1098+
return;
10521099
}
10531100

1101+
ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1102+
ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
10541103
if (!llvm::isa<llvm::Constant>(MatrixM) ||
1055-
!llvm::isa<llvm::Constant>(MatrixK) ||
1056-
!llvm::isa<llvm::Constant>(MatrixLayout) ||
1057-
!llvm::isa<llvm::Constant>(MatrixTranspose)) {
1104+
!llvm::isa<llvm::Constant>(MatrixK) || !MatrixLayoutConst ||
1105+
!MatrixTransposeConst) {
10581106
ValCtx.EmitInstrError(CI,
10591107
ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1108+
return;
10601109
}
10611110

1062-
ConstantInt *ML = cast<ConstantInt>(MatrixLayout);
1063-
auto MLValue = ML->getLimitedValue();
1111+
auto MLValue = MatrixLayoutConst->getLimitedValue();
10641112
if (!CheckMatrixLayout(MLValue)) {
10651113
ValCtx.EmitInstrError(CI,
10661114
ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1115+
return;
10671116
}
10681117

1069-
switch (OpCode) {
1070-
case DXIL::OpCode::MatVecMul: {
1071-
llvm::Value *OutputIsUnsigned =
1072-
CI->getOperand(DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx);
1073-
if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1074-
ValCtx.EmitInstrError(
1075-
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1076-
}
1118+
if (!CheckTransposeForMatrixLayout(
1119+
static_cast<DXIL::DXILMatrixLayout>(MLValue),
1120+
MatrixTransposeConst->getLimitedValue())) {
1121+
ValCtx.EmitInstrError(
1122+
CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable);
1123+
return;
1124+
}
10771125

1078-
} break;
1126+
llvm::Value *InputVector =
1127+
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputVectorIdx);
1128+
if (!CheckUnsignedFlag(InputVector->getType(),
1129+
IsInputUnsignedConst->getLimitedValue())) {
1130+
ValCtx.EmitInstrError(CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1131+
return;
1132+
}
1133+
1134+
if (!CheckUnsignedFlag(CI->getType(),
1135+
IsOutputUnsignedConst->getLimitedValue())) {
1136+
ValCtx.EmitInstrError(CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1137+
return;
1138+
}
1139+
1140+
switch (OpCode) {
10791141
case DXIL::OpCode::MatVecMulAdd: {
1080-
llvm::Value *OutputIsUnsigned =
1081-
CI->getOperand(DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx);
1082-
if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1083-
ValCtx.EmitInstrError(
1084-
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1085-
}
10861142
llvm::Value *BiasInterpretation =
10871143
CI->getOperand(DXIL::OperandIndex::kMatVecMulAddBiasInterpretation);
10881144
if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
10891145
ValCtx.EmitInstrError(
10901146
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1147+
return;
10911148
}
10921149
ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
10931150
auto BIValue = BI->getLimitedValue();
10941151
if (!CheckInMemoryInterpretations(BIValue)) {
10951152
ValCtx.EmitInstrError(
10961153
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1154+
return;
10971155
}
10981156
} break;
10991157
default:

utils/hct/hctdb.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7928,6 +7928,16 @@ def build_valrules(self):
79287928
"Matrix Layout for Linalg ops not in valid set",
79297929
)
79307930

7931+
self.add_valrule(
7932+
"Instr.LinalgMatrixLayoutNotTransposable",
7933+
"Matrix Layout not transposable",
7934+
)
7935+
7936+
self.add_valrule(
7937+
"Instr.LinalgNotAnUnsignedType",
7938+
"Unsigned flag set for signed type",
7939+
)
7940+
79317941
# Some legacy rules:
79327942
# - space is only supported for shader targets 5.1 and higher
79337943
# - multiple rules regarding derivatives, which isn't a supported feature for DXIL

0 commit comments

Comments
 (0)