Skip to content

Commit 334a9b7

Browse files
spirv-val: More OpSpecConstantOp (KhronosGroup#6585)
All is left after for KhronosGroup#6564 is the VectorShuffle/CompositeInsert/CompositeExtract and the AccessChains, these will need some more changes to passing in the operand index, so saving them for last
1 parent c75e349 commit 334a9b7

4 files changed

Lines changed: 222 additions & 30 deletions

File tree

source/val/validate_conversion.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,8 @@ spv_result_t ValidateConvertUToPtr(ValidationState_t& _,
433433
}
434434

435435
spv_result_t ValidatePtrCastToGeneric(ValidationState_t& _,
436-
const Instruction* inst) {
436+
const Instruction* inst,
437+
uint32_t operand_index = 2) {
437438
const spv::Op opcode = inst->opcode();
438439
const uint32_t result_type = inst->type_id();
439440
spv::StorageClass result_storage_class;
@@ -449,7 +450,7 @@ spv_result_t ValidatePtrCastToGeneric(ValidationState_t& _,
449450
<< "Expected Result Type to have storage class Generic: "
450451
<< spvOpcodeString(opcode);
451452

452-
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
453+
const uint32_t input_type = _.GetOperandTypeId(inst, operand_index);
453454
spv::StorageClass input_storage_class;
454455
uint32_t input_data_type = 0;
455456
if (!_.GetPointerTypeInfo(input_type, &input_data_type, &input_storage_class))
@@ -471,7 +472,8 @@ spv_result_t ValidatePtrCastToGeneric(ValidationState_t& _,
471472
}
472473

473474
spv_result_t ValidateGenericCastToPtr(ValidationState_t& _,
474-
const Instruction* inst) {
475+
const Instruction* inst,
476+
uint32_t operand_index = 2) {
475477
const spv::Op opcode = inst->opcode();
476478
const uint32_t result_type = inst->type_id();
477479
spv::StorageClass result_storage_class;
@@ -489,7 +491,7 @@ spv_result_t ValidateGenericCastToPtr(ValidationState_t& _,
489491
<< "Expected Result Type to have storage class Workgroup, "
490492
<< "CrossWorkgroup or Function: " << spvOpcodeString(opcode);
491493

492-
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
494+
const uint32_t input_type = _.GetOperandTypeId(inst, operand_index);
493495
spv::StorageClass input_storage_class;
494496
uint32_t input_data_type = 0;
495497
if (!_.GetPointerTypeInfo(input_type, &input_data_type, &input_storage_class))
@@ -552,10 +554,11 @@ spv_result_t ValidateGenericCastToPtrExplicit(ValidationState_t& _,
552554
return SPV_SUCCESS;
553555
}
554556

555-
spv_result_t ValidateBitcast(ValidationState_t& _, const Instruction* inst) {
557+
spv_result_t ValidateBitcast(ValidationState_t& _, const Instruction* inst,
558+
uint32_t operand_index = 2) {
556559
const spv::Op opcode = inst->opcode();
557560
const uint32_t result_type = inst->type_id();
558-
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
561+
const uint32_t input_type = _.GetOperandTypeId(inst, operand_index);
559562
if (!input_type)
560563
return _.diag(SPV_ERROR_INVALID_DATA, inst)
561564
<< "Expected input to have a type: " << spvOpcodeString(opcode);
@@ -860,6 +863,12 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
860863
return ValidateConvertPtrToU(_, inst, 3);
861864
case spv::Op::OpConvertUToPtr:
862865
return ValidateConvertUToPtr(_, inst, 3);
866+
case spv::Op::OpGenericCastToPtr:
867+
return ValidateGenericCastToPtr(_, inst, 3);
868+
case spv::Op::OpPtrCastToGeneric:
869+
return ValidatePtrCastToGeneric(_, inst, 3);
870+
case spv::Op::OpBitcast:
871+
return ValidateBitcast(_, inst, 3);
863872
default:
864873
break;
865874
}

source/val/validate_logicals.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,38 +92,42 @@ spv_result_t ValidateFloatCompare(ValidationState_t& _,
9292
}
9393

9494
spv_result_t ValidateLogicalCompare(ValidationState_t& _,
95-
const Instruction* inst) {
95+
const Instruction* inst,
96+
uint32_t operand_index = 2) {
9697
const spv::Op opcode = inst->opcode();
9798
const uint32_t result_type = inst->type_id();
9899
if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type))
99100
return _.diag(SPV_ERROR_INVALID_DATA, inst)
100101
<< "Expected bool scalar or vector type as Result Type: "
101102
<< spvOpcodeString(opcode);
102103

103-
if (result_type != _.GetOperandTypeId(inst, 2) ||
104-
result_type != _.GetOperandTypeId(inst, 3))
104+
const uint32_t operand_1 = _.GetOperandTypeId(inst, operand_index);
105+
const uint32_t operand_2 = _.GetOperandTypeId(inst, operand_index + 1);
106+
if (result_type != operand_1 || result_type != operand_2)
105107
return _.diag(SPV_ERROR_INVALID_DATA, inst)
106108
<< "Expected both operands to be of Result Type: "
107109
<< spvOpcodeString(opcode);
108110
return SPV_SUCCESS;
109111
}
110112

111-
spv_result_t ValidateLogicalNot(ValidationState_t& _, const Instruction* inst) {
113+
spv_result_t ValidateLogicalNot(ValidationState_t& _, const Instruction* inst,
114+
uint32_t operand_index = 2) {
112115
const spv::Op opcode = inst->opcode();
113116
const uint32_t result_type = inst->type_id();
114117
if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type))
115118
return _.diag(SPV_ERROR_INVALID_DATA, inst)
116119
<< "Expected bool scalar or vector type as Result Type: "
117120
<< spvOpcodeString(opcode);
118121

119-
if (result_type != _.GetOperandTypeId(inst, 2))
122+
if (result_type != _.GetOperandTypeId(inst, operand_index))
120123
return _.diag(SPV_ERROR_INVALID_DATA, inst)
121124
<< "Expected operand to be of Result Type: "
122125
<< spvOpcodeString(opcode);
123126
return SPV_SUCCESS;
124127
}
125128

126-
spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst) {
129+
spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst,
130+
uint32_t operand_index = 2) {
127131
const spv::Op opcode = inst->opcode();
128132
const uint32_t result_type = inst->type_id();
129133
uint32_t dimension = 1;
@@ -186,9 +190,9 @@ spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst) {
186190
return fail();
187191
}
188192

189-
const uint32_t condition_type = _.GetOperandTypeId(inst, 2);
190-
const uint32_t left_type = _.GetOperandTypeId(inst, 3);
191-
const uint32_t right_type = _.GetOperandTypeId(inst, 4);
193+
const uint32_t condition_type = _.GetOperandTypeId(inst, operand_index);
194+
const uint32_t left_type = _.GetOperandTypeId(inst, operand_index + 1);
195+
const uint32_t right_type = _.GetOperandTypeId(inst, operand_index + 2);
192196

193197
if (!condition_type || (!_.IsBoolScalarType(condition_type) &&
194198
!_.IsBoolVectorType(condition_type)))
@@ -216,16 +220,17 @@ spv_result_t ValidateSelect(ValidationState_t& _, const Instruction* inst) {
216220
return SPV_SUCCESS;
217221
}
218222

219-
spv_result_t ValidateIntCompare(ValidationState_t& _, const Instruction* inst) {
223+
spv_result_t ValidateIntCompare(ValidationState_t& _, const Instruction* inst,
224+
uint32_t operand_index = 2) {
220225
const spv::Op opcode = inst->opcode();
221226
const uint32_t result_type = inst->type_id();
222227
if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type))
223228
return _.diag(SPV_ERROR_INVALID_DATA, inst)
224229
<< "Expected bool scalar or vector type as Result Type: "
225230
<< spvOpcodeString(opcode);
226231

227-
const uint32_t left_type = _.GetOperandTypeId(inst, 2);
228-
const uint32_t right_type = _.GetOperandTypeId(inst, 3);
232+
const uint32_t left_type = _.GetOperandTypeId(inst, operand_index);
233+
const uint32_t right_type = _.GetOperandTypeId(inst, operand_index + 1);
229234

230235
if (!left_type ||
231236
(!_.IsIntScalarType(left_type) && !_.IsIntVectorType(left_type)))
@@ -305,6 +310,35 @@ spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) {
305310
case spv::Op::OpSLessThan:
306311
case spv::Op::OpSLessThanEqual:
307312
return ValidateIntCompare(_, inst);
313+
314+
case spv::Op::OpSpecConstantOp: {
315+
switch (inst->GetOperandAs<spv::Op>(2u)) {
316+
case spv::Op::OpLogicalEqual:
317+
case spv::Op::OpLogicalNotEqual:
318+
case spv::Op::OpLogicalOr:
319+
case spv::Op::OpLogicalAnd:
320+
return ValidateLogicalCompare(_, inst, 3);
321+
case spv::Op::OpLogicalNot:
322+
return ValidateLogicalNot(_, inst, 3);
323+
case spv::Op::OpSelect:
324+
return ValidateSelect(_, inst, 3);
325+
case spv::Op::OpIEqual:
326+
case spv::Op::OpINotEqual:
327+
case spv::Op::OpUGreaterThan:
328+
case spv::Op::OpUGreaterThanEqual:
329+
case spv::Op::OpULessThan:
330+
case spv::Op::OpULessThanEqual:
331+
case spv::Op::OpSGreaterThan:
332+
case spv::Op::OpSGreaterThanEqual:
333+
case spv::Op::OpSLessThan:
334+
case spv::Op::OpSLessThanEqual:
335+
return ValidateIntCompare(_, inst, 3);
336+
default:
337+
break;
338+
}
339+
break;
340+
}
341+
308342
default:
309343
break;
310344
}

source/val/validate_memory.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,8 +2280,10 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
22802280
return SPV_SUCCESS;
22812281
}
22822282

2283-
spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
2284-
const Instruction* inst) {
2283+
spv_result_t ValidateCooperativeMatrixLength(ValidationState_t& state,
2284+
const Instruction* inst,
2285+
bool is_khr,
2286+
uint32_t operand_index = 2) {
22852287
const spv::Op opcode = inst->opcode();
22862288
// Result type must be a 32-bit unsigned int.
22872289
const uint32_t result_type_id = inst->type_id();
@@ -2293,15 +2295,14 @@ spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
22932295
<< " must be OpTypeInt with width 32 and signedness 0.";
22942296
}
22952297

2296-
bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
2297-
auto type_id = inst->GetOperandAs<uint32_t>(2);
2298+
auto type_id = inst->GetOperandAs<uint32_t>(operand_index);
22982299
auto type = state.FindDef(type_id);
2299-
if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
2300+
if (is_khr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
23002301
return state.diag(SPV_ERROR_INVALID_ID, inst)
23012302
<< "The type in Op" << spvOpcodeString(opcode) << " <id> "
23022303
<< state.getIdName(type_id)
23032304
<< " must be OpTypeCooperativeMatrixKHR.";
2304-
} else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
2305+
} else if (!is_khr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
23052306
return state.diag(SPV_ERROR_INVALID_ID, inst)
23062307
<< "The type in Op" << spvOpcodeString(opcode) << " <id> "
23072308
<< state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
@@ -3391,8 +3392,9 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
33913392
case spv::Op::OpCooperativeMatrixStoreNV:
33923393
return ValidateCooperativeMatrixLoadStoreNV(_, inst);
33933394
case spv::Op::OpCooperativeMatrixLengthKHR:
3395+
return ValidateCooperativeMatrixLength(_, inst, true);
33943396
case spv::Op::OpCooperativeMatrixLengthNV:
3395-
return ValidateCooperativeMatrixLengthNV(_, inst);
3397+
return ValidateCooperativeMatrixLength(_, inst, false);
33963398
case spv::Op::OpCooperativeMatrixLoadKHR:
33973399
case spv::Op::OpCooperativeMatrixStoreKHR:
33983400
return ValidateCooperativeMatrixLoadStoreKHR(_, inst);
@@ -3415,6 +3417,18 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
34153417
return ValidatePtrComparison(_, inst);
34163418
case spv::Op::OpImageTexelPointer:
34173419
case spv::Op::OpGenericPtrMemSemantics:
3420+
3421+
case spv::Op::OpSpecConstantOp: {
3422+
switch (inst->GetOperandAs<spv::Op>(2u)) {
3423+
case spv::Op::OpCooperativeMatrixLengthKHR:
3424+
return ValidateCooperativeMatrixLength(_, inst, true, 3);
3425+
case spv::Op::OpCooperativeMatrixLengthNV:
3426+
return ValidateCooperativeMatrixLength(_, inst, false, 3);
3427+
default:
3428+
break;
3429+
}
3430+
}
3431+
34183432
default:
34193433
break;
34203434
}

0 commit comments

Comments
 (0)