Skip to content

Commit 6e16e04

Browse files
andyleisersonteoxoy
authored andcommitted
fix(naga): Detect overflowing const shift amounts at compile time
1 parent afbb31e commit 6e16e04

7 files changed

Lines changed: 177 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Bottom level categories:
9999
- Fixed `workgroupUniformLoad` incorrectly returning an atomic when called on an atomic, it now returns the inner `T` as per the spec. By @cryvosh in [#8791](https://github.com/gfx-rs/wgpu/pull/8791).
100100
- Fixed constant evaluation for `sign()` builtin to return zero when the argument is zero. By @mandryskowski in [#8942](https://github.com/gfx-rs/wgpu/pull/8942).
101101
- Allow array generation to compile with the macOS 10.12 Metal compiler. By @madsmtm in [#8953](https://github.com/gfx-rs/wgpu/pull/8953)
102+
- Naga now detects bitwise shifts by a constant exceeding the operand bit width at compile time, and disallows scalar-by-vector and vector-by-scalar shifts in constant evaluation. By @andyleiserson in [#8907](https://github.com/gfx-rs/wgpu/pull/8907).
102103

103104
#### Validation
104105

cts_runner/fail.lst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ webgpu:shader,validation,expression,access,matrix:* // 93%, runtime OOB matrix a
9191
webgpu:shader,validation,expression,access,vector:* // 52%, https://github.com/gfx-rs/wgpu/issues/4390, and missing swizzle validation
9292
webgpu:shader,validation,expression,binary,add_sub_mul:* // 95%, u32 const-eval overflow incorrectly rejected, f16 const-eval overflow not rejected, atomics #5474
9393
webgpu:shader,validation,expression,binary,and_or_xor:* // 96%, https://github.com/gfx-rs/wgpu/issues/5474
94-
webgpu:shader,validation,expression,binary,bitwise_shift:* // 97%, atomics https://github.com/gfx-rs/wgpu/issues/5474, partial eval errors
94+
webgpu:shader,validation,expression,binary,bitwise_shift:invalid_types:* // 93%, atomics #5474
9595
webgpu:shader,validation,expression,binary,comparison:* // 74%, https://github.com/gfx-rs/wgpu/issues/5474
9696
webgpu:shader,validation,expression,binary,div_rem:* // 86%, https://github.com/gfx-rs/wgpu/issues/5474
9797
webgpu:shader,validation,expression,binary,short_circuiting_and_or:* // 92%, https://github.com/gfx-rs/wgpu/issues/8440

cts_runner/test.lst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,12 @@ webgpu:shader,validation,expression,access,array:early_eval_errors:case="overrid
326326
webgpu:shader,validation,expression,access,structure:*
327327
webgpu:shader,validation,expression,binary,add_sub_mul:scalar_vector_out_of_range:lhs="i32";*
328328
webgpu:shader,validation,expression,binary,add_sub_mul:scalar_vector_out_of_range:lhs="u32";*
329+
webgpu:shader,validation,expression,binary,bitwise_shift:partial_eval_errors:*
329330
webgpu:shader,validation,expression,binary,bitwise_shift:scalar_vector:*
331+
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_abstract:*
332+
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_concrete:*
333+
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_abstract:*
334+
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*
330335
webgpu:shader,validation,expression,binary,parse:*
331336
webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_override:op="%26%26";a_val=1;b_val=1
332337
webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*

naga/src/proc/constant_evaluator.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,10 @@ impl<'a> ConstantEvaluator<'a> {
30403040
h
30413041
}
30423042

3043+
/// Resolve the type of `expr` if it is a constant expression.
3044+
///
3045+
/// If `expr` was evaluated to a constant, returns its type.
3046+
/// Otherwise, returns an error.
30433047
fn resolve_type(
30443048
&self,
30453049
expr: Handle<Expression>,

naga/src/proc/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,12 @@ pub struct GlobalCtx<'a> {
475475
}
476476

477477
impl GlobalCtx<'_> {
478-
/// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
478+
/// Try to evaluate the expression in `self.global_expressions` using its `handle`
479+
/// and return it as a `T: TryFrom<ir::Literal>`.
480+
///
481+
/// This currently only evaluates scalar expressions. If adding support for vectors,
482+
/// consider changing `valid::expression::validate_constant_shift_amounts` to use that
483+
/// support.
479484
#[cfg_attr(
480485
not(any(
481486
feature = "glsl-in",

naga/src/valid/expression.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ pub enum ExpressionError {
148148
UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
149149
#[error("Invalid operand for cooperative op")]
150150
InvalidCooperativeOperand(Handle<crate::Expression>),
151+
#[error("Shift amount exceeds the bit width of {lhs_type:?}")]
152+
ShiftAmountTooLarge {
153+
lhs_type: crate::TypeInner,
154+
rhs_expr: Handle<crate::Expression>,
155+
},
151156
}
152157

153158
#[derive(Clone, Debug, thiserror::Error)]
@@ -243,6 +248,74 @@ impl super::Validator {
243248
Ok(())
244249
}
245250

251+
/// Return an error if a constant shift amount in `right` exceeds the bit
252+
/// width of `left_ty`.
253+
///
254+
/// This function promises to return an error in cases where (1) the
255+
/// expression is well-typed, (2) `left_ty` is a concrete integer, and
256+
/// (3) the shift will overflow. It does not return an error in cases where
257+
/// the expression is not well-typed (e.g. vector dimension mismatch),
258+
/// because those will be rejected elsewhere.
259+
fn validate_constant_shift_amounts(
260+
left_ty: &crate::TypeInner,
261+
right: Handle<crate::Expression>,
262+
module: &crate::Module,
263+
function: &crate::Function,
264+
) -> Result<(), ExpressionError> {
265+
fn is_overflowing_shift(
266+
left_ty: &crate::TypeInner,
267+
right: Handle<crate::Expression>,
268+
module: &crate::Module,
269+
function: &crate::Function,
270+
) -> bool {
271+
let Some((vec_size, scalar)) = left_ty.vector_size_and_scalar() else {
272+
return false;
273+
};
274+
if !matches!(
275+
scalar.kind,
276+
crate::ScalarKind::Sint | crate::ScalarKind::Uint
277+
) {
278+
return false;
279+
}
280+
let lhs_bits = u32::from(8 * scalar.width);
281+
if vec_size.is_none() {
282+
let shift_amount = module
283+
.to_ctx()
284+
.get_const_val_from::<u32, _>(right, &function.expressions);
285+
shift_amount.ok().is_some_and(|s| s >= lhs_bits)
286+
} else {
287+
match function.expressions[right] {
288+
crate::Expression::ZeroValue(_) => false, // zero shift does not overflow
289+
crate::Expression::Splat { value, .. } => module
290+
.to_ctx()
291+
.get_const_val_from::<u32, _>(value, &function.expressions)
292+
.ok()
293+
.is_some_and(|s| s >= lhs_bits),
294+
crate::Expression::Compose {
295+
ty: _,
296+
ref components,
297+
} => components.iter().any(|comp| {
298+
module
299+
.to_ctx()
300+
.get_const_val_from::<u32, _>(*comp, &function.expressions)
301+
.ok()
302+
.is_some_and(|s| s >= lhs_bits)
303+
}),
304+
_ => false,
305+
}
306+
}
307+
}
308+
309+
if is_overflowing_shift(left_ty, right, module, function) {
310+
Err(ExpressionError::ShiftAmountTooLarge {
311+
lhs_type: left_ty.clone(),
312+
rhs_expr: right,
313+
})
314+
} else {
315+
Ok(())
316+
}
317+
}
318+
246319
#[allow(clippy::too_many_arguments)]
247320
pub(super) fn validate_expression(
248321
&self,
@@ -984,6 +1057,10 @@ impl super::Validator {
9841057
rhs_type: right_inner.clone(),
9851058
});
9861059
}
1060+
// For shift operations, check if the constant shift amount exceeds the bit width
1061+
if matches!(op, Bo::ShiftLeft | Bo::ShiftRight) {
1062+
Self::validate_constant_shift_amounts(left_inner, right, module, function)?;
1063+
}
9871064
ShaderStages::all()
9881065
}
9891066
E::Select {

naga/tests/naga/wgsl_errors.rs

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ fn check(input: &str, snapshot: &str) {
3131
}
3232
}
3333

34+
#[track_caller]
35+
fn check_error_matches(input: &str, expected_substring: &str) {
36+
let result = naga::front::wgsl::parse_str(input);
37+
let Err(ref err) = result else {
38+
panic!("expected ParseError, got {result:#?}");
39+
};
40+
let message = err.message();
41+
if !message.contains(expected_substring) {
42+
panic!("expected error containing '{expected_substring}', got '{message}'",);
43+
}
44+
}
45+
3446
#[track_caller]
3547
fn check_success(input: &str) {
3648
match naga::front::wgsl::parse_str(input) {
@@ -3816,18 +3828,10 @@ fn inconsistent_type() {
38163828
fn more_inconsistent_type() {
38173829
#[track_caller]
38183830
fn variant(call: &str) {
3819-
let input = format!(
3820-
r#"
3821-
fn f() {{ var x = {call}; }}
3822-
"#
3831+
check_error_matches(
3832+
&format!("fn f() {{ var x = {call}; }}"),
3833+
"inconsistent type",
38233834
);
3824-
let result = naga::front::wgsl::parse_str(&input);
3825-
let Err(ref err) = result else {
3826-
panic!("expected ParseError, got {result:#?}");
3827-
};
3828-
if !err.message().contains("inconsistent type") {
3829-
panic!("expected 'inconsistent type' error, got {result:#?}");
3830-
}
38313835
}
38323836

38333837
variant("min(1.0, 1i)");
@@ -4986,3 +4990,71 @@ fn enable_without_capability() {
49864990
);
49874991
}
49884992
}
4993+
4994+
#[test]
4995+
fn bitwise_shift_errors() {
4996+
// 32-bit const by const >= bitwidth
4997+
check_error_matches(
4998+
"const N: u32 = 1u >> 32;",
4999+
"RHS of shift operation is greater than or equal to 32",
5000+
);
5001+
check_error_matches(
5002+
"const N: i32 = 1i >> 32;",
5003+
"RHS of shift operation is greater than or equal to 32",
5004+
);
5005+
5006+
// 32-bit const by const overflow
5007+
check_error_matches("const N: u32 = 0xFFFFFFFFu << 1;", "overflowed");
5008+
check_error_matches("const N: i32 = 1i << 31;", "overflowed");
5009+
5010+
// 32-bit const by const negative shift
5011+
check_error_matches("const N: u32 = 1u << -1;", "cannot represent");
5012+
check_error_matches("const N: i32 = 1i << -1;", "cannot represent");
5013+
5014+
// 32-bit runtime by const < bitwidth
5015+
check_success("fn foo() { var x: u32; var n = x << 31; }");
5016+
check_success("fn foo() { var x: i32; var n = x << 31; }");
5017+
check_success("fn foo() { var x: u32; var n = x >> 31; }");
5018+
check_success("fn foo() { var x: i32; var n = x >> 31; }");
5019+
5020+
// 32-bit runtime by const >= bitwidth
5021+
check_validation! {
5022+
"fn foo() { var x: u32; var n = x >> 32; }",
5023+
"fn foo() { var x: i32; var n = x >> 32; }",
5024+
"fn foo() { var x: u32; var n = x << 32; }",
5025+
"fn foo() { var x: i32; var n = x << 32; }":
5026+
Err(naga::valid::ValidationError::Function {
5027+
source: naga::valid::FunctionError::Expression {
5028+
source: naga::valid::ExpressionError::ShiftAmountTooLarge { .. },
5029+
..
5030+
},
5031+
..
5032+
})
5033+
}
5034+
5035+
// (CTS has more 32-bit test cases)
5036+
5037+
// Const evaluation of `i64` and `u64` is not implemented, https://github.com/gfx-rs/wgpu/issues/8972
5038+
5039+
// 64-bit runtime by const < bitwidth
5040+
check_success("fn foo() { var x: u64; var n = x << 63; }");
5041+
check_success("fn foo() { var x: i64; var n = x << 63; }");
5042+
check_success("fn foo() { var x: u64; var n = x >> 63; }");
5043+
check_success("fn foo() { var x: i64; var n = x >> 63; }");
5044+
5045+
// 64-bit runtime by const >= bitwidth
5046+
check_validation! {
5047+
"fn foo() { var x: u64; var n = x << 64; }",
5048+
"fn foo() { var x: i64; var n = x << 64; }",
5049+
"fn foo() { var x: u64; var n = x >> 64; }",
5050+
"fn foo() { var x: i64; var n = x >> 64; }":
5051+
Err(naga::valid::ValidationError::Function {
5052+
source: naga::valid::FunctionError::Expression {
5053+
source: naga::valid::ExpressionError::ShiftAmountTooLarge { .. },
5054+
..
5055+
},
5056+
..
5057+
}),
5058+
naga::valid::Capabilities::SHADER_INT64
5059+
}
5060+
}

0 commit comments

Comments
 (0)