Skip to content

Commit 989e29a

Browse files
authored
Improve constant composite validation (KhronosGroup#6598)
* Refactored a lot of repetitive code * Improved operand checks to disallow spec constants in regular constants * Removed some tests that no longer trigger appropriate messages * replaced more generally
1 parent 5d6745b commit 989e29a

6 files changed

Lines changed: 115 additions & 167 deletions

File tree

source/val/validate_constants.cpp

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ bool isCompositeType(const Instruction* inst) {
4040
(is_tensor && tensor_is_shaped);
4141
}
4242

43+
spv_result_t ValidateConstantOperand(ValidationState_t& _,
44+
const Instruction* inst, size_t operand) {
45+
std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
46+
47+
const auto operand_id = inst->GetOperandAs<uint32_t>(operand);
48+
const bool inst_is_spec_constant = spvOpcodeIsSpecConstant(inst->opcode());
49+
const auto operand_opcode = _.GetIdOpcode(operand_id);
50+
const bool is_constant = spvOpcodeIsConstantOrUndef(operand_opcode);
51+
const bool is_spec_constant = spvOpcodeIsSpecConstant(operand_opcode);
52+
if (!is_constant) {
53+
// All operands must be constant or undef.
54+
return _.diag(SPV_ERROR_INVALID_ID, inst)
55+
<< opcode_name << " must only have constant or undef operands: <id> "
56+
<< _.getIdName(operand_id);
57+
} else if (!inst_is_spec_constant && is_spec_constant) {
58+
// Spec constants are only allowed for spec constant opcodes.
59+
return _.diag(SPV_ERROR_INVALID_ID, inst)
60+
<< opcode_name << " must not have spec constant operands: <id> "
61+
<< _.getIdName(operand_id);
62+
}
63+
64+
return SPV_SUCCESS;
65+
}
66+
4367
spv_result_t ValidateConstantComposite(ValidationState_t& _,
4468
const Instruction* inst) {
4569
std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
@@ -51,7 +75,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
5175
<< _.getIdName(inst->type_id()) << " is not a composite type.";
5276
}
5377

54-
const auto constituent_count = inst->words().size() - 3;
78+
const auto constituent_count = inst->operands().size() - 2;
5579
switch (result_type->opcode()) {
5680
case spv::Op::OpTypeVector:
5781
case spv::Op::OpTypeVectorIdEXT: {
@@ -83,13 +107,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
83107
const auto constituent_id =
84108
inst->GetOperandAs<uint32_t>(constituent_index);
85109
const auto constituent = _.FindDef(constituent_id);
86-
if (!constituent ||
87-
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
88-
return _.diag(SPV_ERROR_INVALID_ID, inst)
89-
<< opcode_name << " Constituent <id> "
90-
<< _.getIdName(constituent_id)
91-
<< " is not a constant or undef.";
92-
}
93110
const auto constituent_result_type = _.FindDef(constituent->type_id());
94111
if (!constituent_result_type ||
95112
component_type->id() != constituent_result_type->id()) {
@@ -112,7 +129,8 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
112129
<< _.getIdName(result_type->id()) << "s matrix column count.";
113130
}
114131

115-
const auto column_type = _.FindDef(result_type->words()[2]);
132+
const auto column_type =
133+
_.FindDef(result_type->GetOperandAs<uint32_t>(1));
116134
if (!column_type) {
117135
return _.diag(SPV_ERROR_INVALID_ID, result_type)
118136
<< "Column type is not defined.";
@@ -130,15 +148,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
130148
const auto constituent_id =
131149
inst->GetOperandAs<uint32_t>(constituent_index);
132150
const auto constituent = _.FindDef(constituent_id);
133-
if (!constituent ||
134-
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
135-
// The message says "... or undef" because the spec does not say
136-
// undef is a constant.
137-
return _.diag(SPV_ERROR_INVALID_ID, inst)
138-
<< opcode_name << " Constituent <id> "
139-
<< _.getIdName(constituent_id)
140-
<< " is not a constant or undef.";
141-
}
142151
const auto vector = _.FindDef(constituent->type_id());
143152
if (!vector) {
144153
return _.diag(SPV_ERROR_INVALID_ID, constituent)
@@ -161,7 +170,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
161170
<< _.getIdName(result_type->id())
162171
<< "s matrix column component type.";
163172
}
164-
if (component_count != vector->words()[3]) {
173+
if (component_count != vector->GetOperandAs<uint32_t>(2)) {
165174
return _.diag(SPV_ERROR_INVALID_ID, inst)
166175
<< opcode_name << " Constituent <id> "
167176
<< _.getIdName(constituent_id)
@@ -198,13 +207,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
198207
const auto constituent_id =
199208
inst->GetOperandAs<uint32_t>(constituent_index);
200209
const auto constituent = _.FindDef(constituent_id);
201-
if (!constituent ||
202-
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
203-
return _.diag(SPV_ERROR_INVALID_ID, inst)
204-
<< opcode_name << " Constituent <id> "
205-
<< _.getIdName(constituent_id)
206-
<< " is not a constant or undef.";
207-
}
208210
const auto constituent_type = _.FindDef(constituent->type_id());
209211
if (!constituent_type) {
210212
return _.diag(SPV_ERROR_INVALID_ID, constituent)
@@ -220,7 +222,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
220222
}
221223
} break;
222224
case spv::Op::OpTypeStruct: {
223-
const auto member_count = result_type->words().size() - 2;
225+
const auto member_count = result_type->operands().size() - 1;
224226
if (member_count != constituent_count) {
225227
return _.diag(SPV_ERROR_INVALID_ID, inst)
226228
<< opcode_name << " Constituent <id> "
@@ -234,13 +236,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
234236
const auto constituent_id =
235237
inst->GetOperandAs<uint32_t>(constituent_index);
236238
const auto constituent = _.FindDef(constituent_id);
237-
if (!constituent ||
238-
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
239-
return _.diag(SPV_ERROR_INVALID_ID, inst)
240-
<< opcode_name << " Constituent <id> "
241-
<< _.getIdName(constituent_id)
242-
<< " is not a constant or undef.";
243-
}
244239
const auto constituent_type = _.FindDef(constituent->type_id());
245240
if (!constituent_type) {
246241
return _.diag(SPV_ERROR_INVALID_ID, constituent)
@@ -268,11 +263,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
268263
}
269264
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
270265
const auto constituent = _.FindDef(constituent_id);
271-
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
272-
return _.diag(SPV_ERROR_INVALID_ID, inst)
273-
<< opcode_name << " Constituent <id> "
274-
<< _.getIdName(constituent_id) << " is not a constant or undef.";
275-
}
276266
const auto constituent_type = _.FindDef(constituent->type_id());
277267
if (!constituent_type) {
278268
return _.diag(SPV_ERROR_INVALID_ID, constituent)
@@ -328,13 +318,6 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
328318
const auto constituent_id =
329319
inst->GetOperandAs<uint32_t>(constituent_index);
330320
const auto constituent = _.FindDef(constituent_id);
331-
if (!constituent ||
332-
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
333-
return _.diag(SPV_ERROR_INVALID_ID, inst)
334-
<< opcode_name << " Constituent <id> "
335-
<< _.getIdName(constituent_id)
336-
<< " is not a constant or undef.";
337-
}
338321
const auto constituent_type = _.FindDef(constituent->type_id());
339322
if (!constituent_type) {
340323
return _.diag(SPV_ERROR_INVALID_ID, constituent)
@@ -427,6 +410,13 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
427410
default:
428411
break;
429412
}
413+
414+
for (size_t i = 2; i < inst->operands().size(); i++) {
415+
if (auto error = ValidateConstantOperand(_, inst, i)) {
416+
return error;
417+
}
418+
}
419+
430420
return SPV_SUCCESS;
431421
}
432422

@@ -443,12 +433,6 @@ spv_result_t ValidateConstantCompositeReplicate(ValidationState_t& _,
443433

444434
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
445435
const auto constituent = _.FindDef(constituent_id);
446-
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
447-
return _.diag(SPV_ERROR_INVALID_ID, inst)
448-
<< opcode_name << " Constituent <id> " << _.getIdName(constituent_id)
449-
<< " is not a constant or undef.";
450-
}
451-
452436
switch (result_type->opcode()) {
453437
case spv::Op::OpTypeVector:
454438
case spv::Op::OpTypeVectorIdEXT:
@@ -486,7 +470,8 @@ spv_result_t ValidateConstantCompositeReplicate(ValidationState_t& _,
486470
default:
487471
break;
488472
}
489-
return SPV_SUCCESS;
473+
474+
return ValidateConstantOperand(_, inst, 2);
490475
}
491476

492477
spv_result_t ValidateConstantSampler(ValidationState_t& _,

test/opt/ccp_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ TEST_F(CCPTest, SkipSpecConstantInstrucitons) {
581581
EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange);
582582
}
583583

584-
TEST_F(CCPTest, FoldConstantCompositeInstrucitonsWithSpecConst) {
584+
TEST_F(CCPTest, FoldConstantCompositeInstructionsWithSpecConst) {
585585
const std::string spv_asm = R"(
586586
OpCapability Shader
587587
OpMemoryModel Logical GLSL450
@@ -595,7 +595,7 @@ TEST_F(CCPTest, FoldConstantCompositeInstrucitonsWithSpecConst) {
595595
%true = OpConstantTrue %bool
596596
; CHECK: [[spec_const:%\w+]] = OpSpecConstantComposite %v3bool
597597
%11 = OpSpecConstantComposite %v3bool %true %true %true
598-
%12 = OpConstantComposite %_struct_8 %11
598+
%12 = OpSpecConstantComposite %_struct_8 %11
599599
; CHECK: OpFunction
600600
%1 = OpFunction %void None %4
601601
%29 = OpLabel

test/val/val_constants_test.cpp

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ TEST_F(ValidateConstant, ConstantCompositeReplicateWrongOperandType) {
612612
"does not match Result Type <id> '17[%v4int]'s element type"));
613613
}
614614

615-
TEST_F(ValidateConstant, ConstantCompositeReplicateWrongOperandClass) {
615+
TEST_F(ValidateConstant, ConstantCompositeReplicateSpecOperand) {
616616
std::string spirv =
617617
std::string(
618618
"OpCapability Shader\nOpCapability Linkage\nOpCapability "
@@ -622,16 +622,84 @@ TEST_F(ValidateConstant, ConstantCompositeReplicateWrongOperandClass) {
622622
"\"SPV_EXT_replicated_composites\"\nOpMemoryModel Logical Simple\n") +
623623
kBasicTypes + R"(
624624
%int = OpTypeInt 32 1
625-
%v4int = OpTypeVector %int 4
626-
%var = OpVariable %_ptr_uint Workgroup
627-
%const_vector = OpConstantCompositeReplicateEXT %v4int %var
625+
%int_4 = OpConstant %int 4
626+
%arr = OpTypeArray %int %int_4
627+
%int_0 = OpSpecConstant %int 0
628+
%const_arr = OpConstantCompositeReplicateEXT %arr %int_0
629+
)";
630+
CompileSuccessfully(spirv);
631+
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
632+
EXPECT_THAT(getDiagnosticString(),
633+
HasSubstr("OpConstantCompositeReplicateEXT must not have spec "
634+
"constant operands: <id>"));
635+
}
636+
637+
TEST_F(ValidateConstant, ConstantCompositeReplicateNotConstant) {
638+
std::string spirv =
639+
std::string(
640+
"OpCapability Kernel\nOpCapability Linkage\nOpCapability "
641+
"Int64\nOpCapability Float64\nOpCapability "
642+
"VariablePointers\nOpCapability Addresses\nOpCapability "
643+
"ReplicatedCompositesEXT\nOpExtension "
644+
"\"SPV_KHR_variable_pointers\"\nOpExtension "
645+
"\"SPV_EXT_replicated_composites\"\nOpMemoryModel Physical64 "
646+
"OpenCL\n") +
647+
kBasicTypes + R"(
648+
%uint_4 = OpConstant %uint 4
649+
%ptr = OpTypePointer Private %uint
650+
%var = OpVariable %ptr Private
651+
%arr = OpTypeArray %ptr %uint_4
652+
%const_arr = OpConstantCompositeReplicateEXT %arr %var
653+
)";
654+
CompileSuccessfully(spirv);
655+
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
656+
EXPECT_THAT(getDiagnosticString(),
657+
HasSubstr("OpConstantCompositeReplicateEXT must only have "
658+
"constant or undef operands: <id>"));
659+
}
660+
661+
TEST_F(ValidateConstant, ConstantCompositeSpecOperand) {
662+
std::string spirv =
663+
std::string(
664+
"OpCapability Shader\nOpCapability Linkage\nOpCapability "
665+
"Int64\nOpCapability Float64\nOpCapability "
666+
"VariablePointers\nOpCapability ReplicatedCompositesEXT\nOpExtension "
667+
"\"SPV_KHR_variable_pointers\"\nOpExtension "
668+
"\"SPV_EXT_replicated_composites\"\nOpMemoryModel Logical Simple\n") +
669+
kBasicTypes + R"(
670+
%int = OpTypeInt 32 1
671+
%int_4 = OpConstant %int 4
672+
%arr = OpTypeArray %int %int_4
673+
%int_0 = OpSpecConstant %int 0
674+
%const_arr = OpConstantComposite %arr %int_0 %int_0 %int_0 %int_0
628675
)";
629676
CompileSuccessfully(spirv);
630677
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
631678
EXPECT_THAT(
632679
getDiagnosticString(),
633-
HasSubstr("OpConstantCompositeReplicateEXT Constituent <id> '18[%18]' "
634-
"is not a constant or undef"));
680+
HasSubstr(
681+
"OpConstantComposite must not have spec constant operands: <id>"));
682+
}
683+
684+
TEST_F(ValidateConstant, ConstantCompositeNotConstant) {
685+
std::string spirv =
686+
std::string(
687+
"OpCapability Kernel\nOpCapability Linkage\nOpCapability "
688+
"Int64\nOpCapability Float64\nOpCapability "
689+
"VariablePointers\nOpCapability Addresses\nOpExtension "
690+
"\"SPV_KHR_variable_pointers\"\nOpMemoryModel Physical64 OpenCL\n") +
691+
kBasicTypes + R"(
692+
%uint_4 = OpConstant %uint 4
693+
%ptr = OpTypePointer Private %uint
694+
%var = OpVariable %ptr Private
695+
%arr = OpTypeArray %ptr %uint_4
696+
%const_arr = OpConstantComposite %arr %var %var %var %var
697+
)";
698+
CompileSuccessfully(spirv);
699+
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
700+
EXPECT_THAT(getDiagnosticString(),
701+
HasSubstr("OpConstantComposite must only have constant or undef "
702+
"operands: <id>"));
635703
}
636704

637705
TEST_F(ValidateConstant, ConstantCompositeReplicateNotComposite) {

0 commit comments

Comments
 (0)