Skip to content

Commit d83d936

Browse files
authored
val: more validation of OpTypeFloat (KhronosGroup#6568)
- add tests for too many arguments - check if 32-bit and 64-bit float types have an encoding arg - check encoding arg for 16 bit types
1 parent 46b1e00 commit d83d936

4 files changed

Lines changed: 255 additions & 36 deletions

File tree

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ cc_library(
672672
":val_test_lib",
673673
"@googletest//:gtest",
674674
"@googletest//:gtest_main",
675+
"@spirv_headers//:spirv_cpp11_headers",
675676
],
676677
) for f in glob(
677678
["test/val/val_*_test.cpp"],

source/val/validate_type.cpp

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
// Ensures type declarations are unique unless allowed by the specification.
1919

20+
#include <optional>
21+
2022
#include "source/opcode.h"
2123
#include "source/spirv_target_env.h"
2224
#include "source/val/instruction.h"
@@ -107,61 +109,78 @@ spv_result_t ValidateTypeInt(ValidationState_t& _, const Instruction* inst) {
107109
}
108110

109111
spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
110-
// Validates that the number of bits specified for an Int type is valid.
111-
// Scalar integer types can be parameterized only with 32-bits.
112-
// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
113-
// integers, respectively.
112+
// Validates:
113+
// - the number of bits specified for a float type is valid
114+
// - the fp encoding is valid, and only used on matching bit widths
115+
// - required capabilities are declared
114116
auto num_bits = inst->GetOperandAs<const uint32_t>(1);
115-
const bool has_encoding = inst->operands().size() > 2;
117+
118+
std::optional<spv::FPEncoding> encoding;
119+
if (inst->operands().size() > 2) {
120+
encoding = inst->GetOperandAs<spv::FPEncoding>(2);
121+
}
122+
// The number of operands is already checked by the grammar structure.
123+
// The fp encoding operand is an optional enum, and there are no further
124+
// operands.
125+
116126
if (num_bits == 32) {
127+
if (encoding.has_value()) {
128+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
129+
<< "32-bit floating point type must not have encoding parameter.";
130+
}
117131
return SPV_SUCCESS;
118132
}
119-
auto operands = inst->words();
120133

121134
if (num_bits == 16) {
122135
// An absence of FP encoding implies IEEE 754. The Float16 and Float16Buffer
123136
// capabilities only enable IEEE 754 binary 16
124-
if (has_encoding || _.features().declare_float16_type) {
125-
return SPV_SUCCESS;
137+
if (!encoding.has_value() && !_.features().declare_float16_type) {
138+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
139+
<< "Using a 16-bit floating point "
140+
<< "type requires the Float16 or Float16Buffer capability,"
141+
" or an extension that explicitly enables 16-bit floating "
142+
"point.";
126143
}
127-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
128-
<< "Using a 16-bit floating point "
129-
<< "type requires the Float16 or Float16Buffer capability,"
130-
" or an extension that explicitly enables 16-bit floating point.";
144+
if (encoding.has_value() &&
145+
encoding.value() != spv::FPEncoding::BFloat16KHR) {
146+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
147+
<< "Unsupported 16-bit floating point encoding ("
148+
<< static_cast<uint32_t>(encoding.value()) << ").";
149+
}
150+
return SPV_SUCCESS;
131151
}
132152
if (num_bits == 8) {
133153
if (!_.features().declare_float8_type) {
134154
return _.diag(SPV_ERROR_INVALID_DATA, inst)
135155
<< "Using a 8-bit floating point "
136156
<< "type requires the Float8EXT capability.";
137157
}
138-
if (!has_encoding) {
158+
if (encoding.has_value()) {
159+
const auto enc = encoding.value();
160+
if (enc != spv::FPEncoding::Float8E4M3EXT &&
161+
enc != spv::FPEncoding::Float8E5M2EXT) {
162+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
163+
<< "Unsupported 8-bit floating point encoding ("
164+
<< static_cast<uint32_t>(enc) << ").";
165+
}
166+
} else {
139167
// we don't support fp8 without encoding
140168
return _.diag(SPV_ERROR_INVALID_DATA, inst)
141169
<< "8-bit floating point type requires an encoding.";
142170
}
143-
const spvtools::OperandDesc* desc = nullptr;
144-
const std::set<spv::FPEncoding> known_encodings{
145-
spv::FPEncoding::Float8E4M3EXT, spv::FPEncoding::Float8E5M2EXT};
146-
spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
147-
inst->words()[3], &desc);
148-
if ((status != SPV_SUCCESS) ||
149-
(known_encodings.find(static_cast<spv::FPEncoding>(desc->value)) ==
150-
known_encodings.end())) {
151-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
152-
<< "Unsupported 8-bit floating point encoding ("
153-
<< desc->name().data() << ").";
154-
}
155-
156171
return SPV_SUCCESS;
157172
}
158173
if (num_bits == 64) {
159-
if (_.HasCapability(spv::Capability::Float64)) {
160-
return SPV_SUCCESS;
174+
if (!_.HasCapability(spv::Capability::Float64)) {
175+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
176+
<< "Using a 64-bit floating point "
177+
<< "type requires the Float64 capability.";
161178
}
162-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
163-
<< "Using a 64-bit floating point "
164-
<< "type requires the Float64 capability.";
179+
if (encoding.has_value()) {
180+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
181+
<< "64-bit floating point type must not have encoding parameter.";
182+
}
183+
return SPV_SUCCESS;
165184
}
166185
return _.diag(SPV_ERROR_INVALID_DATA, inst)
167186
<< "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";

test/val/val_data_test.cpp

Lines changed: 197 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
// Validation tests for Data Rules.
1616

17+
#include <sstream>
1718
#include <string>
1819
#include <utility>
1920

2021
#include "gmock/gmock.h"
22+
#include "spirv/unified1/spirv.hpp11"
2123
#include "test/unit_spirv.h"
2224
#include "test/val/val_fixtures.h"
2325

@@ -128,6 +130,15 @@ std::string header_with_float64 = R"(
128130
OpMemoryModel Logical GLSL450
129131
)";
130132

133+
std::string header_with_float64_bfloat16 = R"(
134+
OpCapability Shader
135+
OpCapability Linkage
136+
OpCapability Float64
137+
OpCapability BFloat16TypeKHR
138+
OpExtension "SPV_KHR_bfloat16"
139+
OpMemoryModel Logical GLSL450
140+
)";
141+
131142
std::string invalid_comp_error = "Illegal number of components";
132143
std::string missing_cap_error = "requires the Vector16 capability";
133144
std::string missing_int8_cap_error = "requires the Int8 capability";
@@ -378,7 +389,7 @@ TEST_F(ValidateData, float8_good) {
378389
%3 = OpTypeFloat 8 Float8E5M2EXT
379390
)";
380391
CompileSuccessfully(str.c_str());
381-
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
392+
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString();
382393
}
383394

384395
TEST_F(ValidateData, bfloat16_good) {
@@ -438,25 +449,98 @@ TEST_F(ValidateData, float16_buffer_good) {
438449
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
439450
}
440451

452+
TEST_F(ValidateData, float32_with_encoding_number_bad) {
453+
std::string str = header_with_bfloat16 + "%2 = OpTypeFloat 32 !9999";
454+
const auto& err = CompileFailure(str.c_str());
455+
EXPECT_THAT(err, HasSubstr("Invalid OpTypeFloat encoding"));
456+
}
457+
458+
TEST_F(ValidateData, float32_with_encoding_enum_bad) {
459+
std::string str = header_with_bfloat16 + "%2 = OpTypeFloat 32 BFloat16";
460+
const auto& err = CompileFailure(str.c_str());
461+
EXPECT_THAT(err, HasSubstr("Invalid FP encoding 'BFloat16'"));
462+
}
463+
464+
TEST_F(ValidateData, float64_with_encoding_number_bad) {
465+
std::string str = header_with_float64 + "%2 = OpTypeFloat 64 !9999";
466+
const auto& err = CompileFailure(str.c_str());
467+
EXPECT_THAT(err, HasSubstr("Invalid OpTypeFloat encoding"));
468+
}
469+
470+
TEST_F(ValidateData, float64_with_encoding_enum_bad) {
471+
std::string str = header_with_float64 + "%2 = OpTypeFloat 64 BFloat16";
472+
const auto& err = CompileFailure(str.c_str());
473+
EXPECT_THAT(err, HasSubstr("Invalid FP encoding 'BFloat16'"));
474+
}
475+
441476
TEST_F(ValidateData, float16_bad) {
442477
std::string str = header + "%2 = OpTypeFloat 16";
443478
CompileSuccessfully(str.c_str());
444479
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
445480
EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float16_cap_error));
446481
}
447482

448-
TEST_F(ValidateData, bfloat16_bad) {
483+
TEST_F(ValidateData, bfloat16_missing_cap_bad) {
449484
std::string str = header + "%2 = OpTypeFloat 16 BFloat16KHR";
450485
CompileSuccessfully(str.c_str());
451486
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
452487
EXPECT_THAT(getDiagnosticString(),
453488
HasSubstr("requires one of these capabilities: BFloat16TypeKHR"));
454489
}
455490

456-
TEST_F(ValidateData, float8_bad) {
491+
TEST_F(ValidateData, bfloat16_wrong_width_15_bad) {
492+
std::stringstream ss;
493+
ss << header_with_bfloat16 << "!"
494+
<< (4u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 15 "
495+
<< static_cast<uint32_t>(spv::FPEncoding::BFloat16KHR);
496+
CompileSuccessfully(ss.str().c_str());
497+
OverwriteIdBound(100);
498+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
499+
EXPECT_THAT(getDiagnosticString(),
500+
HasSubstr("Invalid number of bits (15) used for OpTypeFloat"))
501+
<< getDiagnosticString();
502+
}
503+
504+
TEST_F(ValidateData, bfloat16_wrong_width_8_bad) {
505+
std::stringstream ss;
506+
ss << header_with_float8_and_bfloat16 << "!"
507+
<< (4u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 8 "
508+
<< static_cast<uint32_t>(spv::FPEncoding::BFloat16KHR);
509+
CompileSuccessfully(ss.str().c_str());
510+
OverwriteIdBound(100);
511+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
512+
EXPECT_THAT(getDiagnosticString(),
513+
HasSubstr("Unsupported 8-bit floating point encoding"))
514+
<< getDiagnosticString();
515+
}
516+
517+
TEST_F(ValidateData, bfloat16_too_many_operands_bad) {
518+
std::stringstream ss;
519+
ss << header_with_float8_and_bfloat16 << "!"
520+
<< (5u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 16 "
521+
<< static_cast<uint32_t>(spv::FPEncoding::BFloat16KHR) << " 0";
522+
CompileSuccessfully(ss.str().c_str());
523+
OverwriteIdBound(100);
524+
ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
525+
EXPECT_THAT(getDiagnosticString(),
526+
HasSubstr("expected no more operands after 4 words, but stated "
527+
"word count is 5"))
528+
<< getDiagnosticString();
529+
}
530+
531+
TEST_F(ValidateData, float8_E4M3_missing_cap_bad) {
457532
std::string str = header +
458533
R"(%2 = OpTypeFloat 8 Float8E4M3EXT
459-
%3 = OpTypeFloat 8 Float8E5M2EXT
534+
)";
535+
CompileSuccessfully(str.c_str());
536+
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
537+
EXPECT_THAT(getDiagnosticString(),
538+
HasSubstr("requires one of these capabilities: Float8EXT"));
539+
}
540+
541+
TEST_F(ValidateData, float8_E5M2_missing_cap_bad) {
542+
std::string str = header +
543+
R"(%2 = OpTypeFloat 8 Float8E5M2EXT
460544
)";
461545
CompileSuccessfully(str.c_str());
462546
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
@@ -480,6 +564,86 @@ TEST_F(ValidateData, float8_bad_encoding) {
480564
"BFloat16KHR; expected 16"));
481565
}
482566

567+
TEST_F(ValidateData, float8_E4M3_wrong_width_7_bad) {
568+
std::stringstream ss;
569+
ss << header_with_float8 << "!"
570+
<< ((4u << 16) | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 7 "
571+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E4M3EXT);
572+
CompileSuccessfully(ss.str().c_str());
573+
OverwriteIdBound(100);
574+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
575+
EXPECT_THAT(getDiagnosticString(),
576+
HasSubstr("Invalid number of bits (7) used for OpTypeFloat"))
577+
<< getDiagnosticString();
578+
}
579+
580+
TEST_F(ValidateData, float8_E4M3_wrong_width_16_bad) {
581+
std::stringstream ss;
582+
ss << header_with_float8 << "!"
583+
<< ((4u << 16) | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 16 "
584+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E4M3EXT);
585+
CompileSuccessfully(ss.str().c_str());
586+
OverwriteIdBound(100);
587+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
588+
EXPECT_THAT(getDiagnosticString(),
589+
HasSubstr("Unsupported 16-bit floating point encoding (4214)"))
590+
<< getDiagnosticString();
591+
}
592+
593+
TEST_F(ValidateData, float8_E5M2_wrong_width_7_bad) {
594+
std::stringstream ss;
595+
ss << header_with_float8 << "!"
596+
<< ((4u << 16) | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 7 "
597+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E5M2EXT);
598+
CompileSuccessfully(ss.str().c_str());
599+
OverwriteIdBound(100);
600+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
601+
EXPECT_THAT(getDiagnosticString(),
602+
HasSubstr("Invalid number of bits (7) used for OpTypeFloat"))
603+
<< getDiagnosticString();
604+
}
605+
606+
TEST_F(ValidateData, float8_E5M2_wrong_width_16_bad) {
607+
std::stringstream ss;
608+
ss << header_with_float8 << "!"
609+
<< ((4u << 16) | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 16 "
610+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E5M2EXT);
611+
CompileSuccessfully(ss.str().c_str());
612+
OverwriteIdBound(100);
613+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
614+
EXPECT_THAT(getDiagnosticString(),
615+
HasSubstr("Unsupported 16-bit floating point encoding (4215)"))
616+
<< getDiagnosticString();
617+
}
618+
619+
TEST_F(ValidateData, float8_e4m3_too_many_operands_bad) {
620+
std::stringstream ss;
621+
ss << header_with_float8 << "!"
622+
<< (5u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 8 "
623+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E4M3EXT) << " 0";
624+
CompileSuccessfully(ss.str().c_str());
625+
OverwriteIdBound(100);
626+
ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
627+
EXPECT_THAT(getDiagnosticString(),
628+
HasSubstr("expected no more operands after 4 words, but stated "
629+
"word count is 5"))
630+
<< getDiagnosticString();
631+
}
632+
633+
TEST_F(ValidateData, float8_e5m2_too_many_operands_bad) {
634+
std::stringstream ss;
635+
ss << header_with_float8 << "!"
636+
<< (5u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 8 "
637+
<< static_cast<uint32_t>(spv::FPEncoding::Float8E5M2EXT) << " 0";
638+
CompileSuccessfully(ss.str().c_str());
639+
OverwriteIdBound(100);
640+
ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
641+
EXPECT_THAT(getDiagnosticString(),
642+
HasSubstr("expected no more operands after 4 words, but stated "
643+
"word count is 5"))
644+
<< getDiagnosticString();
645+
}
646+
483647
TEST_F(ValidateData, dot_bfloat16_bad) {
484648
std::string str = R"(
485649
OpCapability Shader
@@ -526,13 +690,41 @@ TEST_F(ValidateData, float64_good) {
526690
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
527691
}
528692

529-
TEST_F(ValidateData, float64_bad) {
693+
TEST_F(ValidateData, float64_missing_cap_bad) {
530694
std::string str = header + "%2 = OpTypeFloat 64";
531695
CompileSuccessfully(str.c_str());
532696
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
533697
EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float64_cap_error));
534698
}
535699

700+
TEST_F(ValidateData, float32_encoding_param_bad) {
701+
std::stringstream ss;
702+
ss << header_with_bfloat16 << "!"
703+
<< (4u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 32 "
704+
<< static_cast<uint32_t>(spv::FPEncoding::BFloat16KHR);
705+
CompileSuccessfully(ss.str().c_str());
706+
OverwriteIdBound(100);
707+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
708+
EXPECT_THAT(
709+
getDiagnosticString(),
710+
HasSubstr("32-bit floating point type must not have encoding parameter"))
711+
<< getDiagnosticString();
712+
}
713+
714+
TEST_F(ValidateData, float64_encoding_param_bad) {
715+
std::stringstream ss;
716+
ss << header_with_float64_bfloat16 << "!"
717+
<< (4u << 16 | static_cast<uint32_t>(spv::Op::OpTypeFloat)) << " 99 64 "
718+
<< static_cast<uint32_t>(spv::FPEncoding::BFloat16KHR);
719+
CompileSuccessfully(ss.str().c_str());
720+
OverwriteIdBound(100);
721+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
722+
EXPECT_THAT(
723+
getDiagnosticString(),
724+
HasSubstr("64-bit floating point type must not have encoding parameter"))
725+
<< getDiagnosticString();
726+
}
727+
536728
// Number of bits in a float may be only one of: {16,32,64}
537729
TEST_F(ValidateData, float_invalid_num_bits) {
538730
std::string str = header + "%2 = OpTypeFloat 48";

0 commit comments

Comments
 (0)