Skip to content

Commit d5f03f2

Browse files
Remvoved MatrixLayout check for OuterProductAccumulate, updated DXIL validation errors per review feedback, some cleanup
1 parent 791dff3 commit d5f03f2

3 files changed

Lines changed: 161 additions & 99 deletions

File tree

docs/DXIL.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,14 +3138,14 @@ INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]
31383138
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
31393139
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
31403140
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
3141-
INSTR.LINALGINTERPRETATIONPARAMARECONST Interpretation values are constants
3142-
INSTR.LINALGINVALIDMATRIXLAYOUTVALUE Matrix Layout for Linalg ops not in valid set
3143-
INSTR.LINALGINVALIDMEMORYINTERPVALUE In Memory Interpolation value not in valid set
3144-
INSTR.LINALGINVALIDREGISTERINTERPVALUE From Register Interpretation value not in valid set
3145-
INSTR.LINALGMATRIXLAYOUTNOTTRANSPOSABLE Matrix Layout not transposable
3146-
INSTR.LINALGMATRIXSHAPEPARAMSARECONST Matrix Layout, Dimensions and isTranspose are immediate constants
3147-
INSTR.LINALGNOTANUNSIGNEDTYPE Unsigned flag set for signed type
3148-
INSTR.MATVECOPISUNSIGNEDFLAGSARECONST MatVec Ops Is Unsigned flag is a constant
3141+
INSTR.LINALGINTERPRETATIONPARAMARECONST In Linalg operations, Interpretation value is a constant.
3142+
INSTR.LINALGINVALIDMATRIXLAYOUTVALUEFORMATVECOPS Matrix Layout for Linalg Mul/MulAdd operation must be valid.
3143+
INSTR.LINALGINVALIDMEMORYINTERPVALUE In Memory Interpolation value must be valid.
3144+
INSTR.LINALGINVALIDREGISTERINTERPVALUE From Register Interpretation value must be valid.
3145+
INSTR.LINALGMATRIXLAYOUTNOTTRANSPOSABLE Row Major and Column Major matrix layouts are not transposable.
3146+
INSTR.LINALGMATRIXSHAPEPARAMSARECONST Matrix Layout, Dimensions and isTranspose are constants
3147+
INSTR.LINALGNOTANUNSIGNEDTYPE Unsigned flag set for a float signed type
3148+
INSTR.MATVECOPISUNSIGNEDFLAGSARECONST In Linalg Mul/MulAdd functions, IsUnsigned flag is a constant.
31493149
INSTR.MAYREORDERTHREADUNDEFCOHERENCEHINTPARAM Use of undef coherence hint or num coherence hint bits in MaybeReorderThread.
31503150
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.
31513151
INSTR.MINPRECISONBITCAST Bitcast on minprecison types is not allowed.

lib/DxilValidation/DxilValidation.cpp

Lines changed: 128 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -992,14 +992,29 @@ static bool CheckLinalgInterpretation(uint32_t Input, bool InRegister) {
992992
}
993993
}
994994

995-
static bool CheckMatrixLayout(unsigned Input) {
996-
return Input <=
995+
static bool CheckMatrixLayoutForMatVecMulOps(unsigned Layout) {
996+
return Layout <=
997997
static_cast<unsigned>(DXIL::LinalgMatrixLayout::OuterProductOptimal);
998998
}
999999

1000-
static bool CheckTransposeForMatrixLayout(DXIL::LinalgMatrixLayout Layout,
1001-
bool Transposed) {
1002-
switch (Layout) {
1000+
std::string GetMatrixLayoutStr(unsigned Layout) {
1001+
switch (static_cast<DXIL::LinalgMatrixLayout>(Layout)) {
1002+
case DXIL::LinalgMatrixLayout::RowMajor:
1003+
return "RowMajor";
1004+
case DXIL::LinalgMatrixLayout::ColumnMajor:
1005+
return "ColumnMajor";
1006+
case DXIL::LinalgMatrixLayout::MulOptimal:
1007+
return "MulOptimal";
1008+
case DXIL::LinalgMatrixLayout::OuterProductOptimal:
1009+
return "OuterProductOptimal";
1010+
default:
1011+
DXASSERT_NOMSG(false);
1012+
return "Invalid";
1013+
}
1014+
}
1015+
1016+
static bool CheckTransposeForMatrixLayout(unsigned Layout, bool Transposed) {
1017+
switch (static_cast<DXIL::LinalgMatrixLayout>(Layout)) {
10031018
case DXIL::LinalgMatrixLayout::RowMajor:
10041019
case DXIL::LinalgMatrixLayout::ColumnMajor:
10051020
return !Transposed;
@@ -1033,115 +1048,151 @@ static Value *GetMatVecOpIsOutputUnsigned(CallInst *CI, DXIL::OpCode OpCode) {
10331048
static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode OpCode,
10341049
ValidationContext &ValCtx) {
10351050

1036-
// Check Common operands
10371051
llvm::Value *IsInputUnsigned =
10381052
CI->getOperand(DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx);
1039-
llvm::Value *InputInterpretation =
1040-
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputInterpretationIdx);
1041-
llvm::Value *MatrixInterpretation =
1042-
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx);
1043-
llvm::Value *MatrixM =
1044-
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixMIdx);
1045-
llvm::Value *MatrixK =
1046-
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixKIdx);
1047-
llvm::Value *MatrixLayout =
1048-
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx);
1049-
llvm::Value *MatrixTranspose =
1050-
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx);
1051-
llvm::Value *IsOutputUnsigned = GetMatVecOpIsOutputUnsigned(CI, OpCode);
1052-
10531053
ConstantInt *IsInputUnsignedConst =
10541054
dyn_cast<llvm::ConstantInt>(IsInputUnsigned);
10551055
if (!IsInputUnsignedConst) {
1056-
ValCtx.EmitInstrError(CI,
1057-
ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1056+
ValCtx.EmitInstrFormatError(
1057+
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst,
1058+
{"IsInputUnsigned"});
10581059
return;
10591060
}
10601061

1062+
llvm::Value *IsOutputUnsigned = GetMatVecOpIsOutputUnsigned(CI, OpCode);
10611063
ConstantInt *IsOutputUnsignedConst =
10621064
dyn_cast<llvm::ConstantInt>(IsOutputUnsigned);
10631065
if (!IsOutputUnsignedConst) {
1064-
ValCtx.EmitInstrError(CI,
1065-
ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1066+
ValCtx.EmitInstrFormatError(
1067+
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst,
1068+
{"IsOutputUnsigned"});
10661069
return;
10671070
}
10681071

1072+
llvm::Value *InputInterpretation =
1073+
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputInterpretationIdx);
10691074
ConstantInt *II = dyn_cast<ConstantInt>(InputInterpretation);
1070-
ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1071-
if (!II || !MI) {
1072-
ValCtx.EmitInstrError(
1073-
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1075+
if (!II) {
1076+
ValCtx.EmitInstrFormatError(
1077+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1078+
{"InputInterpretation"});
10741079
return;
10751080
}
1076-
1077-
// Check if InputInterpretation and MatrixInterpretation are valid
10781081
uint64_t IIValue = II->getLimitedValue();
10791082
if (!CheckLinalgInterpretation(IIValue, true)) {
1080-
ValCtx.EmitInstrError(
1081-
CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1083+
ValCtx.EmitInstrFormatError(
1084+
CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue,
1085+
{std::to_string(IIValue), "Input"});
10821086
return;
10831087
}
10841088

1089+
llvm::Value *MatrixInterpretation =
1090+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx);
1091+
ConstantInt *MI = dyn_cast<ConstantInt>(MatrixInterpretation);
1092+
if (!MI) {
1093+
ValCtx.EmitInstrFormatError(
1094+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1095+
{"MatrixInterpretation"});
1096+
return;
1097+
}
10851098
uint64_t MIValue = MI->getLimitedValue();
10861099
if (!CheckLinalgInterpretation(MIValue, false)) {
1087-
ValCtx.EmitInstrError(CI,
1088-
ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1100+
ValCtx.EmitInstrFormatError(
1101+
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1102+
{std::to_string(MIValue), "Matrix"});
10891103
return;
10901104
}
10911105

1092-
ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1093-
ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
1094-
if (!llvm::isa<llvm::Constant>(MatrixM) ||
1095-
!llvm::isa<llvm::Constant>(MatrixK) || !MatrixLayoutConst ||
1096-
!MatrixTransposeConst) {
1097-
ValCtx.EmitInstrError(CI,
1098-
ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1106+
llvm::Value *MatrixM =
1107+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixMIdx);
1108+
if (!llvm::isa<llvm::Constant>(MatrixM)) {
1109+
ValCtx.EmitInstrFormatError(
1110+
CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1111+
{"Matrix M dimension"});
1112+
return;
1113+
}
1114+
1115+
llvm::Value *MatrixK =
1116+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixKIdx);
1117+
if (!llvm::isa<llvm::Constant>(MatrixK)) {
1118+
ValCtx.EmitInstrFormatError(
1119+
CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1120+
{"Matrix K dimension"});
10991121
return;
11001122
}
11011123

1124+
llvm::Value *MatrixLayout =
1125+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx);
1126+
1127+
ConstantInt *MatrixLayoutConst = dyn_cast<ConstantInt>(MatrixLayout);
1128+
if (!MatrixLayoutConst) {
1129+
ValCtx.EmitInstrFormatError(
1130+
CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1131+
{"Matrix Layout"});
1132+
return;
1133+
}
11021134
uint64_t MLValue = MatrixLayoutConst->getLimitedValue();
1103-
if (!CheckMatrixLayout(MLValue)) {
1104-
ValCtx.EmitInstrError(CI,
1105-
ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1135+
if (!CheckMatrixLayoutForMatVecMulOps(MLValue)) {
1136+
ValCtx.EmitInstrFormatError(
1137+
CI, ValidationRule::InstrLinalgInvalidMatrixLayoutValueForMatVecOps,
1138+
{std::to_string(MLValue),
1139+
std::to_string(
1140+
static_cast<unsigned>(DXIL::LinalgMatrixLayout::RowMajor)),
1141+
std::to_string(static_cast<unsigned>(
1142+
DXIL::LinalgMatrixLayout::OuterProductOptimal))});
1143+
return;
1144+
}
1145+
1146+
llvm::Value *MatrixTranspose =
1147+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx);
1148+
ConstantInt *MatrixTransposeConst = dyn_cast<ConstantInt>(MatrixTranspose);
1149+
if (!MatrixTransposeConst) {
1150+
ValCtx.EmitInstrFormatError(
1151+
CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1152+
{"MatrixTranspose"});
11061153
return;
11071154
}
11081155

1109-
if (!CheckTransposeForMatrixLayout(
1110-
static_cast<DXIL::LinalgMatrixLayout>(MLValue),
1111-
MatrixTransposeConst->getLimitedValue())) {
1112-
ValCtx.EmitInstrError(
1113-
CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable);
1156+
if (!CheckTransposeForMatrixLayout(MLValue,
1157+
MatrixTransposeConst->getLimitedValue())) {
1158+
ValCtx.EmitInstrFormatError(
1159+
CI, ValidationRule::InstrLinalgMatrixLayoutNotTransposable,
1160+
{GetMatrixLayoutStr(MLValue)});
11141161
return;
11151162
}
11161163

11171164
llvm::Value *InputVector =
11181165
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputVectorIdx);
11191166
if (!CheckUnsignedFlag(InputVector->getType(),
11201167
IsInputUnsignedConst->getLimitedValue())) {
1121-
ValCtx.EmitInstrError(CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1168+
ValCtx.EmitInstrFormatError(
1169+
CI, ValidationRule::InstrLinalgNotAnUnsignedType, {"Input"});
11221170
return;
11231171
}
11241172

11251173
if (!CheckUnsignedFlag(CI->getType(),
11261174
IsOutputUnsignedConst->getLimitedValue())) {
1127-
ValCtx.EmitInstrError(CI, ValidationRule::InstrLinalgNotAnUnsignedType);
1175+
ValCtx.EmitInstrFormatError(
1176+
CI, ValidationRule::InstrLinalgNotAnUnsignedType, {"Output"});
11281177
return;
11291178
}
11301179

11311180
switch (OpCode) {
11321181
case DXIL::OpCode::MatVecMulAdd: {
11331182
llvm::Value *BiasInterpretation =
11341183
CI->getOperand(DXIL::OperandIndex::kMatVecMulAddBiasInterpretation);
1135-
if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
1136-
ValCtx.EmitInstrError(
1137-
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1184+
ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
1185+
if (!BI) {
1186+
ValCtx.EmitInstrFormatError(
1187+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1188+
{"BiasInterpretation"});
11381189
return;
11391190
}
1140-
ConstantInt *BI = cast<ConstantInt>(BiasInterpretation);
11411191
uint64_t BIValue = BI->getLimitedValue();
11421192
if (!CheckLinalgInterpretation(BIValue, false)) {
1143-
ValCtx.EmitInstrError(
1144-
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1193+
ValCtx.EmitInstrFormatError(
1194+
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1195+
{std::to_string(BIValue), "Bias vector"});
11451196
return;
11461197
}
11471198
} break;
@@ -1155,26 +1206,29 @@ static void ValidateImmOperandsForOuterProdAcc(CallInst *CI,
11551206

11561207
llvm::Value *MatrixInterpretation =
11571208
CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixInterpretation);
1158-
llvm::Value *MatrixLayout =
1159-
CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixLayout);
1160-
1161-
if (!llvm::isa<llvm::Constant>(MatrixInterpretation))
1162-
ValCtx.EmitInstrError(
1163-
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
11641209
ConstantInt *MI = cast<ConstantInt>(MatrixInterpretation);
1210+
if (!MI) {
1211+
ValCtx.EmitInstrFormatError(
1212+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst,
1213+
{"MatrixInterpretation"});
1214+
return;
1215+
}
11651216
uint64_t MIValue = MI->getLimitedValue();
1166-
if (!CheckLinalgInterpretation(MIValue, false))
1167-
ValCtx.EmitInstrError(CI,
1168-
ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1217+
if (!CheckLinalgInterpretation(MIValue, false)) {
1218+
ValCtx.EmitInstrFormatError(
1219+
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue,
1220+
{std::to_string(MIValue), "Matrix"});
1221+
return;
1222+
}
11691223

1170-
if (!llvm::isa<llvm::Constant>(MatrixLayout))
1171-
ValCtx.EmitInstrError(CI,
1172-
ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1173-
ConstantInt *ML = cast<ConstantInt>(MatrixLayout);
1174-
uint64_t MLValue = ML->getLimitedValue();
1175-
if (!CheckMatrixLayout(MLValue))
1176-
ValCtx.EmitInstrError(CI,
1177-
ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1224+
llvm::Value *MatrixLayout =
1225+
CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixLayout);
1226+
if (!llvm::isa<llvm::Constant>(MatrixLayout)) {
1227+
ValCtx.EmitInstrFormatError(
1228+
CI, ValidationRule::InstrLinalgMatrixShapeParamsAreConst,
1229+
{"MatrixLayout"});
1230+
return;
1231+
}
11781232
}
11791233

11801234
// Validate the type-defined mask compared to the store value mask which

utils/hct/hctdb.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8406,44 +8406,52 @@ def build_valrules(self):
84068406
)
84078407

84088408
# Linalg ops
8409-
self.add_valrule(
8409+
self.add_valrule_msg(
84108410
"Instr.MatVecOpIsUnsignedFlagsAreConst",
8411-
"MatVec Ops Is Unsigned flag is a constant",
8411+
"In Linalg Mul/MulAdd functions, IsUnsigned flag is a constant.",
8412+
"'%1' is not a constant value.",
84128413
)
84138414

8414-
self.add_valrule(
8415+
self.add_valrule_msg(
84158416
"Instr.LinalgInterpretationParamAreConst",
8416-
"Interpretation values are constants",
8417+
"In Linalg operations, Interpretation value is a constant.",
8418+
"'%1' is not a constant value.",
84178419
)
84188420

8419-
self.add_valrule(
8421+
self.add_valrule_msg(
84208422
"Instr.LinalgInvalidRegisterInterpValue",
8421-
"From Register Interpretation value not in valid set",
8423+
"From Register Interpretation value must be valid.",
8424+
"'%0' is not a valid %1 interpretation value.",
84228425
)
84238426

8424-
self.add_valrule(
8427+
self.add_valrule_msg(
84258428
"Instr.LinalgInvalidMemoryInterpValue",
8426-
"In Memory Interpolation value not in valid set",
8429+
"In Memory Interpolation value must be valid.",
8430+
"'%0' is not a valid %1 interpretation value."
84278431
)
84288432

8429-
self.add_valrule(
8433+
self.add_valrule_msg(
84308434
"Instr.LinalgMatrixShapeParamsAreConst",
8431-
"Matrix Layout, Dimensions and isTranspose are immediate constants",
8435+
"Matrix Layout, Dimensions and isTranspose are constants",
8436+
"'%0' is not a constant value.",
84328437
)
84338438

8434-
self.add_valrule(
8435-
"Instr.LinalgInvalidMatrixLayoutValue",
8436-
"Matrix Layout for Linalg ops not in valid set",
8439+
self.add_valrule_msg(
8440+
"Instr.LinalgInvalidMatrixLayoutValueForMatVecOps",
8441+
"Matrix Layout for Linalg Mul/MulAdd operation must be valid.",
8442+
"Matrix Layout value '%0' is not valid. Must be between [%1 - %2].",
84378443
)
84388444

8439-
self.add_valrule(
8445+
self.add_valrule_msg(
84408446
"Instr.LinalgMatrixLayoutNotTransposable",
8441-
"Matrix Layout not transposable",
8447+
"Row Major and Column Major matrix layouts are not transposable.",
8448+
"'%0' matrix layout is not transposable.",
84428449
)
84438450

8444-
self.add_valrule(
8451+
self.add_valrule_msg(
84458452
"Instr.LinalgNotAnUnsignedType",
8446-
"Unsigned flag set for signed type",
8453+
"Unsigned flag set for a float signed type",
8454+
"IsUnsigned flag set to true for a float type '%0' vector",
84478455
)
84488456

84498457
# Some legacy rules:

0 commit comments

Comments
 (0)