Skip to content

Commit c5e460f

Browse files
Fix regressions in logical operator handling in const exprs (gfx-rs#8721)
* Fix regressions in logical operator handling in const exprs Fixes regressions introduced by gfx-rs#7339 Fixes gfx-rs#8711
1 parent b74007f commit c5e460f

13 files changed

Lines changed: 928 additions & 684 deletions

cts_runner/test.lst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ webgpu:shader,execution,flow_control,return:*
177177
// Many other vertex_buffer_access subtests also passing, but there are too many to enumerate.
178178
// Fails on Metal in CI only, not when running locally.
179179
fails-if(metal) webgpu:shader,execution,robust_access_vertex:vertex_buffer_access:indexed=true;indirect=false;drawCallTestParameter="baseVertex";type="float32x4";additionalBuffers=4;partialLastNumber=false;offsetVertexBuffer=true
180+
webgpu:shader,validation,const_assert,const_assert:*
181+
webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_override:op="%26%26";a_val=1;b_val=1
182+
webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*
183+
webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:op="%26%26";lhs="bool";rhs="bool"
180184
webgpu:shader,validation,expression,call,builtin,all:arguments:test="ptr_deref"
181185
webgpu:shader,validation,expression,call,builtin,max:values:*
182186
webgpu:shader,validation,statement,statement_behavior:invalid_statements:body="break"

naga/src/front/wgsl/lower/mod.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
560560
}
561561
}
562562

563+
fn const_eval_expr_to_bool(&self, handle: Handle<ir::Expression>) -> Option<bool> {
564+
match self.expr_type {
565+
ExpressionContextType::Runtime(ref ctx) => {
566+
if !ctx.local_expression_kind_tracker.is_const(handle) {
567+
return None;
568+
}
569+
570+
self.module
571+
.to_ctx()
572+
.eval_expr_to_bool_from(handle, &ctx.function.expressions)
573+
}
574+
ExpressionContextType::Constant(Some(ref ctx)) => {
575+
assert!(ctx.local_expression_kind_tracker.is_const(handle));
576+
self.module
577+
.to_ctx()
578+
.eval_expr_to_bool_from(handle, &ctx.function.expressions)
579+
}
580+
ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_bool(handle),
581+
ExpressionContextType::Override => None,
582+
}
583+
}
584+
563585
/// Return `true` if `handle` is a constant expression.
564586
fn is_const(&self, handle: Handle<ir::Expression>) -> bool {
565587
use ExpressionContextType as Ect;
@@ -2538,22 +2560,27 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
25382560
result_var,
25392561
)))
25402562
} else {
2541-
let left_expr = ctx.get(left);
2542-
// Constant or override context in either function or module scope
2543-
let &crate::Expression::Literal(crate::Literal::Bool(left_val)) = left_expr else {
2544-
return Err(Box::new(Error::NotBool(span)));
2545-
};
2563+
let left_val = ctx.const_eval_expr_to_bool(left);
2564+
2565+
if left_val.is_some_and(|left_val| {
2566+
op == crate::BinaryOperator::LogicalAnd && !left_val
2567+
|| op == crate::BinaryOperator::LogicalOr && left_val
2568+
}) {
2569+
// Short-circuit behavior: don't evaluate the RHS.
25462570

2547-
if op == crate::BinaryOperator::LogicalAnd && !left_val
2548-
|| op == crate::BinaryOperator::LogicalOr && left_val
2549-
{
2550-
// Short-circuit behavior: don't evaluate the RHS. Ideally we
2551-
// would do _some_ validity checks of the RHS here, but that's
2552-
// tricky, because the RHS is allowed to have things that aren't
2553-
// legal in const contexts.
2571+
// TODO(https://github.com/gfx-rs/wgpu/issues/8440): We shouldn't ignore the
2572+
// RHS completely, it should still be type-checked. Preserving it for type
2573+
// checking is a bit tricky, because we're trying to produce an expression
2574+
// for a const context, but the RHS is allowed to have things that aren't
2575+
// const.
25542576

2555-
Ok(Typed::Plain(left_expr.clone()))
2577+
Ok(Typed::Plain(ctx.get(left).clone()))
25562578
} else {
2579+
// Evaluate the RHS and construct the entire binary expression as we
2580+
// normally would. This case applies to well-formed constant logical
2581+
// expressions that don't short-circuit (handled by the constant evaluator
2582+
// shortly), to override expressions (handled when overrides are processed)
2583+
// and to non-well-formed expressions (rejected by type checking).
25572584
let right = self.expression_for_abstract(right, ctx)?;
25582585
ctx.grow_types(right)?;
25592586

naga/src/proc/mod.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,18 @@ pub struct GlobalCtx<'a> {
440440

441441
impl GlobalCtx<'_> {
442442
/// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
443-
#[allow(dead_code)]
443+
#[cfg_attr(
444+
not(any(
445+
feature = "glsl-in",
446+
feature = "spv-in",
447+
feature = "wgsl-in",
448+
glsl_out,
449+
hlsl_out,
450+
msl_out,
451+
wgsl_out
452+
)),
453+
allow(dead_code)
454+
)]
444455
pub(super) fn eval_expr_to_u32(
445456
&self,
446457
handle: crate::Handle<crate::Expression>,
@@ -463,8 +474,17 @@ impl GlobalCtx<'_> {
463474
}
464475
}
465476

477+
/// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `bool`.
478+
#[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))]
479+
pub(super) fn eval_expr_to_bool(
480+
&self,
481+
handle: crate::Handle<crate::Expression>,
482+
) -> Option<bool> {
483+
self.eval_expr_to_bool_from(handle, self.global_expressions)
484+
}
485+
466486
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
467-
#[allow(dead_code)]
487+
#[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))]
468488
pub(super) fn eval_expr_to_bool_from(
469489
&self,
470490
handle: crate::Handle<crate::Expression>,
@@ -476,7 +496,7 @@ impl GlobalCtx<'_> {
476496
}
477497
}
478498

479-
#[allow(dead_code)]
499+
#[expect(dead_code)]
480500
pub(crate) fn eval_expr_to_literal(
481501
&self,
482502
handle: crate::Handle<crate::Expression>,

naga/tests/in/wgsl/const_assert.wgsl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement
1+
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement and the CTS.
22
const x = 1;
33
const y = 2;
44
const_assert x < y; // valid at module-scope.
@@ -9,13 +9,20 @@ const_assert x == 1i;
99
const_assert x > 0u;
1010
const_assert x < 2.0f;
1111

12+
const g_false = false;
13+
const_assert (!((g_false) || (any(vec3(false, false, false)))));
14+
1215
@compute @workgroup_size(1)
1316
fn foo() {
1417
const z = x + y - 2;
18+
const l_false = false;
1519
const_assert z > 0; // valid in functions.
1620
const_assert(z > 0);
1721

1822
const_assert z == 1i;
1923
const_assert z > 0u;
2024
const_assert z < 2.0f;
25+
26+
const_assert (!((g_false) || (any(vec3(false, false, false)))));
27+
const_assert (!((l_false) || (any(vec3(false, false, false)))));
2128
}

naga/tests/in/wgsl/operators.wgsl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ const v_f32_zero = vec4<f32>(0.0, 0.0, 0.0, 0.0);
33
const v_f32_half = vec4<f32>(0.5, 0.5, 0.5, 0.5);
44
const v_i32_one = vec4<i32>(1, 1, 1, 1);
55

6+
const b_false = false;
7+
const b_true = true;
8+
69
fn builtins() -> vec4<f32> {
710
// select()
811
let condition = true;
@@ -45,6 +48,14 @@ fn q() -> bool { return false; }
4548
fn r() -> bool { return true; }
4649
fn s() -> bool { return false; }
4750

51+
const short_circuit_1_invalid_rhs = false && sqrt(-1) != 0;
52+
// TODO(https://github.com/gfx-rs/wgpu/issues/8440):
53+
// The following should not be accepted, but it currently is.
54+
// When fixed, move this to a wgsl_errors test.
55+
const short_circuit_2_invalid_rhs = false && (0u + 1.0f > 0);
56+
const short_circuit_3 = b_false || b_true;
57+
const short_circuit_4 = !((b_false) || (any(vec3(false, false, false))));
58+
4859
fn logical() {
4960
let t = true;
5061
let f = false;
@@ -60,7 +71,18 @@ fn logical() {
6071
let bitwise_or1 = vec3(t) | vec3(f);
6172
let bitwise_and0 = t & f;
6273
let bitwise_and1 = vec4(t) & vec4(f);
63-
let short_circuit = (p() || q()) && (r() || s());
74+
75+
const short_circuit_1_invalid = false && sqrt(-1) != 0;
76+
// TODO(https://github.com/gfx-rs/wgpu/issues/8440):
77+
// The following should not be accepted, but it currently is.
78+
// When fixed, move this to a wgsl_errors test.
79+
const short_circuit_2_invalid_rhs = false && (0u + 1.0f > 0);
80+
const short_circuit_3 = b_false || b_true;
81+
const short_circuit_4 = !((b_false) || (any(vec3(false, false, false))));
82+
83+
let short_circuit_5 = !((f) || (any(vec3(false, false, false))));
84+
let short_circuit_6 = (p() || q()) && (r() || s());
85+
let short_circuit_7 = true || q();
6486
}
6587

6688
fn arithmetic() {
@@ -330,4 +352,3 @@ fn main(@builtin(workgroup_id) id: vec3<u32>) {
330352

331353
negation_avoids_prefix_decrement();
332354
}
333-

naga/tests/out/glsl/wgsl-operators.main.Compute.glsl

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ const vec4 v_f32_one = vec4(1.0, 1.0, 1.0, 1.0);
99
const vec4 v_f32_zero = vec4(0.0, 0.0, 0.0, 0.0);
1010
const vec4 v_f32_half = vec4(0.5, 0.5, 0.5, 0.5);
1111
const ivec4 v_i32_one = ivec4(1, 1, 1, 1);
12+
const bool b_false = false;
13+
const bool b_true = true;
14+
const bool short_circuit_1_invalid_rhs = false;
15+
const bool short_circuit_2_invalid_rhs = false;
16+
const bool short_circuit_3_ = true;
17+
const bool short_circuit_4_ = true;
1218

1319

1420
vec4 builtins() {
@@ -68,6 +74,8 @@ void logical() {
6874
bool local_2 = false;
6975
bool local_3 = false;
7076
bool local_4 = false;
77+
bool local_5 = false;
78+
bool local_6 = false;
7179
bool neg0_ = !(true);
7280
bvec2 neg1_ = not(bvec2(true));
7381
if (!(true)) {
@@ -86,28 +94,42 @@ void logical() {
8694
bvec3 bitwise_or1_ = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z);
8795
bool bitwise_and0_ = (true && false);
8896
bvec4 bitwise_and1_ = bvec4(bvec4(true).x && bvec4(false).x, bvec4(true).y && bvec4(false).y, bvec4(true).z && bvec4(false).z, bvec4(true).w && bvec4(false).w);
89-
bool _e22 = p();
90-
if (!(_e22)) {
91-
bool _e26 = q();
92-
local_2 = _e26;
97+
if (!(false)) {
98+
local_2 = false;
9399
} else {
94100
local_2 = true;
95101
}
96-
bool _e28 = local_2;
97-
if (_e28) {
98-
bool _e31 = r();
99-
if (!(_e31)) {
100-
bool _e35 = s();
101-
local_4 = _e35;
102+
bool _e27 = local_2;
103+
bool short_circuit_5_ = !(_e27);
104+
bool _e29 = p();
105+
if (!(_e29)) {
106+
bool _e33 = q();
107+
local_3 = _e33;
108+
} else {
109+
local_3 = true;
110+
}
111+
bool _e35 = local_3;
112+
if (_e35) {
113+
bool _e38 = r();
114+
if (!(_e38)) {
115+
bool _e42 = s();
116+
local_5 = _e42;
102117
} else {
103-
local_4 = true;
118+
local_5 = true;
104119
}
105-
bool _e37 = local_4;
106-
local_3 = _e37;
120+
bool _e44 = local_5;
121+
local_4 = _e44;
122+
} else {
123+
local_4 = false;
124+
}
125+
bool short_circuit_6_ = local_4;
126+
if (false) {
127+
bool _e50 = q();
128+
local_6 = _e50;
107129
} else {
108-
local_3 = false;
130+
local_6 = true;
109131
}
110-
bool short_circuit = local_3;
132+
bool short_circuit_7_ = local_6;
111133
return;
112134
}
113135

naga/tests/out/hlsl/wgsl-operators.hlsl

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ static const float4 v_f32_one = float4(1.0, 1.0, 1.0, 1.0);
22
static const float4 v_f32_zero = float4(0.0, 0.0, 0.0, 0.0);
33
static const float4 v_f32_half = float4(0.5, 0.5, 0.5, 0.5);
44
static const int4 v_i32_one = int4(int(1), int(1), int(1), int(1));
5+
static const bool b_false = false;
6+
static const bool b_true = true;
7+
static const bool short_circuit_1_invalid_rhs = false;
8+
static const bool short_circuit_2_invalid_rhs = false;
9+
static const bool short_circuit_3_ = true;
10+
static const bool short_circuit_4_ = true;
511

612
float4 builtins()
713
{
@@ -75,6 +81,8 @@ void logical()
7581
bool local_2 = (bool)0;
7682
bool local_3 = (bool)0;
7783
bool local_4 = (bool)0;
84+
bool local_5 = (bool)0;
85+
bool local_6 = (bool)0;
7886

7987
bool neg0_ = !(true);
8088
bool2 neg1_ = !((true).xx);
@@ -94,28 +102,42 @@ void logical()
94102
bool3 bitwise_or1_ = ((true).xxx | (false).xxx);
95103
bool bitwise_and0_ = (true & false);
96104
bool4 bitwise_and1_ = ((true).xxxx & (false).xxxx);
97-
const bool _e22 = p();
98-
if (!(_e22)) {
99-
const bool _e26 = q();
100-
local_2 = _e26;
105+
if (!(false)) {
106+
local_2 = false;
101107
} else {
102108
local_2 = true;
103109
}
104-
bool _e28 = local_2;
105-
if (_e28) {
106-
const bool _e31 = r();
107-
if (!(_e31)) {
108-
const bool _e35 = s();
109-
local_4 = _e35;
110+
bool _e27 = local_2;
111+
bool short_circuit_5_ = !(_e27);
112+
const bool _e29 = p();
113+
if (!(_e29)) {
114+
const bool _e33 = q();
115+
local_3 = _e33;
116+
} else {
117+
local_3 = true;
118+
}
119+
bool _e35 = local_3;
120+
if (_e35) {
121+
const bool _e38 = r();
122+
if (!(_e38)) {
123+
const bool _e42 = s();
124+
local_5 = _e42;
110125
} else {
111-
local_4 = true;
126+
local_5 = true;
112127
}
113-
bool _e37 = local_4;
114-
local_3 = _e37;
128+
bool _e44 = local_5;
129+
local_4 = _e44;
130+
} else {
131+
local_4 = false;
132+
}
133+
bool short_circuit_6_ = local_4;
134+
if (false) {
135+
const bool _e50 = q();
136+
local_6 = _e50;
115137
} else {
116-
local_3 = false;
138+
local_6 = true;
117139
}
118-
bool short_circuit = local_3;
140+
bool short_circuit_7_ = local_6;
119141
return;
120142
}
121143

0 commit comments

Comments
 (0)