Skip to content

Commit 4c2ec2a

Browse files
authored
spirv-val: add validation for SPV_INTEL_predicated_io (#6665)
Add spirv-val validation for `OpPredicatedLoadINTEL` and `OpPredicatedStoreINTEL` instructions introduced by the `SPV_INTEL_predicated_io` extension (KhronosGroup/SPIRV-Registry@7be4570): - Result/Object type must be scalar or vector of numerical type - Predicate must be a Boolean scalar - Default Value type must match Result Type (for load) - Pointer type must match Result/Object type (for typed pointers) - Volatile memory operand is not allowed Co-Authored-By: Claude (commercial)
1 parent c8bda96 commit 4c2ec2a

4 files changed

Lines changed: 496 additions & 4 deletions

File tree

source/val/validate_logical_pointers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ spv_result_t ValidateLogicalPointerOperands(ValidationState_t& _,
305305
case spv::Op::OpNodePayloadArrayLengthAMDX:
306306
case spv::Op::OpIsNodePayloadValidAMDX:
307307
case spv::Op::OpFinishWritingNodePayloadAMDX:
308+
// SPV_INTEL_predicated_io
309+
case spv::Op::OpPredicatedLoadINTEL:
310+
case spv::Op::OpPredicatedStoreINTEL:
308311
// SPV_ARM_graph
309312
case spv::Op::OpGraphEntryPointARM:
310313
return SPV_SUCCESS;

source/val/validate_memory.cpp

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ std::pair<Instruction*, Instruction*> GetPointerTypes(ValidationState_t& _,
207207
case spv::Op::OpCooperativeMatrixLoadTensorNV:
208208
case spv::Op::OpCooperativeMatrixLoadKHR:
209209
case spv::Op::OpCooperativeVectorLoadNV:
210-
case spv::Op::OpLoad: {
210+
case spv::Op::OpLoad:
211+
case spv::Op::OpPredicatedLoadINTEL: {
211212
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
212213
dst_pointer_type = _.FindDef(load_pointer->type_id());
213214
break;
@@ -216,7 +217,8 @@ std::pair<Instruction*, Instruction*> GetPointerTypes(ValidationState_t& _,
216217
case spv::Op::OpCooperativeMatrixStoreTensorNV:
217218
case spv::Op::OpCooperativeMatrixStoreKHR:
218219
case spv::Op::OpCooperativeVectorStoreNV:
219-
case spv::Op::OpStore: {
220+
case spv::Op::OpStore:
221+
case spv::Op::OpPredicatedStoreINTEL: {
220222
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
221223
dst_pointer_type = _.FindDef(store_pointer->type_id());
222224
break;
@@ -315,7 +317,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
315317
inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
316318
inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV ||
317319
inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR ||
318-
inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
320+
inst->opcode() == spv::Op::OpCooperativeVectorLoadNV ||
321+
inst->opcode() == spv::Op::OpPredicatedLoadINTEL) {
319322
return _.diag(SPV_ERROR_INVALID_ID, inst)
320323
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
321324
}
@@ -337,7 +340,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
337340
inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV ||
338341
inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR ||
339342
inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV ||
340-
inst->opcode() == spv::Op::OpCooperativeVectorStoreNV) {
343+
inst->opcode() == spv::Op::OpCooperativeVectorStoreNV ||
344+
inst->opcode() == spv::Op::OpPredicatedStoreINTEL) {
341345
return _.diag(SPV_ERROR_INVALID_ID, inst)
342346
<< "MakePointerVisibleKHR cannot be used with OpStore.";
343347
}
@@ -3475,6 +3479,154 @@ spv_result_t ValidatePtrComparison(ValidationState_t& _,
34753479
return SPV_SUCCESS;
34763480
}
34773481

3482+
spv_result_t ValidatePredicatedLoadINTEL(ValidationState_t& _,
3483+
const Instruction* inst) {
3484+
const auto result_type_id = inst->type_id();
3485+
if (!_.IsIntScalarOrVectorType(result_type_id) &&
3486+
!_.IsFloatScalarOrVectorType(result_type_id)) {
3487+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3488+
<< "OpPredicatedLoadINTEL Result Type <id> "
3489+
<< _.getIdName(result_type_id)
3490+
<< " must be a scalar or vector of numerical type.";
3491+
}
3492+
3493+
const auto pointer_id = inst->GetOperandAs<uint32_t>(2);
3494+
const auto pointer = _.FindDef(pointer_id);
3495+
if (!pointer ||
3496+
((_.addressing_model() == spv::AddressingModel::Logical) &&
3497+
((!_.features().variable_pointers &&
3498+
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
3499+
(_.features().variable_pointers &&
3500+
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
3501+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3502+
<< "OpPredicatedLoadINTEL Pointer <id> " << _.getIdName(pointer_id)
3503+
<< " is not a logical pointer.";
3504+
}
3505+
3506+
const auto pointer_type = _.FindDef(pointer->type_id());
3507+
if (!pointer_type ||
3508+
(pointer_type->opcode() != spv::Op::OpTypePointer &&
3509+
pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
3510+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3511+
<< "OpPredicatedLoadINTEL type for pointer <id> "
3512+
<< _.getIdName(pointer_id) << " is not a pointer type.";
3513+
}
3514+
3515+
if (pointer_type->opcode() == spv::Op::OpTypePointer) {
3516+
const auto pointee_type =
3517+
_.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
3518+
if (!pointee_type || result_type_id != pointee_type->id()) {
3519+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3520+
<< "OpPredicatedLoadINTEL Result Type <id> "
3521+
<< _.getIdName(result_type_id) << " does not match Pointer <id> "
3522+
<< _.getIdName(pointer->id()) << "s type.";
3523+
}
3524+
}
3525+
3526+
const auto predicate_id = inst->GetOperandAs<uint32_t>(3);
3527+
const auto predicate = _.FindDef(predicate_id);
3528+
if (!predicate || !_.IsBoolScalarType(predicate->type_id())) {
3529+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3530+
<< "OpPredicatedLoadINTEL Predicate <id> "
3531+
<< _.getIdName(predicate_id) << " must be a Boolean scalar.";
3532+
}
3533+
3534+
const auto default_value_id = inst->GetOperandAs<uint32_t>(4);
3535+
const auto default_value = _.FindDef(default_value_id);
3536+
if (!default_value || default_value->type_id() != result_type_id) {
3537+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3538+
<< "OpPredicatedLoadINTEL Default Value <id> "
3539+
<< _.getIdName(default_value_id)
3540+
<< " type does not match Result Type.";
3541+
}
3542+
3543+
if (inst->operands().size() > 5) {
3544+
const auto mask = inst->GetOperandAs<uint32_t>(5);
3545+
if (mask & uint32_t(spv::MemoryAccessMask::Volatile)) {
3546+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3547+
<< "OpPredicatedLoadINTEL does not allow the Volatile memory "
3548+
"operand.";
3549+
}
3550+
}
3551+
3552+
if (auto error = CheckMemoryAccess(_, inst, 5)) return error;
3553+
3554+
return SPV_SUCCESS;
3555+
}
3556+
3557+
spv_result_t ValidatePredicatedStoreINTEL(ValidationState_t& _,
3558+
const Instruction* inst) {
3559+
const auto pointer_id = inst->GetOperandAs<uint32_t>(0);
3560+
const auto pointer = _.FindDef(pointer_id);
3561+
if (!pointer ||
3562+
(_.addressing_model() == spv::AddressingModel::Logical &&
3563+
((!_.features().variable_pointers &&
3564+
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
3565+
(_.features().variable_pointers &&
3566+
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
3567+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3568+
<< "OpPredicatedStoreINTEL Pointer <id> " << _.getIdName(pointer_id)
3569+
<< " is not a logical pointer.";
3570+
}
3571+
3572+
const auto pointer_type = _.FindDef(pointer->type_id());
3573+
if (!pointer_type ||
3574+
(pointer_type->opcode() != spv::Op::OpTypePointer &&
3575+
pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
3576+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3577+
<< "OpPredicatedStoreINTEL type for pointer <id> "
3578+
<< _.getIdName(pointer_id) << " is not a pointer type.";
3579+
}
3580+
3581+
const auto object_id = inst->GetOperandAs<uint32_t>(1);
3582+
const auto object = _.FindDef(object_id);
3583+
if (!object || !object->type_id()) {
3584+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3585+
<< "OpPredicatedStoreINTEL Object <id> " << _.getIdName(object_id)
3586+
<< " is not an object.";
3587+
}
3588+
3589+
const auto object_type_id = object->type_id();
3590+
if (!_.IsIntScalarOrVectorType(object_type_id) &&
3591+
!_.IsFloatScalarOrVectorType(object_type_id)) {
3592+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3593+
<< "OpPredicatedStoreINTEL Object <id> " << _.getIdName(object_id)
3594+
<< " type must be a scalar or vector of numerical type.";
3595+
}
3596+
3597+
if (pointer_type->opcode() == spv::Op::OpTypePointer) {
3598+
const auto pointee_type =
3599+
_.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
3600+
if (!pointee_type || pointee_type->id() != object_type_id) {
3601+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3602+
<< "OpPredicatedStoreINTEL Pointer <id> "
3603+
<< _.getIdName(pointer_id) << "s type does not match Object <id> "
3604+
<< _.getIdName(object->id()) << "s type.";
3605+
}
3606+
}
3607+
3608+
const auto predicate_id = inst->GetOperandAs<uint32_t>(2);
3609+
const auto predicate = _.FindDef(predicate_id);
3610+
if (!predicate || !_.IsBoolScalarType(predicate->type_id())) {
3611+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3612+
<< "OpPredicatedStoreINTEL Predicate <id> "
3613+
<< _.getIdName(predicate_id) << " must be a Boolean scalar.";
3614+
}
3615+
3616+
if (inst->operands().size() > 3) {
3617+
const auto mask = inst->GetOperandAs<uint32_t>(3);
3618+
if (mask & uint32_t(spv::MemoryAccessMask::Volatile)) {
3619+
return _.diag(SPV_ERROR_INVALID_ID, inst)
3620+
<< "OpPredicatedStoreINTEL does not allow the Volatile memory "
3621+
"operand.";
3622+
}
3623+
}
3624+
3625+
if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
3626+
3627+
return SPV_SUCCESS;
3628+
}
3629+
34783630
} // namespace
34793631

34803632
spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
@@ -3529,6 +3681,10 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
35293681
case spv::Op::OpCooperativeVectorMatrixMulNV:
35303682
case spv::Op::OpCooperativeVectorMatrixMulAddNV:
35313683
return ValidateCooperativeVectorMatrixMulNV(_, inst);
3684+
case spv::Op::OpPredicatedLoadINTEL:
3685+
return ValidatePredicatedLoadINTEL(_, inst);
3686+
case spv::Op::OpPredicatedStoreINTEL:
3687+
return ValidatePredicatedStoreINTEL(_, inst);
35323688
case spv::Op::OpPtrEqual:
35333689
case spv::Op::OpPtrNotEqual:
35343690
case spv::Op::OpPtrDiff:

test/val/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ add_spvtools_unittest(TARGET val_abcde
5353
val_extension_spv_intel_arbitrary_precision_integers_test.cpp
5454
val_extension_spv_intel_function_variants.cpp
5555
val_extension_spv_intel_inline_assembly.cpp
56+
val_extension_spv_intel_predicated_io_test.cpp
5657
val_extension_spv_ext_descriptor_heap.cpp
5758
val_ext_inst_test.cpp
5859
val_ext_inst_debug_test.cpp

0 commit comments

Comments
 (0)