Skip to content

Commit 6d74cf0

Browse files
fix(naga): properly impl. auto. type conv. for select
1 parent 158f2ad commit 6d74cf0

6 files changed

Lines changed: 91 additions & 44 deletions

File tree

naga/src/front/wgsl/lower/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,8 +2483,23 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24832483
"select" => {
24842484
let mut args = ctx.prepare_args(arguments, 3, span);
24852485

2486-
let reject = self.expression(args.next()?, ctx)?;
2487-
let accept = self.expression(args.next()?, ctx)?;
2486+
let mut values = [
2487+
self.expression_for_abstract(args.next()?, ctx)?,
2488+
self.expression_for_abstract(args.next()?, ctx)?,
2489+
];
2490+
for &value in &values {
2491+
ctx.grow_types(value)?;
2492+
}
2493+
if let Ok(consensus_scalar) =
2494+
ctx.automatic_conversion_consensus(&values)
2495+
{
2496+
ctx.convert_slice_to_common_leaf_scalar(
2497+
&mut values,
2498+
consensus_scalar,
2499+
)?;
2500+
}
2501+
2502+
let [reject, accept] = values;
24882503
let condition = self.expression(args.next()?, ctx)?;
24892504

24902505
args.finish()?;

naga/src/proc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use emitter::Emitter;
1919
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
2020
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
2121
pub use namer::{EntryPointIndex, NameKey, Namer};
22-
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
22+
pub use overloads::{select, Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;

naga/src/proc/overloads/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod list;
2121
mod mathfunction;
2222
mod one_bits_iter;
2323
mod rule;
24+
pub mod select;
2425
mod utils;
2526

2627
pub use rule::{Conclusion, MissingSpecialType, Rule};

naga/src/proc/overloads/select.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use crate::common::wgsl::{ToWgsl, TryToWgsl};
2+
use crate::ir;
3+
use crate::proc::overloads::utils::{list, rule, scalar_or_vecn, scalars};
4+
use crate::proc::overloads::OverloadSet;
5+
6+
pub fn overloads() -> impl OverloadSet {
7+
list(scalars().flat_map(|scalar| {
8+
scalar_or_vecn(scalar).map(|input| {
9+
let bool_arg = match input.clone() {
10+
ir::TypeInner::Scalar(_) => ir::TypeInner::Scalar(ir::Scalar::BOOL),
11+
ir::TypeInner::Vector { size, scalar: _ } => ir::TypeInner::Vector {
12+
size,
13+
scalar: ir::Scalar::BOOL,
14+
},
15+
_ => unreachable!(),
16+
};
17+
rule([input.clone(), input.clone(), bool_arg], input)
18+
})
19+
}))
20+
}
21+
22+
#[derive(Clone, Copy)]
23+
pub struct WgslSymbol;
24+
25+
impl ToWgsl for WgslSymbol {
26+
fn to_wgsl(self) -> &'static str {
27+
"select"
28+
}
29+
}
30+
31+
impl TryToWgsl for WgslSymbol {
32+
fn try_to_wgsl(self) -> Option<&'static str> {
33+
Some(self.to_wgsl())
34+
}
35+
36+
const DESCRIPTION: &'static str = "`select` built-in";
37+
}
38+
39+
impl core::fmt::Debug for WgslSymbol {
40+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41+
f.write_str(Self::DESCRIPTION)
42+
}
43+
}

naga/src/proc/overloads/utils.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,25 @@ pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
3434
.into_iter()
3535
}
3636

37+
/// Produce all [`ir::Scalar`]s.
38+
///
39+
/// Note that `*32` and `F16` must appear before other sizes; this is how we
40+
/// represent conversion rank.
41+
pub fn scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
42+
[
43+
ir::Scalar::ABSTRACT_INT,
44+
ir::Scalar::ABSTRACT_FLOAT,
45+
ir::Scalar::I32,
46+
ir::Scalar::U32,
47+
ir::Scalar::F32,
48+
ir::Scalar::F16,
49+
ir::Scalar::I64,
50+
ir::Scalar::U64,
51+
ir::Scalar::F64,
52+
]
53+
.into_iter()
54+
}
55+
3756
/// Produce all the floating-point [`ir::Scalar`]s, but omit
3857
/// abstract types, for #7405.
3958
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {

naga/src/valid/expression.rs

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -972,47 +972,16 @@ impl super::Validator {
972972
accept,
973973
reject,
974974
} => {
975-
let accept_inner = &resolver[accept];
976-
let reject_inner = &resolver[reject];
977-
let condition_ty = &resolver[condition];
978-
let condition_good = match *condition_ty {
979-
Ti::Scalar(Sc {
980-
kind: Sk::Bool,
981-
width: _,
982-
}) => {
983-
// When `condition` is a single boolean, `accept` and
984-
// `reject` can be vectors or scalars.
985-
match *accept_inner {
986-
Ti::Scalar { .. } | Ti::Vector { .. } => true,
987-
_ => false,
988-
}
989-
}
990-
Ti::Vector {
991-
size,
992-
scalar:
993-
Sc {
994-
kind: Sk::Bool,
995-
width: _,
996-
},
997-
} => match *accept_inner {
998-
Ti::Vector {
999-
size: other_size, ..
1000-
} => size == other_size,
1001-
_ => false,
1002-
},
1003-
_ => false,
1004-
};
1005-
if accept_inner != reject_inner {
1006-
return Err(ExpressionError::SelectValuesTypeMismatch {
1007-
accept: accept_inner.clone(),
1008-
reject: reject_inner.clone(),
1009-
});
1010-
}
1011-
if !condition_good {
1012-
return Err(ExpressionError::SelectConditionNotABool {
1013-
actual: condition_ty.clone(),
1014-
});
1015-
}
975+
self.validate_func_call_with_overloads(
976+
module,
977+
proc::select::WgslSymbol,
978+
proc::select::overloads(),
979+
[reject, accept, condition]
980+
.iter()
981+
.copied()
982+
.map(|arg| (arg, &resolver[arg])),
983+
)?;
984+
1016985
ShaderStages::all()
1017986
}
1018987
E::Derivative { expr, .. } => {

0 commit comments

Comments
 (0)