Skip to content

Commit c75e349

Browse files
authored
Add validation for SPV_EXT_replicated_composites (KhronosGroup#6583)
Fixes KhronosGroup#6529 * OpConstantCompositeReplicateEXT * OpSpecConstantCompositeReplicateEXT * OpCompositeConstructReplicateEXT
1 parent 85dc1ee commit c75e349

4 files changed

Lines changed: 486 additions & 2 deletions

File tree

source/val/validate_composites.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,67 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
409409
return SPV_SUCCESS;
410410
}
411411

412+
spv_result_t ValidateCompositeConstructReplicate(ValidationState_t& _,
413+
const Instruction* inst) {
414+
const auto result_type = _.FindDef(inst->type_id());
415+
const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
416+
417+
switch (result_type->opcode()) {
418+
case spv::Op::OpTypeVector:
419+
case spv::Op::OpTypeVectorIdEXT:
420+
case spv::Op::OpTypeMatrix:
421+
case spv::Op::OpTypeArray:
422+
case spv::Op::OpTypeCooperativeMatrixKHR:
423+
case spv::Op::OpTypeCooperativeMatrixNV: {
424+
const auto element_type = result_type->GetOperandAs<uint32_t>(1);
425+
if (operand_type != element_type) {
426+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
427+
<< "Expected Value type to be equal to the "
428+
<< "result's element type";
429+
}
430+
break;
431+
}
432+
case spv::Op::OpTypeStruct: {
433+
for (uint32_t operand_index = 1;
434+
operand_index < result_type->operands().size(); ++operand_index) {
435+
const uint32_t member_type =
436+
result_type->GetOperandAs<uint32_t>(operand_index);
437+
if (operand_type != member_type) {
438+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
439+
<< "Expected Value type to be equal to the "
440+
<< "corresponding member type of the result";
441+
}
442+
}
443+
break;
444+
}
445+
case spv::Op::OpTypeTensorARM: {
446+
const uint32_t component_type = result_type->GetOperandAs<uint32_t>(1);
447+
if (operand_type != component_type) {
448+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
449+
<< "Expected Value type to be equal to the result's element "
450+
"type";
451+
}
452+
if (result_type->operands().size() <= 3) {
453+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
454+
<< "Result tensor type is not a composite type because it lacks "
455+
"a shape operand";
456+
}
457+
break;
458+
}
459+
default: {
460+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
461+
<< "Expected Result Type to be a composite type";
462+
}
463+
}
464+
465+
if (_.HasCapability(spv::Capability::Shader) &&
466+
_.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
467+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
468+
<< "Cannot create a composite containing 8- or 16-bit types";
469+
}
470+
return SPV_SUCCESS;
471+
}
472+
412473
spv_result_t ValidateCompositeExtract(ValidationState_t& _,
413474
const Instruction* inst) {
414475
uint32_t member_type = 0;
@@ -1089,6 +1150,8 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
10891150
return ValidateVectorShuffle(_, inst);
10901151
case spv::Op::OpCompositeConstruct:
10911152
return ValidateCompositeConstruct(_, inst);
1153+
case spv::Op::OpCompositeConstructReplicateEXT:
1154+
return ValidateCompositeConstructReplicate(_, inst);
10921155
case spv::Op::OpCompositeExtract:
10931156
return ValidateCompositeExtract(_, inst);
10941157
case spv::Op::OpCompositeInsert:

source/val/validate_constants.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,65 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
430430
return SPV_SUCCESS;
431431
}
432432

433+
spv_result_t ValidateConstantCompositeReplicate(ValidationState_t& _,
434+
const Instruction* inst) {
435+
std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
436+
437+
const auto result_type = _.FindDef(inst->type_id());
438+
if (!result_type || !isCompositeType(result_type)) {
439+
return _.diag(SPV_ERROR_INVALID_ID, inst)
440+
<< opcode_name << " Result Type <id> "
441+
<< _.getIdName(inst->type_id()) << " is not a composite type.";
442+
}
443+
444+
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
445+
const auto constituent = _.FindDef(constituent_id);
446+
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
447+
return _.diag(SPV_ERROR_INVALID_ID, inst)
448+
<< opcode_name << " Constituent <id> " << _.getIdName(constituent_id)
449+
<< " is not a constant or undef.";
450+
}
451+
452+
switch (result_type->opcode()) {
453+
case spv::Op::OpTypeVector:
454+
case spv::Op::OpTypeVectorIdEXT:
455+
case spv::Op::OpTypeMatrix:
456+
case spv::Op::OpTypeArray:
457+
case spv::Op::OpTypeCooperativeMatrixKHR:
458+
case spv::Op::OpTypeCooperativeMatrixNV:
459+
case spv::Op::OpTypeTensorARM: {
460+
const auto component_type = result_type->GetOperandAs<uint32_t>(1);
461+
if (component_type != constituent->type_id()) {
462+
return _.diag(SPV_ERROR_INVALID_ID, inst)
463+
<< opcode_name << " Constituent <id> "
464+
<< _.getIdName(constituent_id)
465+
<< "s type does not match Result Type <id> "
466+
<< _.getIdName(result_type->id()) << "s element type.";
467+
}
468+
break;
469+
}
470+
case spv::Op::OpTypeStruct: {
471+
const auto member_count = result_type->operands().size() - 1;
472+
for (uint32_t member_index = 1; member_index <= member_count;
473+
member_index++) {
474+
const auto member_type_id =
475+
result_type->GetOperandAs<uint32_t>(member_index);
476+
if (member_type_id != constituent->type_id()) {
477+
return _.diag(SPV_ERROR_INVALID_ID, inst)
478+
<< opcode_name << " Constituent <id> "
479+
<< _.getIdName(constituent_id)
480+
<< " type does not match the Result Type <id> "
481+
<< _.getIdName(result_type->id()) << "s member type.";
482+
}
483+
}
484+
break;
485+
}
486+
default:
487+
break;
488+
}
489+
return SPV_SUCCESS;
490+
}
491+
433492
spv_result_t ValidateConstantSampler(ValidationState_t& _,
434493
const Instruction* inst) {
435494
const auto result_type = _.FindDef(inst->type_id());
@@ -617,6 +676,11 @@ spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
617676
case spv::Op::OpSpecConstantComposite:
618677
if (auto error = ValidateConstantComposite(_, inst)) return error;
619678
break;
679+
case spv::Op::OpConstantCompositeReplicateEXT:
680+
case spv::Op::OpSpecConstantCompositeReplicateEXT:
681+
if (auto error = ValidateConstantCompositeReplicate(_, inst))
682+
return error;
683+
break;
620684
case spv::Op::OpConstantSampler:
621685
if (auto error = ValidateConstantSampler(_, inst)) return error;
622686
break;

test/val/val_composites_test.cpp

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ using ValidateComposites = spvtest::ValidateBase<bool>;
3333
std::string GenerateShaderCode(
3434
const std::string& body,
3535
const std::string& capabilities_and_extensions = "",
36-
const std::string& execution_model = "Fragment") {
36+
const std::string& execution_model = "Fragment",
37+
const std::string& extra_types = "",
38+
const std::string& memory_model = "Logical GLSL450") {
3739
std::ostringstream ss;
3840
ss << R"(
3941
OpCapability Shader
4042
OpCapability Float64
4143
)";
4244

4345
ss << capabilities_and_extensions;
44-
ss << "OpMemoryModel Logical GLSL450\n";
46+
ss << "OpMemoryModel " << memory_model << "\n";
4547
ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
4648
if (execution_model == "Fragment") {
4749
ss << "OpExecutionMode %main OriginUpperLeft\n";
@@ -93,7 +95,11 @@ OpCapability Float64
9395
9496
%ptr_big_struct = OpTypePointer Uniform %big_struct
9597
%var_big_struct = OpVariable %ptr_big_struct Uniform
98+
)";
99+
100+
ss << extra_types;
96101

102+
ss << R"(
97103
%main = OpFunction %void None %func
98104
%main_entry = OpLabel
99105
)";
@@ -536,6 +542,201 @@ TEST_F(ValidateComposites, CompositeConstructStructWrongConstituent) {
536542
"corresponding member type of Result Type struct"));
537543
}
538544

545+
TEST_F(ValidateComposites, CompositeConstructReplicateVectorGood) {
546+
const std::string body = R"(
547+
%val1 = OpCompositeConstructReplicateEXT %f32vec4 %f32_0
548+
)";
549+
550+
CompileSuccessfully(
551+
GenerateShaderCode(body,
552+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
553+
"\"SPV_EXT_replicated_composites\"\n")
554+
.c_str());
555+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
556+
}
557+
558+
TEST_F(ValidateComposites, CompositeConstructReplicateMatrixGood) {
559+
const std::string body = R"(
560+
%val1 = OpCompositeConstructReplicateEXT %f32mat22 %f32vec2_01
561+
)";
562+
563+
CompileSuccessfully(
564+
GenerateShaderCode(body,
565+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
566+
"\"SPV_EXT_replicated_composites\"\n",
567+
"Fragment")
568+
.c_str());
569+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
570+
}
571+
572+
TEST_F(ValidateComposites, CompositeConstructReplicateArrayGood) {
573+
const std::string body = R"(
574+
%val1 = OpCompositeConstructReplicateEXT %f32vec2arr3 %f32vec2_12
575+
)";
576+
577+
CompileSuccessfully(
578+
GenerateShaderCode(body,
579+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
580+
"\"SPV_EXT_replicated_composites\"\n")
581+
.c_str());
582+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
583+
}
584+
585+
TEST_F(ValidateComposites, CompositeConstructReplicateStructGood) {
586+
const std::string copy_types = R"(
587+
%f32struct = OpTypeStruct %f32 %f32 %f32
588+
)";
589+
590+
const std::string body = R"(
591+
%val1 = OpCompositeConstructReplicateEXT %f32struct %f32_0
592+
)";
593+
594+
CompileSuccessfully(
595+
GenerateShaderCode(body,
596+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
597+
"\"SPV_EXT_replicated_composites\"\n",
598+
"Fragment", copy_types)
599+
.c_str());
600+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
601+
}
602+
603+
TEST_F(ValidateComposites, CompositeConstructReplicateCoopMatGood) {
604+
const std::string extra_types = R"(
605+
%u32_8 = OpConstant %u32 8
606+
%u32_16 = OpConstant %u32 16
607+
%subgroup = OpConstant %u32 3
608+
%useA = OpConstant %u32 0
609+
%f32mat_nv = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
610+
%f32mat_khr = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useA
611+
)";
612+
613+
const std::string body = R"(
614+
%val1 = OpCompositeConstructReplicateEXT %f32mat_nv %f32_0
615+
%val2 = OpCompositeConstructReplicateEXT %f32mat_khr %f32_0
616+
)";
617+
618+
CompileSuccessfully(
619+
GenerateShaderCode(body,
620+
"OpCapability ReplicatedCompositesEXT\n"
621+
"OpCapability CooperativeMatrixNV\n"
622+
"OpCapability CooperativeMatrixKHR\n"
623+
"OpCapability VulkanMemoryModel\n"
624+
"OpCapability Float16\n"
625+
"OpExtension \"SPV_EXT_replicated_composites\"\n"
626+
"OpExtension \"SPV_NV_cooperative_matrix\"\n"
627+
"OpExtension \"SPV_KHR_cooperative_matrix\"\n"
628+
"OpExtension \"SPV_KHR_vulkan_memory_model\"\n",
629+
"Fragment", extra_types, "Logical Vulkan")
630+
.c_str(),
631+
SPV_ENV_UNIVERSAL_1_3);
632+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
633+
}
634+
635+
TEST_F(ValidateComposites, CompositeConstructReplicateTensorGood) {
636+
const std::string extra_types = R"(
637+
%arr = OpTypeArray %u32 %u32_1
638+
%c_arr = OpConstantNull %arr
639+
%tensor = OpTypeTensorARM %f32 %u32_1 %c_arr
640+
)";
641+
642+
const std::string body = R"(
643+
%val1 = OpCompositeConstructReplicateEXT %tensor %f32_0
644+
)";
645+
646+
CompileSuccessfully(
647+
GenerateShaderCode(body,
648+
"OpCapability ReplicatedCompositesEXT\nOpCapability "
649+
"TensorsARM\nOpExtension "
650+
"\"SPV_EXT_replicated_composites\"\nOpExtension "
651+
"\"SPV_ARM_tensors\"\n",
652+
"Fragment", extra_types)
653+
.c_str());
654+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
655+
}
656+
657+
TEST_F(ValidateComposites, CompositeConstructReplicateCoopMatWrongOperand) {
658+
const std::string extra_types = R"(
659+
%u32_8 = OpConstant %u32 8
660+
%subgroup = OpConstant %u32 3
661+
%f32mat_nv = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
662+
)";
663+
664+
const std::string body = R"(
665+
%val1 = OpCompositeConstructReplicateEXT %f32mat_nv %u32_0
666+
)";
667+
668+
CompileSuccessfully(
669+
GenerateShaderCode(body,
670+
"OpCapability ReplicatedCompositesEXT\n"
671+
"OpCapability CooperativeMatrixNV\n"
672+
"OpCapability Float16\n"
673+
"OpExtension \"SPV_EXT_replicated_composites\"\n"
674+
"OpExtension \"SPV_NV_cooperative_matrix\"\n",
675+
"Fragment", extra_types)
676+
.c_str());
677+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
678+
EXPECT_THAT(
679+
getDiagnosticString(),
680+
HasSubstr(
681+
"Expected Value type to be equal to the result's element type"));
682+
}
683+
684+
TEST_F(ValidateComposites, CompositeConstructReplicateTensorWrongOperand) {
685+
const std::string extra_types = R"(
686+
%tensor = OpTypeTensorARM %f32
687+
)";
688+
689+
const std::string body = R"(
690+
%val1 = OpCompositeConstructReplicateEXT %tensor %u32_0
691+
)";
692+
693+
CompileSuccessfully(
694+
GenerateShaderCode(body,
695+
"OpCapability ReplicatedCompositesEXT\nOpCapability "
696+
"TensorsARM\nOpExtension "
697+
"\"SPV_EXT_replicated_composites\"\nOpExtension "
698+
"\"SPV_ARM_tensors\"\n",
699+
"Fragment", extra_types)
700+
.c_str());
701+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
702+
EXPECT_THAT(
703+
getDiagnosticString(),
704+
HasSubstr(
705+
"Expected Value type to be equal to the result's element type"));
706+
}
707+
708+
TEST_F(ValidateComposites, CompositeConstructReplicateWrongOperandType) {
709+
const std::string body = R"(
710+
%val1 = OpCompositeConstructReplicateEXT %f32vec4 %u32_0
711+
)";
712+
713+
CompileSuccessfully(
714+
GenerateShaderCode(body,
715+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
716+
"\"SPV_EXT_replicated_composites\"\n")
717+
.c_str());
718+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
719+
EXPECT_THAT(
720+
getDiagnosticString(),
721+
HasSubstr(
722+
"Expected Value type to be equal to the result's element type"));
723+
}
724+
725+
TEST_F(ValidateComposites, CompositeConstructReplicateNotComposite) {
726+
const std::string body = R"(
727+
%val1 = OpCompositeConstructReplicateEXT %f32 %f32_0
728+
)";
729+
730+
CompileSuccessfully(
731+
GenerateShaderCode(body,
732+
"OpCapability ReplicatedCompositesEXT\nOpExtension "
733+
"\"SPV_EXT_replicated_composites\"\n")
734+
.c_str());
735+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
736+
EXPECT_THAT(getDiagnosticString(),
737+
HasSubstr("Expected Result Type to be a composite type"));
738+
}
739+
539740
TEST_F(ValidateComposites, CopyObjectSuccess) {
540741
const std::string body = R"(
541742
%val1 = OpCopyObject %f32 %f32_0

0 commit comments

Comments
 (0)