1717// Validates correctness of composite SPIR-V instructions.
1818
1919#include < climits>
20+ #include < cstdint>
2021
2122#include " source/opcode.h"
2223#include " source/spirv_target_env.h"
@@ -36,14 +37,11 @@ namespace {
3637// deep).
3738spv_result_t GetExtractInsertValueType (ValidationState_t& _,
3839 const Instruction* inst,
39- uint32_t * member_type) {
40- const spv::Op opcode = inst->opcode ();
41- assert (opcode == spv::Op::OpCompositeExtract ||
42- opcode == spv::Op::OpCompositeInsert);
43- uint32_t word_index = opcode == spv::Op::OpCompositeExtract ? 4 : 5 ;
44- const uint32_t num_words = static_cast <uint32_t >(inst->words ().size ());
45- const uint32_t composite_id_index = word_index - 1 ;
46- const uint32_t num_indices = num_words - word_index;
40+ uint32_t * member_type,
41+ uint32_t composite_id_index) {
42+ const uint32_t num_operands = static_cast <uint32_t >(inst->operands ().size ());
43+ const uint32_t first_literal_index = composite_id_index + 1 ;
44+ const uint32_t num_indices = num_operands - first_literal_index;
4745 const uint32_t kCompositeExtractInsertMaxNumIndices = 255 ;
4846
4947 if (num_indices == 0 ) {
@@ -53,19 +51,21 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
5351
5452 } else if (num_indices > kCompositeExtractInsertMaxNumIndices ) {
5553 return _.diag (SPV_ERROR_INVALID_DATA, inst)
56- << " The number of indexes in Op" << spvOpcodeString (opcode)
54+ << " The number of indexes in Op" << spvOpcodeString (inst-> opcode () )
5755 << " may not exceed " << kCompositeExtractInsertMaxNumIndices
5856 << " . Found " << num_indices << " indexes." ;
5957 }
6058
61- *member_type = _.GetTypeId (inst-> word ( composite_id_index) );
59+ *member_type = _.GetOperandTypeId (inst, composite_id_index);
6260 if (*member_type == 0 ) {
6361 return _.diag (SPV_ERROR_INVALID_DATA, inst)
6462 << " Expected Composite to be an object of composite type" ;
6563 }
6664
67- for (; word_index < num_words; ++word_index) {
68- const uint32_t component_index = inst->word (word_index);
65+ for (uint32_t operand_index = first_literal_index;
66+ operand_index < num_operands; ++operand_index) {
67+ const uint32_t component_index =
68+ inst->GetOperandAs <uint32_t >(operand_index);
6969 const Instruction* const type_inst = _.FindDef (*member_type);
7070 assert (type_inst);
7171 switch (type_inst->opcode ()) {
@@ -471,9 +471,12 @@ spv_result_t ValidateCompositeConstructReplicate(ValidationState_t& _,
471471}
472472
473473spv_result_t ValidateCompositeExtract (ValidationState_t& _,
474- const Instruction* inst) {
474+ const Instruction* inst,
475+ uint32_t operand_index = 2 ) {
475476 uint32_t member_type = 0 ;
476- if (spv_result_t error = GetExtractInsertValueType (_, inst, &member_type)) {
477+
478+ if (spv_result_t error =
479+ GetExtractInsertValueType (_, inst, &member_type, operand_index)) {
477480 return error;
478481 }
479482
@@ -496,9 +499,10 @@ spv_result_t ValidateCompositeExtract(ValidationState_t& _,
496499}
497500
498501spv_result_t ValidateCompositeInsert (ValidationState_t& _,
499- const Instruction* inst) {
500- const uint32_t object_type = _.GetOperandTypeId (inst, 2 );
501- const uint32_t composite_type = _.GetOperandTypeId (inst, 3 );
502+ const Instruction* inst,
503+ uint32_t operand_index = 2 ) {
504+ const uint32_t object_type = _.GetOperandTypeId (inst, operand_index);
505+ const uint32_t composite_type = _.GetOperandTypeId (inst, operand_index + 1 );
502506 const uint32_t result_type = inst->type_id ();
503507 if (result_type != composite_type) {
504508 return _.diag (SPV_ERROR_INVALID_DATA, inst)
@@ -508,7 +512,8 @@ spv_result_t ValidateCompositeInsert(ValidationState_t& _,
508512 }
509513
510514 uint32_t member_type = 0 ;
511- if (spv_result_t error = GetExtractInsertValueType (_, inst, &member_type)) {
515+ if (spv_result_t error =
516+ GetExtractInsertValueType (_, inst, &member_type, operand_index + 1 )) {
512517 return error;
513518 }
514519
@@ -589,58 +594,58 @@ spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
589594}
590595
591596spv_result_t ValidateVectorShuffle (ValidationState_t& _,
592- const Instruction* inst) {
593- auto resultType = _.FindDef (inst->type_id ());
594- if (!_.IsVectorType (resultType->id ())) {
597+ const Instruction* inst,
598+ uint32_t operand_index = 2 ) {
599+ auto result_type = _.FindDef (inst->type_id ());
600+ if (!_.IsVectorType (result_type->id ())) {
595601 return _.diag (SPV_ERROR_INVALID_ID, inst)
596602 << " The Result Type of OpVectorShuffle must be"
597603 << " a vector type. Found Op"
598- << spvOpcodeString (resultType ->opcode ()) << " ." ;
604+ << spvOpcodeString (result_type ->opcode ()) << " ." ;
599605 }
600606
601607 // The number of components in Result Type must be the same as the number of
602608 // Component operands.
603- auto componentCount = inst->operands ().size () - 4 ;
604- auto resultVectorDimension = _.GetDimension (resultType->id ());
605- if (resultVectorDimension > 0 && componentCount != resultVectorDimension) {
609+ uint32_t first_literal_index = operand_index + 2 ;
610+ uint32_t component_count =
611+ static_cast <uint32_t >(inst->operands ().size ()) - first_literal_index;
612+ auto result_vec_dimension = _.GetDimension (result_type->id ());
613+ if (result_vec_dimension > 0 && component_count != result_vec_dimension) {
606614 return _.diag (SPV_ERROR_INVALID_ID, inst)
607615 << " OpVectorShuffle component literals count does not match "
608616 " Result Type <id> "
609- << _.getIdName (resultType ->id ()) << " s vector component count." ;
617+ << _.getIdName (result_type ->id ()) << " s vector component count." ;
610618 }
611619
612620 // Vector 1 and Vector 2 must both have vector types, with the same Component
613621 // Type as Result Type.
614- auto vector1Object = _.FindDef (inst->GetOperandAs <uint32_t >(2 ));
615- auto vector1Type = _.FindDef (vector1Object->type_id ());
616- auto vector2Object = _.FindDef (inst->GetOperandAs <uint32_t >(3 ));
617- auto vector2Type = _.FindDef (vector2Object->type_id ());
618- if (!_.IsVectorType (vector1Type->id ())) {
622+ auto vec1_type = _.FindDef (_.GetOperandTypeId (inst, operand_index));
623+ auto vec2_type = _.FindDef (_.GetOperandTypeId (inst, operand_index + 1 ));
624+ if (!_.IsVectorType (vec1_type->id ())) {
619625 return _.diag (SPV_ERROR_INVALID_ID, inst)
620626 << " The type of Vector 1 must be a vector type." ;
621627 }
622- if (!_.IsVectorType (vector2Type ->id ())) {
628+ if (!_.IsVectorType (vec2_type ->id ())) {
623629 return _.diag (SPV_ERROR_INVALID_ID, inst)
624630 << " The type of Vector 2 must be a vector type." ;
625631 }
626632
627- auto resultComponentType = resultType ->GetOperandAs <uint32_t >(1 );
628- if (vector1Type ->GetOperandAs <uint32_t >(1 ) != resultComponentType ) {
633+ uint32_t result_component_type = result_type ->GetOperandAs <uint32_t >(1 );
634+ if (vec1_type ->GetOperandAs <uint32_t >(1 ) != result_component_type ) {
629635 return _.diag (SPV_ERROR_INVALID_ID, inst)
630636 << " The Component Type of Vector 1 must be the same as ResultType." ;
631637 }
632- if (vector2Type ->GetOperandAs <uint32_t >(1 ) != resultComponentType ) {
638+ if (vec2_type ->GetOperandAs <uint32_t >(1 ) != result_component_type ) {
633639 return _.diag (SPV_ERROR_INVALID_ID, inst)
634640 << " The Component Type of Vector 2 must be the same as ResultType." ;
635641 }
636642
637643 // All Component literals must either be FFFFFFFF or in [0, N - 1].
638- auto vector1ComponentCount = vector1Type->GetOperandAs <uint32_t >(2 );
639- auto vector2ComponentCount = vector2Type->GetOperandAs <uint32_t >(2 );
640- auto N = vector1ComponentCount + vector2ComponentCount;
641- auto firstLiteralIndex = 4 ;
642- for (size_t i = firstLiteralIndex; i < inst->operands ().size (); ++i) {
643- auto literal = inst->GetOperandAs <uint32_t >(i);
644+ uint32_t vec1_component_count = vec1_type->GetOperandAs <uint32_t >(2 );
645+ uint32_t vec2_component_count = vec2_type->GetOperandAs <uint32_t >(2 );
646+ uint32_t N = vec1_component_count + vec2_component_count;
647+ for (size_t i = first_literal_index; i < inst->operands ().size (); ++i) {
648+ uint32_t literal = inst->GetOperandAs <uint32_t >(i);
644649 if (literal != 0xFFFFFFFF && literal >= N) {
645650 return _.diag (SPV_ERROR_INVALID_ID, inst)
646651 << " Component index " << literal << " is out of bounds for "
@@ -1168,6 +1173,20 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
11681173 return ValidateCompositeExtractCoopMatQCOM (_, inst);
11691174 case spv::Op::OpExtractSubArrayQCOM:
11701175 return ValidateExtractSubArrayQCOM (_, inst);
1176+
1177+ case spv::Op::OpSpecConstantOp: {
1178+ switch (inst->GetOperandAs <spv::Op>(2u )) {
1179+ case spv::Op::OpVectorShuffle:
1180+ return ValidateVectorShuffle (_, inst, 3 );
1181+ case spv::Op::OpCompositeExtract:
1182+ return ValidateCompositeExtract (_, inst, 3 );
1183+ case spv::Op::OpCompositeInsert:
1184+ return ValidateCompositeInsert (_, inst, 3 );
1185+ default :
1186+ break ;
1187+ }
1188+ }
1189+
11711190 default :
11721191 break ;
11731192 }
0 commit comments