Skip to content

Commit 4972c69

Browse files
authored
spirv-as: Validate bit width of float types with explicit encodings (KhronosGroup#6562)
Do this in the assembler because it's silly to let it get any farther. Bug: crbug.com/465892071
1 parent 1c8c845 commit 4972c69

3 files changed

Lines changed: 95 additions & 4 deletions

File tree

source/text_handler.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cassert>
1919
#include <cstdlib>
2020
#include <cstring>
21+
#include <string_view>
2122
#include <tuple>
2223

2324
#include "source/assembly_grammar.h"
@@ -152,6 +153,31 @@ bool startsWithOp(spv_text text, spv_position position) {
152153
return ('O' == ch0 && 'p' == ch1 && ('A' <= ch2 && ch2 <= 'Z'));
153154
}
154155

156+
// Returns false if the the floating point encoding requires a bit width
157+
// different from the given width. Write the expected bit width via *expected.
158+
bool validBitWidthForFPEncoding(spv_fp_encoding_t enc, uint32_t width,
159+
uint32_t* expected) {
160+
switch (enc) {
161+
case SPV_FP_ENCODING_IEEE754_BINARY16:
162+
case SPV_FP_ENCODING_BFLOAT16:
163+
*expected = 16;
164+
break;
165+
case SPV_FP_ENCODING_IEEE754_BINARY32:
166+
*expected = 32;
167+
break;
168+
case SPV_FP_ENCODING_IEEE754_BINARY64:
169+
*expected = 64;
170+
break;
171+
case SPV_FP_ENCODING_FLOAT8_E5M2:
172+
case SPV_FP_ENCODING_FLOAT8_E4M3:
173+
*expected = 8;
174+
break;
175+
default:
176+
return true;
177+
}
178+
return width == *expected;
179+
}
180+
155181
} // namespace
156182

157183
const IdType kUnknownType = {0, false, IdTypeClass::kBottom};
@@ -342,6 +368,15 @@ spv_result_t AssemblyContext::recordTypeDefinition(
342368
if (status == SPV_SUCCESS) {
343369
enc = spvFPEncodingFromOperandFPEncoding(
344370
static_cast<spv::FPEncoding>(desc->value));
371+
uint32_t expected_width;
372+
if (!validBitWidthForFPEncoding(enc, pInst->words[2],
373+
&expected_width)) {
374+
const auto& name_span = desc->name();
375+
const std::string_view name(name_span.data(), name_span.size() - 1);
376+
return diagnostic() << "Invalid bit width " << pInst->words[2]
377+
<< " for floating point encoding " << name
378+
<< "; expected " << expected_width;
379+
}
345380
} else {
346381
return diagnostic() << "Invalid OpTypeFloat encoding";
347382
}

test/text_to_binary.type_declaration_test.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,63 @@ TEST_F(OpSizeOfTest, ArgumentTypes) {
278278
Eq("Expected id to start with %."));
279279
}
280280

281+
struct FloatEncodingWidthCase {
282+
std::string input;
283+
bool expect_pass;
284+
};
285+
286+
using FloatEncodingWidthTest = spvtest::TextToBinaryTestBase<
287+
::testing::TestWithParam<FloatEncodingWidthCase>>;
288+
289+
TEST_P(FloatEncodingWidthTest, Samples) {
290+
const auto& param = GetParam();
291+
if (param.expect_pass) {
292+
CompileSuccessfully(param.input);
293+
} else {
294+
auto err = CompileFailure(param.input);
295+
EXPECT_THAT(err, testing::HasSubstr("Invalid bit width"));
296+
EXPECT_THAT(err, testing::HasSubstr("for floating point encoding"));
297+
}
298+
}
299+
300+
INSTANTIATE_TEST_SUITE_P(
301+
TextToBinaryFloatWidth, FloatEncodingWidthTest,
302+
::testing::ValuesIn(std::vector<FloatEncodingWidthCase>{
303+
{"%1 = OpTypeFloat 32", true},
304+
{"%1 = OpTypeFloat 64", true},
305+
{"%1 = OpTypeFloat 16", true},
306+
{"%1 = OpTypeFloat 8", true},
307+
// bfloat16
308+
{"%1 = OpTypeFloat 0 BFloat16KHR", false},
309+
{"%1 = OpTypeFloat 1 BFloat16KHR", false},
310+
{"%1 = OpTypeFloat 15 BFloat16KHR", false},
311+
{"%1 = OpTypeFloat 16 BFloat16KHR", true},
312+
{"%1 = OpTypeFloat 17 BFloat16KHR", false},
313+
{"%1 = OpTypeFloat 32 BFloat16KHR", false},
314+
{"%1 = OpTypeFloat 64 BFloat16KHR", false},
315+
{"%1 = OpTypeFloat 100 BFloat16KHR", false},
316+
// fp8 E5M2
317+
{"%1 = OpTypeFloat 0 Float8E5M2EXT", false},
318+
{"%1 = OpTypeFloat 1 Float8E5M2EXT", false},
319+
{"%1 = OpTypeFloat 7 Float8E5M2EXT", false},
320+
{"%1 = OpTypeFloat 8 Float8E5M2EXT", true},
321+
{"%1 = OpTypeFloat 9 Float8E5M2EXT", false},
322+
{"%1 = OpTypeFloat 16 Float8E5M2EXT", false},
323+
{"%1 = OpTypeFloat 32 Float8E5M2EXT", false},
324+
{"%1 = OpTypeFloat 64 Float8E5M2EXT", false},
325+
{"%1 = OpTypeFloat 100 Float8E4M3EXT", false},
326+
// fp8 E4M3
327+
{"%1 = OpTypeFloat 0 Float8E4M3EXT", false},
328+
{"%1 = OpTypeFloat 1 Float8E4M3EXT", false},
329+
{"%1 = OpTypeFloat 7 Float8E4M3EXT", false},
330+
{"%1 = OpTypeFloat 8 Float8E4M3EXT", true},
331+
{"%1 = OpTypeFloat 9 Float8E4M3EXT", false},
332+
{"%1 = OpTypeFloat 16 Float8E4M3EXT", false},
333+
{"%1 = OpTypeFloat 32 Float8E4M3EXT", false},
334+
{"%1 = OpTypeFloat 64 Float8E4M3EXT", false},
335+
{"%1 = OpTypeFloat 100 Float8E4M3EXT", false},
336+
}));
337+
281338
// TODO(dneto): OpTypeVoid
282339
// TODO(dneto): OpTypeBool
283340
// TODO(dneto): OpTypeInt

test/val/val_data_test.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,9 @@ TEST_F(ValidateData, float8_no_encoding_bad) {
475475
TEST_F(ValidateData, float8_bad_encoding) {
476476
std::string str =
477477
header_with_float8_and_bfloat16 + "%2 = OpTypeFloat 8 BFloat16KHR";
478-
CompileSuccessfully(str.c_str());
479-
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
480-
EXPECT_THAT(getDiagnosticString(),
481-
HasSubstr("Unsupported 8-bit floating point encoding"));
478+
const auto& err = CompileFailure(str.c_str());
479+
EXPECT_THAT(err, HasSubstr("Invalid bit width 8 for floating point encoding "
480+
"BFloat16KHR; expected 16"));
482481
}
483482

484483
TEST_F(ValidateData, dot_bfloat16_bad) {

0 commit comments

Comments
 (0)