Skip to content

Commit e4bceac

Browse files
spirv-val: Add remaining OpSpecConstantOp (KhronosGroup#6596)
My personal final PR for KhronosGroup#6564 There is still the access chains left, but those are going to be ugly, and since its only a `Kernel` feature, I don't have the bandwidth to look into right now. (But did left some test for the person who does in the future)
1 parent 989e29a commit e4bceac

3 files changed

Lines changed: 117 additions & 43 deletions

File tree

source/val/validate_composites.cpp

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
// Validates correctness of composite SPIR-V instructions.
1818

1919
#include <climits>
20+
#include <cstdint>
2021

2122
#include "source/opcode.h"
2223
#include "source/spirv_target_env.h"
@@ -36,14 +37,11 @@ namespace {
3637
// deep).
3738
spv_result_t GetExtractInsertValueType(ValidationState_t& _,
3839
const Instruction* inst,
39-
uint32_t* member_type) {
40-
const spv::Op opcode = inst->opcode();
41-
assert(opcode == spv::Op::OpCompositeExtract ||
42-
opcode == spv::Op::OpCompositeInsert);
43-
uint32_t word_index = opcode == spv::Op::OpCompositeExtract ? 4 : 5;
44-
const uint32_t num_words = static_cast<uint32_t>(inst->words().size());
45-
const uint32_t composite_id_index = word_index - 1;
46-
const uint32_t num_indices = num_words - word_index;
40+
uint32_t* member_type,
41+
uint32_t composite_id_index) {
42+
const uint32_t num_operands = static_cast<uint32_t>(inst->operands().size());
43+
const uint32_t first_literal_index = composite_id_index + 1;
44+
const uint32_t num_indices = num_operands - first_literal_index;
4745
const uint32_t kCompositeExtractInsertMaxNumIndices = 255;
4846

4947
if (num_indices == 0) {
@@ -53,19 +51,21 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
5351

5452
} else if (num_indices > kCompositeExtractInsertMaxNumIndices) {
5553
return _.diag(SPV_ERROR_INVALID_DATA, inst)
56-
<< "The number of indexes in Op" << spvOpcodeString(opcode)
54+
<< "The number of indexes in Op" << spvOpcodeString(inst->opcode())
5755
<< " may not exceed " << kCompositeExtractInsertMaxNumIndices
5856
<< ". Found " << num_indices << " indexes.";
5957
}
6058

61-
*member_type = _.GetTypeId(inst->word(composite_id_index));
59+
*member_type = _.GetOperandTypeId(inst, composite_id_index);
6260
if (*member_type == 0) {
6361
return _.diag(SPV_ERROR_INVALID_DATA, inst)
6462
<< "Expected Composite to be an object of composite type";
6563
}
6664

67-
for (; word_index < num_words; ++word_index) {
68-
const uint32_t component_index = inst->word(word_index);
65+
for (uint32_t operand_index = first_literal_index;
66+
operand_index < num_operands; ++operand_index) {
67+
const uint32_t component_index =
68+
inst->GetOperandAs<uint32_t>(operand_index);
6969
const Instruction* const type_inst = _.FindDef(*member_type);
7070
assert(type_inst);
7171
switch (type_inst->opcode()) {
@@ -471,9 +471,12 @@ spv_result_t ValidateCompositeConstructReplicate(ValidationState_t& _,
471471
}
472472

473473
spv_result_t ValidateCompositeExtract(ValidationState_t& _,
474-
const Instruction* inst) {
474+
const Instruction* inst,
475+
uint32_t operand_index = 2) {
475476
uint32_t member_type = 0;
476-
if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
477+
478+
if (spv_result_t error =
479+
GetExtractInsertValueType(_, inst, &member_type, operand_index)) {
477480
return error;
478481
}
479482

@@ -496,9 +499,10 @@ spv_result_t ValidateCompositeExtract(ValidationState_t& _,
496499
}
497500

498501
spv_result_t ValidateCompositeInsert(ValidationState_t& _,
499-
const Instruction* inst) {
500-
const uint32_t object_type = _.GetOperandTypeId(inst, 2);
501-
const uint32_t composite_type = _.GetOperandTypeId(inst, 3);
502+
const Instruction* inst,
503+
uint32_t operand_index = 2) {
504+
const uint32_t object_type = _.GetOperandTypeId(inst, operand_index);
505+
const uint32_t composite_type = _.GetOperandTypeId(inst, operand_index + 1);
502506
const uint32_t result_type = inst->type_id();
503507
if (result_type != composite_type) {
504508
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -508,7 +512,8 @@ spv_result_t ValidateCompositeInsert(ValidationState_t& _,
508512
}
509513

510514
uint32_t member_type = 0;
511-
if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
515+
if (spv_result_t error =
516+
GetExtractInsertValueType(_, inst, &member_type, operand_index + 1)) {
512517
return error;
513518
}
514519

@@ -589,58 +594,58 @@ spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
589594
}
590595

591596
spv_result_t ValidateVectorShuffle(ValidationState_t& _,
592-
const Instruction* inst) {
593-
auto resultType = _.FindDef(inst->type_id());
594-
if (!_.IsVectorType(resultType->id())) {
597+
const Instruction* inst,
598+
uint32_t operand_index = 2) {
599+
auto result_type = _.FindDef(inst->type_id());
600+
if (!_.IsVectorType(result_type->id())) {
595601
return _.diag(SPV_ERROR_INVALID_ID, inst)
596602
<< "The Result Type of OpVectorShuffle must be"
597603
<< " a vector type. Found Op"
598-
<< spvOpcodeString(resultType->opcode()) << ".";
604+
<< spvOpcodeString(result_type->opcode()) << ".";
599605
}
600606

601607
// The number of components in Result Type must be the same as the number of
602608
// Component operands.
603-
auto componentCount = inst->operands().size() - 4;
604-
auto resultVectorDimension = _.GetDimension(resultType->id());
605-
if (resultVectorDimension > 0 && componentCount != resultVectorDimension) {
609+
uint32_t first_literal_index = operand_index + 2;
610+
uint32_t component_count =
611+
static_cast<uint32_t>(inst->operands().size()) - first_literal_index;
612+
auto result_vec_dimension = _.GetDimension(result_type->id());
613+
if (result_vec_dimension > 0 && component_count != result_vec_dimension) {
606614
return _.diag(SPV_ERROR_INVALID_ID, inst)
607615
<< "OpVectorShuffle component literals count does not match "
608616
"Result Type <id> "
609-
<< _.getIdName(resultType->id()) << "s vector component count.";
617+
<< _.getIdName(result_type->id()) << "s vector component count.";
610618
}
611619

612620
// Vector 1 and Vector 2 must both have vector types, with the same Component
613621
// Type as Result Type.
614-
auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
615-
auto vector1Type = _.FindDef(vector1Object->type_id());
616-
auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
617-
auto vector2Type = _.FindDef(vector2Object->type_id());
618-
if (!_.IsVectorType(vector1Type->id())) {
622+
auto vec1_type = _.FindDef(_.GetOperandTypeId(inst, operand_index));
623+
auto vec2_type = _.FindDef(_.GetOperandTypeId(inst, operand_index + 1));
624+
if (!_.IsVectorType(vec1_type->id())) {
619625
return _.diag(SPV_ERROR_INVALID_ID, inst)
620626
<< "The type of Vector 1 must be a vector type.";
621627
}
622-
if (!_.IsVectorType(vector2Type->id())) {
628+
if (!_.IsVectorType(vec2_type->id())) {
623629
return _.diag(SPV_ERROR_INVALID_ID, inst)
624630
<< "The type of Vector 2 must be a vector type.";
625631
}
626632

627-
auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
628-
if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
633+
uint32_t result_component_type = result_type->GetOperandAs<uint32_t>(1);
634+
if (vec1_type->GetOperandAs<uint32_t>(1) != result_component_type) {
629635
return _.diag(SPV_ERROR_INVALID_ID, inst)
630636
<< "The Component Type of Vector 1 must be the same as ResultType.";
631637
}
632-
if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
638+
if (vec2_type->GetOperandAs<uint32_t>(1) != result_component_type) {
633639
return _.diag(SPV_ERROR_INVALID_ID, inst)
634640
<< "The Component Type of Vector 2 must be the same as ResultType.";
635641
}
636642

637643
// All Component literals must either be FFFFFFFF or in [0, N - 1].
638-
auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
639-
auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
640-
auto N = vector1ComponentCount + vector2ComponentCount;
641-
auto firstLiteralIndex = 4;
642-
for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
643-
auto literal = inst->GetOperandAs<uint32_t>(i);
644+
uint32_t vec1_component_count = vec1_type->GetOperandAs<uint32_t>(2);
645+
uint32_t vec2_component_count = vec2_type->GetOperandAs<uint32_t>(2);
646+
uint32_t N = vec1_component_count + vec2_component_count;
647+
for (size_t i = first_literal_index; i < inst->operands().size(); ++i) {
648+
uint32_t literal = inst->GetOperandAs<uint32_t>(i);
644649
if (literal != 0xFFFFFFFF && literal >= N) {
645650
return _.diag(SPV_ERROR_INVALID_ID, inst)
646651
<< "Component index " << literal << " is out of bounds for "
@@ -1168,6 +1173,20 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
11681173
return ValidateCompositeExtractCoopMatQCOM(_, inst);
11691174
case spv::Op::OpExtractSubArrayQCOM:
11701175
return ValidateExtractSubArrayQCOM(_, inst);
1176+
1177+
case spv::Op::OpSpecConstantOp: {
1178+
switch (inst->GetOperandAs<spv::Op>(2u)) {
1179+
case spv::Op::OpVectorShuffle:
1180+
return ValidateVectorShuffle(_, inst, 3);
1181+
case spv::Op::OpCompositeExtract:
1182+
return ValidateCompositeExtract(_, inst, 3);
1183+
case spv::Op::OpCompositeInsert:
1184+
return ValidateCompositeInsert(_, inst, 3);
1185+
default:
1186+
break;
1187+
}
1188+
}
1189+
11711190
default:
11721191
break;
11731192
}

source/val/validate_memory.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3535,13 +3535,14 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
35353535
return ValidatePtrComparison(_, inst);
35363536
case spv::Op::OpImageTexelPointer:
35373537
case spv::Op::OpGenericPtrMemSemantics:
3538-
3538+
break; // no validation currently
35393539
case spv::Op::OpSpecConstantOp: {
35403540
switch (inst->GetOperandAs<spv::Op>(2u)) {
35413541
case spv::Op::OpCooperativeMatrixLengthKHR:
35423542
return ValidateCooperativeMatrixLength(_, inst, true, 3);
35433543
case spv::Op::OpCooperativeMatrixLengthNV:
35443544
return ValidateCooperativeMatrixLength(_, inst, false, 3);
3545+
// TODO - Add AccesChains
35453546
default:
35463547
break;
35473548
}

test/val/val_constants_test.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,12 +772,16 @@ OpMemoryModel Logical VulkanKHR
772772
HasSubstr("must be OpTypeCooperativeMatrixKHR"));
773773
}
774774

775+
// Some check use SPV_ERROR_INVALID_DATA vs SPV_ERROR_INVALID_ID
775776
#define BAD_KERNEL_OPERANDS(STR, ERR) \
776777
{ \
777778
SPV_ENV_UNIVERSAL_1_0, kKernelPreamble kBasicTypes STR, false, ERR, \
778779
SPV_ERROR_INVALID_DATA \
779780
}
780781

782+
#define BAD_KERNEL_OPERANDS_ID(STR, ERR) \
783+
{ SPV_ENV_UNIVERSAL_1_0, kKernelPreamble kBasicTypes STR, false, ERR, }
784+
781785
// 2 of each, first has bad return type, second has bad operand
782786
INSTANTIATE_TEST_SUITE_P(
783787
BadOperandsKernel, ValidateConstantOp,
@@ -1079,7 +1083,57 @@ INSTANTIATE_TEST_SUITE_P(
10791083
"float vector or scalar type"),
10801084
BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %uint Bitcast %true",
10811085
"Expected input to be a pointer or int or float "
1082-
"vector or scalar")}));
1086+
"vector or scalar"),
1087+
1088+
BAD_KERNEL_OPERANDS_ID(
1089+
"%v = OpSpecConstantOp %float VectorShuffle %uint2_0 %uint2_0 1 3",
1090+
"The Result Type of OpVectorShuffle must be a vector type"),
1091+
BAD_KERNEL_OPERANDS_ID(
1092+
"%v = OpSpecConstantOp %uint2 VectorShuffle %uint2_0 %uint_0 1 3",
1093+
"The type of Vector 2 must be a vector type"),
1094+
BAD_KERNEL_OPERANDS(
1095+
"%v = OpSpecConstantOp %float CompositeExtract %uint2_0 1",
1096+
"Result type (OpTypeFloat) does not match the type that results "
1097+
"from indexing into the composite (OpTypeInt)"),
1098+
BAD_KERNEL_OPERANDS(
1099+
"%v = OpSpecConstantOp %uint CompositeExtract %uint_0 1",
1100+
"Reached non-composite type while indexes still remain to be "
1101+
"traversed"),
1102+
BAD_KERNEL_OPERANDS(
1103+
"%v = OpSpecConstantOp %float CompositeInsert %uint_0 %uint2_0 1",
1104+
"The Result Type must be the same as Composite type in "
1105+
"OpSpecConstantOp yielding Result Id 5"),
1106+
BAD_KERNEL_OPERANDS(
1107+
"%v = OpSpecConstantOp %uint2 CompositeInsert %uint_0 %uint_0 1",
1108+
"The Result Type must be the same as Composite type in "
1109+
"OpSpecConstantOp yielding Result Id 4"),
1110+
1111+
// TODO - Still need to add access chains
1112+
//
1113+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %uint AccessChain %null",
1114+
// "AccessChain"),
1115+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %_ptr_uint AccessChain
1116+
// %null %float_0",
1117+
// "AccessChain"),
1118+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %uint InBoundsAccessChain
1119+
// %null",
1120+
// "InBoundsAccessChain"),
1121+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %_ptr_uint
1122+
// InBoundsAccessChain %null %float_0",
1123+
// "InBoundsAccessChain"),
1124+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %uint PtrAccessChain %null
1125+
// %uint_0",
1126+
// "PtrAccessChain"),
1127+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %_ptr_uint PtrAccessChain
1128+
// %float_0 %float_0",
1129+
// "PtrAccessChain"),
1130+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %uint
1131+
// InBoundsPtrAccessChain %null %uint_0",
1132+
// "InBoundsPtrAccessChain"),
1133+
// BAD_KERNEL_OPERANDS("%v = OpSpecConstantOp %_ptr_uint
1134+
// InBoundsPtrAccessChain %float_0 %float_0",
1135+
// "InBoundsPtrAccessChain"),
1136+
}));
10831137

10841138
} // namespace
10851139
} // namespace val

0 commit comments

Comments
 (0)