1414
1515// Validation tests for Data Rules.
1616
17+ #include < sstream>
1718#include < string>
1819#include < utility>
1920
2021#include " gmock/gmock.h"
22+ #include " spirv/unified1/spirv.hpp11"
2123#include " test/unit_spirv.h"
2224#include " test/val/val_fixtures.h"
2325
@@ -128,6 +130,15 @@ std::string header_with_float64 = R"(
128130 OpMemoryModel Logical GLSL450
129131)" ;
130132
133+ std::string header_with_float64_bfloat16 = R"(
134+ OpCapability Shader
135+ OpCapability Linkage
136+ OpCapability Float64
137+ OpCapability BFloat16TypeKHR
138+ OpExtension "SPV_KHR_bfloat16"
139+ OpMemoryModel Logical GLSL450
140+ )" ;
141+
131142std::string invalid_comp_error = " Illegal number of components" ;
132143std::string missing_cap_error = " requires the Vector16 capability" ;
133144std::string missing_int8_cap_error = " requires the Int8 capability" ;
@@ -378,7 +389,7 @@ TEST_F(ValidateData, float8_good) {
378389%3 = OpTypeFloat 8 Float8E5M2EXT
379390)" ;
380391 CompileSuccessfully (str.c_str ());
381- ASSERT_EQ (SPV_SUCCESS, ValidateInstructions ());
392+ ASSERT_EQ (SPV_SUCCESS, ValidateInstructions ()) << getDiagnosticString () ;
382393}
383394
384395TEST_F (ValidateData, bfloat16_good) {
@@ -438,25 +449,98 @@ TEST_F(ValidateData, float16_buffer_good) {
438449 ASSERT_EQ (SPV_SUCCESS, ValidateInstructions ());
439450}
440451
452+ TEST_F (ValidateData, float32_with_encoding_number_bad) {
453+ std::string str = header_with_bfloat16 + " %2 = OpTypeFloat 32 !9999" ;
454+ const auto & err = CompileFailure (str.c_str ());
455+ EXPECT_THAT (err, HasSubstr (" Invalid OpTypeFloat encoding" ));
456+ }
457+
458+ TEST_F (ValidateData, float32_with_encoding_enum_bad) {
459+ std::string str = header_with_bfloat16 + " %2 = OpTypeFloat 32 BFloat16" ;
460+ const auto & err = CompileFailure (str.c_str ());
461+ EXPECT_THAT (err, HasSubstr (" Invalid FP encoding 'BFloat16'" ));
462+ }
463+
464+ TEST_F (ValidateData, float64_with_encoding_number_bad) {
465+ std::string str = header_with_float64 + " %2 = OpTypeFloat 64 !9999" ;
466+ const auto & err = CompileFailure (str.c_str ());
467+ EXPECT_THAT (err, HasSubstr (" Invalid OpTypeFloat encoding" ));
468+ }
469+
470+ TEST_F (ValidateData, float64_with_encoding_enum_bad) {
471+ std::string str = header_with_float64 + " %2 = OpTypeFloat 64 BFloat16" ;
472+ const auto & err = CompileFailure (str.c_str ());
473+ EXPECT_THAT (err, HasSubstr (" Invalid FP encoding 'BFloat16'" ));
474+ }
475+
441476TEST_F (ValidateData, float16_bad) {
442477 std::string str = header + " %2 = OpTypeFloat 16" ;
443478 CompileSuccessfully (str.c_str ());
444479 ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
445480 EXPECT_THAT (getDiagnosticString (), HasSubstr (missing_float16_cap_error));
446481}
447482
448- TEST_F (ValidateData, bfloat16_bad ) {
483+ TEST_F (ValidateData, bfloat16_missing_cap_bad ) {
449484 std::string str = header + " %2 = OpTypeFloat 16 BFloat16KHR" ;
450485 CompileSuccessfully (str.c_str ());
451486 ASSERT_EQ (SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions ());
452487 EXPECT_THAT (getDiagnosticString (),
453488 HasSubstr (" requires one of these capabilities: BFloat16TypeKHR" ));
454489}
455490
456- TEST_F (ValidateData, float8_bad) {
491+ TEST_F (ValidateData, bfloat16_wrong_width_15_bad) {
492+ std::stringstream ss;
493+ ss << header_with_bfloat16 << " !"
494+ << (4u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 15 "
495+ << static_cast <uint32_t >(spv::FPEncoding::BFloat16KHR);
496+ CompileSuccessfully (ss.str ().c_str ());
497+ OverwriteIdBound (100 );
498+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
499+ EXPECT_THAT (getDiagnosticString (),
500+ HasSubstr (" Invalid number of bits (15) used for OpTypeFloat" ))
501+ << getDiagnosticString ();
502+ }
503+
504+ TEST_F (ValidateData, bfloat16_wrong_width_8_bad) {
505+ std::stringstream ss;
506+ ss << header_with_float8_and_bfloat16 << " !"
507+ << (4u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 8 "
508+ << static_cast <uint32_t >(spv::FPEncoding::BFloat16KHR);
509+ CompileSuccessfully (ss.str ().c_str ());
510+ OverwriteIdBound (100 );
511+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
512+ EXPECT_THAT (getDiagnosticString (),
513+ HasSubstr (" Unsupported 8-bit floating point encoding" ))
514+ << getDiagnosticString ();
515+ }
516+
517+ TEST_F (ValidateData, bfloat16_too_many_operands_bad) {
518+ std::stringstream ss;
519+ ss << header_with_float8_and_bfloat16 << " !"
520+ << (5u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 16 "
521+ << static_cast <uint32_t >(spv::FPEncoding::BFloat16KHR) << " 0" ;
522+ CompileSuccessfully (ss.str ().c_str ());
523+ OverwriteIdBound (100 );
524+ ASSERT_EQ (SPV_ERROR_INVALID_BINARY, ValidateInstructions ());
525+ EXPECT_THAT (getDiagnosticString (),
526+ HasSubstr (" expected no more operands after 4 words, but stated "
527+ " word count is 5" ))
528+ << getDiagnosticString ();
529+ }
530+
531+ TEST_F (ValidateData, float8_E4M3_missing_cap_bad) {
457532 std::string str = header +
458533 R"( %2 = OpTypeFloat 8 Float8E4M3EXT
459- %3 = OpTypeFloat 8 Float8E5M2EXT
534+ )" ;
535+ CompileSuccessfully (str.c_str ());
536+ ASSERT_EQ (SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions ());
537+ EXPECT_THAT (getDiagnosticString (),
538+ HasSubstr (" requires one of these capabilities: Float8EXT" ));
539+ }
540+
541+ TEST_F (ValidateData, float8_E5M2_missing_cap_bad) {
542+ std::string str = header +
543+ R"( %2 = OpTypeFloat 8 Float8E5M2EXT
460544)" ;
461545 CompileSuccessfully (str.c_str ());
462546 ASSERT_EQ (SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions ());
@@ -480,6 +564,86 @@ TEST_F(ValidateData, float8_bad_encoding) {
480564 " BFloat16KHR; expected 16" ));
481565}
482566
567+ TEST_F (ValidateData, float8_E4M3_wrong_width_7_bad) {
568+ std::stringstream ss;
569+ ss << header_with_float8 << " !"
570+ << ((4u << 16 ) | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 7 "
571+ << static_cast <uint32_t >(spv::FPEncoding::Float8E4M3EXT);
572+ CompileSuccessfully (ss.str ().c_str ());
573+ OverwriteIdBound (100 );
574+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
575+ EXPECT_THAT (getDiagnosticString (),
576+ HasSubstr (" Invalid number of bits (7) used for OpTypeFloat" ))
577+ << getDiagnosticString ();
578+ }
579+
580+ TEST_F (ValidateData, float8_E4M3_wrong_width_16_bad) {
581+ std::stringstream ss;
582+ ss << header_with_float8 << " !"
583+ << ((4u << 16 ) | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 16 "
584+ << static_cast <uint32_t >(spv::FPEncoding::Float8E4M3EXT);
585+ CompileSuccessfully (ss.str ().c_str ());
586+ OverwriteIdBound (100 );
587+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
588+ EXPECT_THAT (getDiagnosticString (),
589+ HasSubstr (" Unsupported 16-bit floating point encoding (4214)" ))
590+ << getDiagnosticString ();
591+ }
592+
593+ TEST_F (ValidateData, float8_E5M2_wrong_width_7_bad) {
594+ std::stringstream ss;
595+ ss << header_with_float8 << " !"
596+ << ((4u << 16 ) | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 7 "
597+ << static_cast <uint32_t >(spv::FPEncoding::Float8E5M2EXT);
598+ CompileSuccessfully (ss.str ().c_str ());
599+ OverwriteIdBound (100 );
600+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
601+ EXPECT_THAT (getDiagnosticString (),
602+ HasSubstr (" Invalid number of bits (7) used for OpTypeFloat" ))
603+ << getDiagnosticString ();
604+ }
605+
606+ TEST_F (ValidateData, float8_E5M2_wrong_width_16_bad) {
607+ std::stringstream ss;
608+ ss << header_with_float8 << " !"
609+ << ((4u << 16 ) | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 16 "
610+ << static_cast <uint32_t >(spv::FPEncoding::Float8E5M2EXT);
611+ CompileSuccessfully (ss.str ().c_str ());
612+ OverwriteIdBound (100 );
613+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
614+ EXPECT_THAT (getDiagnosticString (),
615+ HasSubstr (" Unsupported 16-bit floating point encoding (4215)" ))
616+ << getDiagnosticString ();
617+ }
618+
619+ TEST_F (ValidateData, float8_e4m3_too_many_operands_bad) {
620+ std::stringstream ss;
621+ ss << header_with_float8 << " !"
622+ << (5u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 8 "
623+ << static_cast <uint32_t >(spv::FPEncoding::Float8E4M3EXT) << " 0" ;
624+ CompileSuccessfully (ss.str ().c_str ());
625+ OverwriteIdBound (100 );
626+ ASSERT_EQ (SPV_ERROR_INVALID_BINARY, ValidateInstructions ());
627+ EXPECT_THAT (getDiagnosticString (),
628+ HasSubstr (" expected no more operands after 4 words, but stated "
629+ " word count is 5" ))
630+ << getDiagnosticString ();
631+ }
632+
633+ TEST_F (ValidateData, float8_e5m2_too_many_operands_bad) {
634+ std::stringstream ss;
635+ ss << header_with_float8 << " !"
636+ << (5u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 8 "
637+ << static_cast <uint32_t >(spv::FPEncoding::Float8E5M2EXT) << " 0" ;
638+ CompileSuccessfully (ss.str ().c_str ());
639+ OverwriteIdBound (100 );
640+ ASSERT_EQ (SPV_ERROR_INVALID_BINARY, ValidateInstructions ());
641+ EXPECT_THAT (getDiagnosticString (),
642+ HasSubstr (" expected no more operands after 4 words, but stated "
643+ " word count is 5" ))
644+ << getDiagnosticString ();
645+ }
646+
483647TEST_F (ValidateData, dot_bfloat16_bad) {
484648 std::string str = R"(
485649 OpCapability Shader
@@ -526,13 +690,41 @@ TEST_F(ValidateData, float64_good) {
526690 ASSERT_EQ (SPV_SUCCESS, ValidateInstructions ());
527691}
528692
529- TEST_F (ValidateData, float64_bad ) {
693+ TEST_F (ValidateData, float64_missing_cap_bad ) {
530694 std::string str = header + " %2 = OpTypeFloat 64" ;
531695 CompileSuccessfully (str.c_str ());
532696 ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
533697 EXPECT_THAT (getDiagnosticString (), HasSubstr (missing_float64_cap_error));
534698}
535699
700+ TEST_F (ValidateData, float32_encoding_param_bad) {
701+ std::stringstream ss;
702+ ss << header_with_bfloat16 << " !"
703+ << (4u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 32 "
704+ << static_cast <uint32_t >(spv::FPEncoding::BFloat16KHR);
705+ CompileSuccessfully (ss.str ().c_str ());
706+ OverwriteIdBound (100 );
707+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
708+ EXPECT_THAT (
709+ getDiagnosticString (),
710+ HasSubstr (" 32-bit floating point type must not have encoding parameter" ))
711+ << getDiagnosticString ();
712+ }
713+
714+ TEST_F (ValidateData, float64_encoding_param_bad) {
715+ std::stringstream ss;
716+ ss << header_with_float64_bfloat16 << " !"
717+ << (4u << 16 | static_cast <uint32_t >(spv::Op::OpTypeFloat)) << " 99 64 "
718+ << static_cast <uint32_t >(spv::FPEncoding::BFloat16KHR);
719+ CompileSuccessfully (ss.str ().c_str ());
720+ OverwriteIdBound (100 );
721+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions ());
722+ EXPECT_THAT (
723+ getDiagnosticString (),
724+ HasSubstr (" 64-bit floating point type must not have encoding parameter" ))
725+ << getDiagnosticString ();
726+ }
727+
536728// Number of bits in a float may be only one of: {16,32,64}
537729TEST_F (ValidateData, float_invalid_num_bits) {
538730 std::string str = header + " %2 = OpTypeFloat 48" ;
0 commit comments