Skip to content

Commit f7fabd9

Browse files
Add validation rules for DXIL ops
1 parent a3bdc34 commit f7fabd9

4 files changed

Lines changed: 245 additions & 9 deletions

File tree

include/dxc/DXIL/DxilConstants.h

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,31 @@ const unsigned kDxilMaxOloadDims = 2;
163163
enum class ComponentType : uint32_t {
164164
Invalid = 0,
165165
I1,
166-
I16,
167-
U16,
168-
I32,
169-
U32,
166+
I16, // = 2
167+
U16, // = 3
168+
I32, // = 4
169+
U32, // = 5
170170
I64,
171171
U64,
172-
F16,
173-
F32,
172+
F16, // = 8
173+
F32, // = 9
174174
F64,
175175
SNormF16,
176176
UNormF16,
177177
SNormF32,
178178
UNormF32,
179179
SNormF64,
180180
UNormF64,
181-
PackedS8x32,
182-
PackedU8x32,
181+
PackedS8x32, // = 17
182+
PackedU8x32, // = 18
183+
184+
// BEGIN NEW FOR SM 6.9
185+
U8, // = 19
186+
I8, // = 20
187+
F8_E4M3, // = 21
188+
F8_E5M2, // = 22
189+
// END
190+
183191
LastEntry
184192
};
185193

@@ -1536,6 +1544,28 @@ const unsigned kMSStoreOutputColOpIdx = 3;
15361544
const unsigned kMSStoreOutputVIdxOpIdx = 4;
15371545
const unsigned kMSStoreOutputValOpIdx = 5;
15381546

1547+
// MatVec Ops
1548+
const unsigned kMatVecMulInputVectorIdx = 1;
1549+
const unsigned kMatVecMulIsInputUnsignedIdx = 2;
1550+
const unsigned kMatVecMulInputInterpretationIdx = 3;
1551+
const unsigned kMatVecMulMatrixBufferIdx = 4;
1552+
const unsigned kMatVecMulMatrixOffsetIdx = 5;
1553+
const unsigned kMatVecMulMatrixInterpretationIdx = 6;
1554+
const unsigned kMatVecMulMatrixMIdx = 7;
1555+
const unsigned kMatVecMulMatrixKIdx = 8;
1556+
const unsigned kMatVecMulMatrixLayoutIdx = 9;
1557+
const unsigned kMatVecMulMatrixTransposeIdx = 10;
1558+
const unsigned kMatVecMulMatrixStrideIdx = 11;
1559+
const unsigned kMatVecMulIsOutputUnsignedIdx = 12;
1560+
1561+
// MatVecAdd
1562+
const unsigned kMatVecMulAddBiasInterpretation = 14;
1563+
const unsigned kMatVecMulAddIsOutputUnsignedIdx = 15;
1564+
1565+
// Outer Product Accumulate
1566+
const unsigned kOuterProdAccMatrixInterpretation = 5;
1567+
const unsigned kOuterProdAccMatrixLayout = 6;
1568+
15391569
// TODO: add operand index for all the OpCodeClass.
15401570
} // namespace OperandIndex
15411571

@@ -2105,6 +2135,13 @@ extern const char *kHostLayoutTypePrefix;
21052135

21062136
extern const char *kWaveOpsIncludeHelperLanesString;
21072137

2138+
enum class DXILMatrixLayout : uint32_t {
2139+
RowMajor = 0,
2140+
ColumnMajor = 1,
2141+
MulOptimal = 2,
2142+
OuterProductOptimal = 3,
2143+
};
2144+
21082145
} // namespace DXIL
21092146

21102147
} // namespace hlsl

include/dxc/DxilContainer/RDAT_LibraryTypes.inl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,13 @@ RDAT_DXIL_ENUM_START(hlsl::DXIL::ComponentType, uint32_t)
564564
RDAT_ENUM_VALUE_NODEF(UNormF64)
565565
RDAT_ENUM_VALUE_NODEF(PackedS8x32)
566566
RDAT_ENUM_VALUE_NODEF(PackedU8x32)
567+
RDAT_ENUM_VALUE_NODEF(U8)
568+
RDAT_ENUM_VALUE_NODEF(I8)
569+
RDAT_ENUM_VALUE_NODEF(F8_E4M3)
570+
RDAT_ENUM_VALUE_NODEF(F8_E5M2)
567571
RDAT_ENUM_VALUE_NODEF(LastEntry)
568572
#if DEF_RDAT_ENUMS == DEF_RDAT_DUMP_IMPL
569-
static_assert((unsigned)hlsl::DXIL::ComponentType::LastEntry == 19,
573+
static_assert((unsigned)hlsl::DXIL::ComponentType::LastEntry == 23,
570574
"otherwise, RDAT_DXIL_ENUM definition needs updating");
571575
#endif
572576
RDAT_ENUM_END()

lib/DxilValidation/DxilValidation.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,175 @@ static void ValidateImmOperandForMathDxilOp(CallInst *CI, DXIL::OpCode opcode,
970970
}
971971
}
972972

973+
static bool CheckFromRegisterInterpretations(uint32_t Ri) {
974+
std::set<DXIL::ComponentType> ValidSet = {
975+
DXIL::ComponentType::I16, DXIL::ComponentType::U16,
976+
DXIL::ComponentType::I32, DXIL::ComponentType::U32,
977+
DXIL::ComponentType::F16, DXIL::ComponentType::F32,
978+
DXIL::ComponentType::PackedS8x32, DXIL::ComponentType::PackedU8x32,
979+
DXIL::ComponentType::U8, DXIL::ComponentType::I8,
980+
DXIL::ComponentType::F8_E4M3, DXIL::ComponentType::F8_E5M2};
981+
982+
if (ValidSet.find(static_cast<DXIL::ComponentType>(Ri)) != ValidSet.end()) {
983+
return true;
984+
}
985+
return false;
986+
}
987+
988+
static bool CheckInMemoryInterpretations(uint32_t Mi) {
989+
std::set<DXIL::ComponentType> ValidSet = {
990+
DXIL::ComponentType::I16, DXIL::ComponentType::U16,
991+
DXIL::ComponentType::I32, DXIL::ComponentType::U32,
992+
DXIL::ComponentType::F16, DXIL::ComponentType::F32,
993+
DXIL::ComponentType::U8, DXIL::ComponentType::I8,
994+
DXIL::ComponentType::F8_E4M3, DXIL::ComponentType::F8_E5M2};
995+
996+
if (ValidSet.find(static_cast<DXIL::ComponentType>(Mi)) != ValidSet.end()) {
997+
return true;
998+
}
999+
return false;
1000+
}
1001+
1002+
static bool CheckMatrixLayout(uint32_t Ml) {
1003+
std::set<DXIL::DXILMatrixLayout> ValidSet = {
1004+
DXIL::DXILMatrixLayout::RowMajor, DXIL::DXILMatrixLayout::ColumnMajor,
1005+
DXIL::DXILMatrixLayout::MulOptimal,
1006+
DXIL::DXILMatrixLayout::OuterProductOptimal};
1007+
1008+
if (ValidSet.find(static_cast<DXIL::DXILMatrixLayout>(Ml)) !=
1009+
ValidSet.end()) {
1010+
return true;
1011+
}
1012+
return false;
1013+
}
1014+
1015+
static void ValidateImmOperandsForMatVecOps(CallInst *CI, DXIL::OpCode opcode,
1016+
ValidationContext &ValCtx) {
1017+
1018+
// Check Common operands
1019+
llvm::Value *InputIsUnsigned =
1020+
CI->getOperand(DXIL::OperandIndex::kMatVecMulIsInputUnsignedIdx);
1021+
llvm::Value *InputInterpretation =
1022+
CI->getOperand(DXIL::OperandIndex::kMatVecMulInputInterpretationIdx);
1023+
llvm::Value *MatrixInterpretation =
1024+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixInterpretationIdx);
1025+
llvm::Value *MatrixM =
1026+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixMIdx);
1027+
llvm::Value *MatrixK =
1028+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixKIdx);
1029+
llvm::Value *MatrixLayout =
1030+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixLayoutIdx);
1031+
llvm::Value *MatrixTranspose =
1032+
CI->getOperand(DXIL::OperandIndex::kMatVecMulMatrixTransposeIdx);
1033+
1034+
if (!llvm::isa<llvm::Constant>(InputIsUnsigned)) {
1035+
ValCtx.EmitInstrError(CI,
1036+
ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1037+
}
1038+
1039+
if (!llvm::isa<llvm::Constant>(InputInterpretation) ||
1040+
!llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1041+
ValCtx.EmitInstrError(
1042+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1043+
}
1044+
1045+
// Check if InputInterpretation and MatrixInterpretation are valid
1046+
ConstantInt *Ii = cast<ConstantInt>(InputInterpretation);
1047+
auto IiValue = Ii->getLimitedValue();
1048+
if (!CheckFromRegisterInterpretations(IiValue)) {
1049+
ValCtx.EmitInstrError(
1050+
CI, ValidationRule::InstrLinalgInvalidRegisterInterpValue);
1051+
}
1052+
1053+
ConstantInt *Mi = cast<ConstantInt>(MatrixInterpretation);
1054+
auto MiValue = Mi->getLimitedValue();
1055+
if (!CheckInMemoryInterpretations(MiValue)) {
1056+
ValCtx.EmitInstrError(CI,
1057+
ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1058+
}
1059+
1060+
if (!llvm::isa<llvm::Constant>(MatrixM) ||
1061+
!llvm::isa<llvm::Constant>(MatrixK) ||
1062+
!llvm::isa<llvm::Constant>(MatrixLayout) ||
1063+
!llvm::isa<llvm::Constant>(MatrixTranspose)) {
1064+
ValCtx.EmitInstrError(CI,
1065+
ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1066+
}
1067+
1068+
ConstantInt *Ml = cast<ConstantInt>(MatrixLayout);
1069+
auto MlValue = Ml->getLimitedValue();
1070+
if (!CheckMatrixLayout(MlValue)) {
1071+
ValCtx.EmitInstrError(CI,
1072+
ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1073+
}
1074+
1075+
switch (opcode) {
1076+
case DXIL::OpCode::MatVecMul: {
1077+
llvm::Value *OutputIsUnsigned =
1078+
CI->getOperand(DXIL::OperandIndex::kMatVecMulIsOutputUnsignedIdx);
1079+
if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1080+
ValCtx.EmitInstrError(
1081+
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1082+
}
1083+
1084+
} break;
1085+
case DXIL::OpCode::MatVecMulAdd: {
1086+
llvm::Value *OutputIsUnsigned =
1087+
CI->getOperand(DXIL::OperandIndex::kMatVecMulAddIsOutputUnsignedIdx);
1088+
if (!llvm::isa<llvm::Constant>(OutputIsUnsigned)) {
1089+
ValCtx.EmitInstrError(
1090+
CI, ValidationRule::InstrMatVecOpIsUnsignedFlagsAreConst);
1091+
}
1092+
llvm::Value *BiasInterpretation =
1093+
CI->getOperand(DXIL::OperandIndex::kMatVecMulAddBiasInterpretation);
1094+
if (!llvm::isa<llvm::Constant>(BiasInterpretation)) {
1095+
ValCtx.EmitInstrError(
1096+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1097+
}
1098+
ConstantInt *Bi = cast<ConstantInt>(BiasInterpretation);
1099+
auto BiValue = Bi->getLimitedValue();
1100+
if (!CheckInMemoryInterpretations(BiValue)) {
1101+
ValCtx.EmitInstrError(
1102+
CI, ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1103+
}
1104+
} break;
1105+
default:
1106+
break;
1107+
}
1108+
}
1109+
1110+
static void ValidateImmOperandsForOuterProdAcc(CallInst *CI,
1111+
DXIL::OpCode opcode,
1112+
ValidationContext &ValCtx) {
1113+
1114+
llvm::Value *MatrixInterpretation =
1115+
CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixInterpretation);
1116+
llvm::Value *MatrixLayout =
1117+
CI->getOperand(DXIL::OperandIndex::kOuterProdAccMatrixLayout);
1118+
1119+
if (!llvm::isa<llvm::Constant>(MatrixInterpretation)) {
1120+
ValCtx.EmitInstrError(
1121+
CI, ValidationRule::InstrLinalgInterpretationParamAreConst);
1122+
}
1123+
ConstantInt *Mi = cast<ConstantInt>(MatrixInterpretation);
1124+
auto MiValue = Mi->getLimitedValue();
1125+
if (!CheckInMemoryInterpretations(MiValue)) {
1126+
ValCtx.EmitInstrError(CI,
1127+
ValidationRule::InstrLinalgInvalidMemoryInterpValue);
1128+
}
1129+
1130+
if (!llvm::isa<llvm::Constant>(MatrixLayout)) {
1131+
ValCtx.EmitInstrError(CI,
1132+
ValidationRule::InstrLinalgMatrixShapeParamsAreConst);
1133+
}
1134+
ConstantInt *Ml = cast<ConstantInt>(MatrixLayout);
1135+
auto MlValue = Ml->getLimitedValue();
1136+
if (!CheckMatrixLayout(MlValue)) {
1137+
ValCtx.EmitInstrError(CI,
1138+
ValidationRule::InstrLinalgInvalidMatrixLayoutValue);
1139+
}
1140+
}
1141+
9731142
// Validate the type-defined mask compared to the store value mask which
9741143
// indicates which parts were defined returns true if caller should continue
9751144
// validation
@@ -1942,6 +2111,16 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
19422111
GetLaunchTypeStr(nodeLaunchType)});
19432112

19442113
break;
2114+
case DXIL::OpCode::MatVecMul:
2115+
case DXIL::OpCode::MatVecMulAdd:
2116+
ValidateImmOperandsForMatVecOps(CI, opcode, ValCtx);
2117+
break;
2118+
case DXIL::OpCode::OuterProductAccumulate:
2119+
ValidateImmOperandsForOuterProdAcc(CI, opcode, ValCtx);
2120+
break;
2121+
case DXIL::OpCode::VectorAccumulate:
2122+
2123+
break;
19452124

19462125
default:
19472126
// TODO: make sure every opcode is checked.

utils/hct/hctdb.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7878,6 +7878,22 @@ def build_valrules(self):
78787878
"Invalid use of completed record handle.",
78797879
)
78807880

7881+
# Linalg ops
7882+
self.add_valrule("Instr.MatVecOpIsUnsignedFlagsAreConst", "MatVec Ops Is Unsigned flag is a constant")
7883+
7884+
self.add_valrule("Instr.LinalgInterpretationParamAreConst", "Interpretation values are constants")
7885+
7886+
self.add_valrule("Instr.LinalgInvalidRegisterInterpValue", "From Register Interpretation value not in valid set")
7887+
7888+
self.add_valrule("Instr.LinalgInvalidMemoryInterpValue", "In Memory Interpolation value not in valid set")
7889+
7890+
self.add_valrule("Instr.LinalgMatrixShapeParamsAreConst", "Matrix Layout, Dimensions and isTranspose are immediate constants")
7891+
7892+
self.add_valrule("Instr.LinalgInvalidMatrixLayoutValue", "Matrix Layout for Linalg ops not in valid set")
7893+
7894+
7895+
7896+
78817897
# Some legacy rules:
78827898
# - space is only supported for shader targets 5.1 and higher
78837899
# - multiple rules regarding derivatives, which isn't a supported feature for DXIL

0 commit comments

Comments
 (0)