@@ -92,38 +92,42 @@ spv_result_t ValidateFloatCompare(ValidationState_t& _,
9292}
9393
9494spv_result_t ValidateLogicalCompare (ValidationState_t& _,
95- const Instruction* inst) {
95+ const Instruction* inst,
96+ uint32_t operand_index = 2 ) {
9697 const spv::Op opcode = inst->opcode ();
9798 const uint32_t result_type = inst->type_id ();
9899 if (!_.IsBoolScalarType (result_type) && !_.IsBoolVectorType (result_type))
99100 return _.diag (SPV_ERROR_INVALID_DATA, inst)
100101 << " Expected bool scalar or vector type as Result Type: "
101102 << spvOpcodeString (opcode);
102103
103- if (result_type != _.GetOperandTypeId (inst, 2 ) ||
104- result_type != _.GetOperandTypeId (inst, 3 ))
104+ const uint32_t operand_1 = _.GetOperandTypeId (inst, operand_index);
105+ const uint32_t operand_2 = _.GetOperandTypeId (inst, operand_index + 1 );
106+ if (result_type != operand_1 || result_type != operand_2)
105107 return _.diag (SPV_ERROR_INVALID_DATA, inst)
106108 << " Expected both operands to be of Result Type: "
107109 << spvOpcodeString (opcode);
108110 return SPV_SUCCESS;
109111}
110112
111- spv_result_t ValidateLogicalNot (ValidationState_t& _, const Instruction* inst) {
113+ spv_result_t ValidateLogicalNot (ValidationState_t& _, const Instruction* inst,
114+ uint32_t operand_index = 2 ) {
112115 const spv::Op opcode = inst->opcode ();
113116 const uint32_t result_type = inst->type_id ();
114117 if (!_.IsBoolScalarType (result_type) && !_.IsBoolVectorType (result_type))
115118 return _.diag (SPV_ERROR_INVALID_DATA, inst)
116119 << " Expected bool scalar or vector type as Result Type: "
117120 << spvOpcodeString (opcode);
118121
119- if (result_type != _.GetOperandTypeId (inst, 2 ))
122+ if (result_type != _.GetOperandTypeId (inst, operand_index ))
120123 return _.diag (SPV_ERROR_INVALID_DATA, inst)
121124 << " Expected operand to be of Result Type: "
122125 << spvOpcodeString (opcode);
123126 return SPV_SUCCESS;
124127}
125128
126- spv_result_t ValidateSelect (ValidationState_t& _, const Instruction* inst) {
129+ spv_result_t ValidateSelect (ValidationState_t& _, const Instruction* inst,
130+ uint32_t operand_index = 2 ) {
127131 const spv::Op opcode = inst->opcode ();
128132 const uint32_t result_type = inst->type_id ();
129133 uint32_t dimension = 1 ;
@@ -186,9 +190,9 @@ spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst) {
186190 return fail ();
187191 }
188192
189- const uint32_t condition_type = _.GetOperandTypeId (inst, 2 );
190- const uint32_t left_type = _.GetOperandTypeId (inst, 3 );
191- const uint32_t right_type = _.GetOperandTypeId (inst, 4 );
193+ const uint32_t condition_type = _.GetOperandTypeId (inst, operand_index );
194+ const uint32_t left_type = _.GetOperandTypeId (inst, operand_index + 1 );
195+ const uint32_t right_type = _.GetOperandTypeId (inst, operand_index + 2 );
192196
193197 if (!condition_type || (!_.IsBoolScalarType (condition_type) &&
194198 !_.IsBoolVectorType (condition_type)))
@@ -216,16 +220,17 @@ spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst) {
216220 return SPV_SUCCESS;
217221}
218222
219- spv_result_t ValidateIntCompare (ValidationState_t& _, const Instruction* inst) {
223+ spv_result_t ValidateIntCompare (ValidationState_t& _, const Instruction* inst,
224+ uint32_t operand_index = 2 ) {
220225 const spv::Op opcode = inst->opcode ();
221226 const uint32_t result_type = inst->type_id ();
222227 if (!_.IsBoolScalarType (result_type) && !_.IsBoolVectorType (result_type))
223228 return _.diag (SPV_ERROR_INVALID_DATA, inst)
224229 << " Expected bool scalar or vector type as Result Type: "
225230 << spvOpcodeString (opcode);
226231
227- const uint32_t left_type = _.GetOperandTypeId (inst, 2 );
228- const uint32_t right_type = _.GetOperandTypeId (inst, 3 );
232+ const uint32_t left_type = _.GetOperandTypeId (inst, operand_index );
233+ const uint32_t right_type = _.GetOperandTypeId (inst, operand_index + 1 );
229234
230235 if (!left_type ||
231236 (!_.IsIntScalarType (left_type) && !_.IsIntVectorType (left_type)))
@@ -305,6 +310,35 @@ spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) {
305310 case spv::Op::OpSLessThan:
306311 case spv::Op::OpSLessThanEqual:
307312 return ValidateIntCompare (_, inst);
313+
314+ case spv::Op::OpSpecConstantOp: {
315+ switch (inst->GetOperandAs <spv::Op>(2u )) {
316+ case spv::Op::OpLogicalEqual:
317+ case spv::Op::OpLogicalNotEqual:
318+ case spv::Op::OpLogicalOr:
319+ case spv::Op::OpLogicalAnd:
320+ return ValidateLogicalCompare (_, inst, 3 );
321+ case spv::Op::OpLogicalNot:
322+ return ValidateLogicalNot (_, inst, 3 );
323+ case spv::Op::OpSelect:
324+ return ValidateSelect (_, inst, 3 );
325+ case spv::Op::OpIEqual:
326+ case spv::Op::OpINotEqual:
327+ case spv::Op::OpUGreaterThan:
328+ case spv::Op::OpUGreaterThanEqual:
329+ case spv::Op::OpULessThan:
330+ case spv::Op::OpULessThanEqual:
331+ case spv::Op::OpSGreaterThan:
332+ case spv::Op::OpSGreaterThanEqual:
333+ case spv::Op::OpSLessThan:
334+ case spv::Op::OpSLessThanEqual:
335+ return ValidateIntCompare (_, inst, 3 );
336+ default :
337+ break ;
338+ }
339+ break ;
340+ }
341+
308342 default :
309343 break ;
310344 }
0 commit comments