@@ -39,8 +39,8 @@ spv_result_t ValidateShaderBitWidth(ValidationState_t& _,
3939 return SPV_SUCCESS;
4040}
4141
42- spv_result_t ValidateConvertFToU (ValidationState_t& _,
43- const Instruction* inst ) {
42+ spv_result_t ValidateConvertFToU (ValidationState_t& _, const Instruction* inst,
43+ uint32_t operand_index = 2 ) {
4444 const spv::Op opcode = inst->opcode ();
4545 const uint32_t result_type = inst->type_id ();
4646 if (!_.IsUnsignedIntScalarType (result_type) &&
@@ -51,7 +51,7 @@ spv_result_t ValidateConvertFToU(ValidationState_t& _,
5151 << " Expected unsigned int scalar or vector type as Result Type: "
5252 << spvOpcodeString (opcode);
5353
54- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
54+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
5555 if (!input_type ||
5656 (!_.IsFloatScalarType (input_type) && !_.IsFloatVectorType (input_type) &&
5757 !_.IsFloatCooperativeMatrixType (input_type) &&
@@ -82,8 +82,8 @@ spv_result_t ValidateConvertFToU(ValidationState_t& _,
8282 return SPV_SUCCESS;
8383}
8484
85- spv_result_t ValidateConvertFToS (ValidationState_t& _,
86- const Instruction* inst ) {
85+ spv_result_t ValidateConvertFToS (ValidationState_t& _, const Instruction* inst,
86+ uint32_t operand_index = 2 ) {
8787 const spv::Op opcode = inst->opcode ();
8888 const uint32_t result_type = inst->type_id ();
8989 if (!_.IsIntScalarType (result_type) && !_.IsIntVectorType (result_type) &&
@@ -93,7 +93,7 @@ spv_result_t ValidateConvertFToS(ValidationState_t& _,
9393 << " Expected int scalar or vector type as Result Type: "
9494 << spvOpcodeString (opcode);
9595
96- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
96+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
9797 if (!input_type ||
9898 (!_.IsFloatScalarType (input_type) && !_.IsFloatVectorType (input_type) &&
9999 !_.IsFloatCooperativeMatrixType (input_type) &&
@@ -125,7 +125,8 @@ spv_result_t ValidateConvertFToS(ValidationState_t& _,
125125}
126126
127127spv_result_t ValidateConvertIntToF (ValidationState_t& _,
128- const Instruction* inst) {
128+ const Instruction* inst,
129+ uint32_t operand_index = 2 ) {
129130 const spv::Op opcode = inst->opcode ();
130131 const uint32_t result_type = inst->type_id ();
131132 if (!_.IsFloatScalarType (result_type) && !_.IsFloatVectorType (result_type) &&
@@ -135,7 +136,7 @@ spv_result_t ValidateConvertIntToF(ValidationState_t& _,
135136 << " Expected float scalar or vector type as Result Type: "
136137 << spvOpcodeString (opcode);
137138
138- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
139+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
139140 if (!input_type ||
140141 (!_.IsIntScalarType (input_type) && !_.IsIntVectorType (input_type) &&
141142 !_.IsIntCooperativeMatrixType (input_type) &&
@@ -307,7 +308,8 @@ spv_result_t ValidateFConvert(ValidationState_t& _, const Instruction* inst,
307308}
308309
309310spv_result_t ValidateQuantizeToF16 (ValidationState_t& _,
310- const Instruction* inst) {
311+ const Instruction* inst,
312+ uint32_t operand_index = 2 ) {
311313 const spv::Op opcode = inst->opcode ();
312314 const uint32_t result_type = inst->type_id ();
313315 if ((!_.IsFloatScalarType (result_type) &&
@@ -317,7 +319,7 @@ spv_result_t ValidateQuantizeToF16(ValidationState_t& _,
317319 << " Expected 32-bit float scalar or vector type as Result Type: "
318320 << spvOpcodeString (opcode);
319321
320- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
322+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
321323 if (input_type != result_type)
322324 return _.diag (SPV_ERROR_INVALID_DATA, inst)
323325 << " Expected input type to be equal to Result Type: "
@@ -326,15 +328,16 @@ spv_result_t ValidateQuantizeToF16(ValidationState_t& _,
326328}
327329
328330spv_result_t ValidateConvertPtrToU (ValidationState_t& _,
329- const Instruction* inst) {
331+ const Instruction* inst,
332+ uint32_t operand_index = 2 ) {
330333 const spv::Op opcode = inst->opcode ();
331334 const uint32_t result_type = inst->type_id ();
332335 if (!_.IsUnsignedIntScalarType (result_type))
333336 return _.diag (SPV_ERROR_INVALID_DATA, inst)
334337 << " Expected unsigned int scalar type as Result Type: "
335338 << spvOpcodeString (opcode);
336339
337- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
340+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
338341 if (!_.IsPointerType (input_type))
339342 return _.diag (SPV_ERROR_INVALID_DATA, inst)
340343 << " Expected input to be a pointer: " << spvOpcodeString (opcode);
@@ -389,15 +392,16 @@ spv_result_t ValidateSatConvertInt(ValidationState_t& _,
389392}
390393
391394spv_result_t ValidateConvertUToPtr (ValidationState_t& _,
392- const Instruction* inst) {
395+ const Instruction* inst,
396+ uint32_t operand_index = 2 ) {
393397 const spv::Op opcode = inst->opcode ();
394398 const uint32_t result_type = inst->type_id ();
395399 if (!_.IsPointerType (result_type))
396400 return _.diag (SPV_ERROR_INVALID_DATA, inst)
397401 << " Expected Result Type to be a pointer: "
398402 << spvOpcodeString (opcode);
399403
400- const uint32_t input_type = _.GetOperandTypeId (inst, 2 );
404+ const uint32_t input_type = _.GetOperandTypeId (inst, operand_index );
401405 if (!input_type || !_.IsIntScalarType (input_type))
402406 return _.diag (SPV_ERROR_INVALID_DATA, inst)
403407 << " Expected int scalar as input: " << spvOpcodeString (opcode);
@@ -843,6 +847,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
843847 return ValidateSConvert (_, inst, 3 );
844848 case spv::Op::OpFConvert:
845849 return ValidateFConvert (_, inst, 3 );
850+ case spv::Op::OpConvertSToF:
851+ case spv::Op::OpConvertUToF:
852+ return ValidateConvertIntToF (_, inst, 3 );
853+ case spv::Op::OpConvertFToS:
854+ return ValidateConvertFToS (_, inst, 3 );
855+ case spv::Op::OpConvertFToU:
856+ return ValidateConvertFToU (_, inst, 3 );
857+ case spv::Op::OpQuantizeToF16:
858+ return ValidateQuantizeToF16 (_, inst, 3 );
859+ case spv::Op::OpConvertPtrToU:
860+ return ValidateConvertPtrToU (_, inst, 3 );
861+ case spv::Op::OpConvertUToPtr:
862+ return ValidateConvertUToPtr (_, inst, 3 );
846863 default :
847864 break ;
848865 }
0 commit comments