Skip to content

Commit 2f271ab

Browse files
fix(naga): properly impl. auto. type conv. for select
1 parent 90104c5 commit 2f271ab

6 files changed

Lines changed: 68 additions & 44 deletions

File tree

naga/src/front/wgsl/lower/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,8 +2483,23 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24832483
"select" => {
24842484
let mut args = ctx.prepare_args(arguments, 3, span);
24852485

2486-
let reject = self.expression(args.next()?, ctx)?;
2487-
let accept = self.expression(args.next()?, ctx)?;
2486+
let mut values = [
2487+
self.expression_for_abstract(args.next()?, ctx)?,
2488+
self.expression_for_abstract(args.next()?, ctx)?,
2489+
];
2490+
for &value in &values {
2491+
ctx.grow_types(value)?;
2492+
}
2493+
if let Ok(consensus_scalar) =
2494+
ctx.automatic_conversion_consensus(&values)
2495+
{
2496+
ctx.convert_slice_to_common_leaf_scalar(
2497+
&mut values,
2498+
consensus_scalar,
2499+
)?;
2500+
}
2501+
2502+
let [reject, accept] = values;
24882503
let condition = self.expression(args.next()?, ctx)?;
24892504

24902505
args.finish()?;

naga/src/proc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use emitter::Emitter;
1919
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
2020
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
2121
pub use namer::{EntryPointIndex, NameKey, Namer};
22-
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
22+
pub use overloads::{select, Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;

naga/src/proc/overloads/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod list;
2121
mod mathfunction;
2222
mod one_bits_iter;
2323
mod rule;
24+
pub mod select;
2425
mod utils;
2526

2627
pub use rule::{Conclusion, MissingSpecialType, Rule};

naga/src/proc/overloads/select.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::common::wgsl::{ToWgsl, TryToWgsl};
2+
use crate::ir;
3+
use crate::proc::overloads::utils::{list, rule, scalar_or_vecn, scalars};
4+
use crate::proc::overloads::OverloadSet;
5+
6+
pub fn overloads() -> impl OverloadSet {
7+
list(scalars().flat_map(|scalar| {
8+
scalar_or_vecn(scalar).map(|input| {
9+
let bool_arg = match input.clone() {
10+
ir::TypeInner::Scalar(_) => ir::TypeInner::Scalar(ir::Scalar::BOOL),
11+
ir::TypeInner::Vector { size, scalar: _ } => ir::TypeInner::Vector {
12+
size,
13+
scalar: ir::Scalar::BOOL,
14+
},
15+
_ => unreachable!(),
16+
};
17+
rule([input.clone(), input.clone(), bool_arg], input)
18+
})
19+
}))
20+
}

naga/src/proc/overloads/utils.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,25 @@ pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
3434
.into_iter()
3535
}
3636

37+
/// Produce all [`ir::Scalar`]s.
38+
///
39+
/// Note that `*32` and `F16` must appear before other sizes; this is how we
40+
/// represent conversion rank.
41+
pub fn scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
42+
[
43+
ir::Scalar::ABSTRACT_INT,
44+
ir::Scalar::ABSTRACT_FLOAT,
45+
ir::Scalar::I32,
46+
ir::Scalar::U32,
47+
ir::Scalar::F32,
48+
ir::Scalar::F16,
49+
ir::Scalar::I64,
50+
ir::Scalar::U64,
51+
ir::Scalar::F64,
52+
]
53+
.into_iter()
54+
}
55+
3756
/// Produce all the floating-point [`ir::Scalar`]s, but omit
3857
/// abstract types, for #7405.
3958
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {

naga/src/valid/expression.rs

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -975,47 +975,16 @@ impl super::Validator {
975975
accept,
976976
reject,
977977
} => {
978-
let accept_inner = &resolver[accept];
979-
let reject_inner = &resolver[reject];
980-
let condition_ty = &resolver[condition];
981-
let condition_good = match *condition_ty {
982-
Ti::Scalar(Sc {
983-
kind: Sk::Bool,
984-
width: _,
985-
}) => {
986-
// When `condition` is a single boolean, `accept` and
987-
// `reject` can be vectors or scalars.
988-
match *accept_inner {
989-
Ti::Scalar { .. } | Ti::Vector { .. } => true,
990-
_ => false,
991-
}
992-
}
993-
Ti::Vector {
994-
size,
995-
scalar:
996-
Sc {
997-
kind: Sk::Bool,
998-
width: _,
999-
},
1000-
} => match *accept_inner {
1001-
Ti::Vector {
1002-
size: other_size, ..
1003-
} => size == other_size,
1004-
_ => false,
1005-
},
1006-
_ => false,
1007-
};
1008-
if accept_inner != reject_inner {
1009-
return Err(ExpressionError::SelectValuesTypeMismatch {
1010-
accept: accept_inner.clone(),
1011-
reject: reject_inner.clone(),
1012-
});
1013-
}
1014-
if !condition_good {
1015-
return Err(ExpressionError::SelectConditionNotABool {
1016-
actual: condition_ty.clone(),
1017-
});
1018-
}
978+
self.validate_func_call_with_overloads(
979+
module,
980+
"Select",
981+
proc::select::overloads(),
982+
[reject, accept, condition]
983+
.iter()
984+
.copied()
985+
.map(|arg| (arg, &resolver[arg])),
986+
)?;
987+
1019988
ShaderStages::all()
1020989
}
1021990
E::Derivative { expr, .. } => {

0 commit comments

Comments
 (0)