Skip to content

Commit 7d5a24a

Browse files
feat(naga): constant evaluation for select
1 parent bb83976 commit 7d5a24a

1 file changed

Lines changed: 141 additions & 3 deletions

File tree

naga/src/proc/constant_evaluator.rs

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,27 @@ pub enum ConstantEvaluatorError {
563563
RuntimeExpr,
564564
#[error("Unexpected override-expression")]
565565
OverrideExpr,
566+
#[error("Expected boolean expression for condition argument of `select`, got something else")]
567+
SelectScalarConditionNotABool,
568+
#[error(
569+
"Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
570+
reject,
571+
accept
572+
)]
573+
SelectVecRejectAcceptSizeMismatch {
574+
reject: crate::VectorSize,
575+
accept: crate::VectorSize,
576+
},
577+
#[error("Expected boolean vector for condition arg., got something else")]
578+
SelectConditionNotAVecBool,
579+
#[error(
580+
"Expected same number of vector components between condition, accept, and reject args., got something else",
581+
)]
582+
SelectConditionVecSizeMismatch,
583+
#[error(
584+
"Expected reject and accept args. to be scalars of vectors of the same type, got something else",
585+
)]
586+
SelectAcceptRejectTypeMismatch,
566587
}
567588

568589
impl<'a> ConstantEvaluator<'a> {
@@ -904,9 +925,11 @@ impl<'a> ConstantEvaluator<'a> {
904925
)),
905926
}
906927
}
907-
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
908-
"select built-in function".into(),
909-
)),
928+
Expression::Select {
929+
reject,
930+
accept,
931+
condition,
932+
} => self.select(reject, accept, condition, span),
910933
Expression::Relational { fun, argument } => {
911934
let argument = self.check_and_get(argument)?;
912935
self.relational(fun, argument, span)
@@ -2497,6 +2520,121 @@ impl<'a> ConstantEvaluator<'a> {
24972520

24982521
Ok(resolution)
24992522
}
2523+
2524+
fn select(
2525+
&mut self,
2526+
reject: Handle<Expression>,
2527+
accept: Handle<Expression>,
2528+
condition: Handle<Expression>,
2529+
span: Span,
2530+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2531+
let arg = |this: &mut Self, expr| {
2532+
let expr = this.check_and_get(expr)?;
2533+
let expr = this.eval_zero_value_and_splat(expr, span)?;
2534+
let expr: &Expression = &this.expressions[expr];
2535+
this.try_eval_and_append(expr.clone(), span)
2536+
};
2537+
2538+
let reject = arg(self, reject)?;
2539+
let accept = arg(self, accept)?;
2540+
let condition = arg(self, condition)?;
2541+
2542+
let select_single_component =
2543+
|this: &mut Self, reject_scalar, reject, accept, condition| {
2544+
let accept = this.cast(accept, reject_scalar, span)?;
2545+
if condition {
2546+
Ok(accept)
2547+
} else {
2548+
Ok(reject)
2549+
}
2550+
};
2551+
2552+
match (&self.expressions[reject], &self.expressions[accept]) {
2553+
(&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
2554+
let reject_scalar = reject_lit.scalar();
2555+
let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
2556+
else {
2557+
return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
2558+
};
2559+
select_single_component(self, reject_scalar, reject, accept, condition)
2560+
}
2561+
(
2562+
&Expression::Compose {
2563+
ty: reject_ty,
2564+
components: ref reject_components,
2565+
},
2566+
&Expression::Compose {
2567+
ty: accept_ty,
2568+
components: ref accept_components,
2569+
},
2570+
) => {
2571+
let ty_deets = |ty| {
2572+
let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
2573+
(size.unwrap(), scalar)
2574+
};
2575+
2576+
let expected_vec_size = {
2577+
let [(reject_vec_size, _), (accept_vec_size, _)] =
2578+
[reject_ty, accept_ty].map(ty_deets);
2579+
2580+
if reject_vec_size != accept_vec_size {
2581+
return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
2582+
reject: reject_vec_size,
2583+
accept: accept_vec_size,
2584+
});
2585+
}
2586+
reject_vec_size
2587+
};
2588+
2589+
let condition_components = match self.expressions[condition] {
2590+
Expression::Literal(Literal::Bool(condition)) => {
2591+
vec![condition; (expected_vec_size as u8).into()]
2592+
}
2593+
Expression::Compose {
2594+
ty: condition_ty,
2595+
components: ref condition_components,
2596+
} => {
2597+
let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
2598+
if condition_scalar.kind != ScalarKind::Bool {
2599+
return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
2600+
}
2601+
if condition_vec_size != expected_vec_size {
2602+
return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
2603+
}
2604+
condition_components
2605+
.iter()
2606+
.copied()
2607+
.map(|component| match &self.expressions[component] {
2608+
&Expression::Literal(Literal::Bool(condition)) => condition,
2609+
_ => unreachable!(),
2610+
})
2611+
.collect()
2612+
}
2613+
2614+
_ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
2615+
};
2616+
2617+
let evaluated = Expression::Compose {
2618+
ty: reject_ty,
2619+
components: reject_components
2620+
.clone()
2621+
.into_iter()
2622+
.zip(accept_components.clone().into_iter())
2623+
.zip(condition_components.into_iter())
2624+
.map(|((reject, accept), condition)| {
2625+
let reject_scalar = match &self.expressions[reject] {
2626+
&Expression::Literal(lit) => lit.scalar(),
2627+
_ => unreachable!(),
2628+
};
2629+
select_single_component(self, reject_scalar, reject, accept, condition)
2630+
})
2631+
.collect::<Result<_, _>>()?,
2632+
};
2633+
self.register_evaluated_expr(evaluated, span)
2634+
}
2635+
_ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
2636+
}
2637+
}
25002638
}
25012639

25022640
fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {

0 commit comments

Comments
 (0)