Skip to content

Commit 6495ef3

Browse files
fix(naga): Calculate 1D distance using abs (gfx-rs#8903)
1 parent 723cb5d commit 6495ef3

2 files changed

Lines changed: 19 additions & 7 deletions

File tree

cts_runner/test.lst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_overrid
223223
webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*
224224
webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:op="%26%26";lhs="bool";rhs="bool"
225225
webgpu:shader,validation,expression,call,builtin,all:arguments:test="ptr_deref"
226+
webgpu:shader,validation,expression,call,builtin,distance:values:stage="constant";type="f32"
227+
webgpu:shader,validation,expression,call,builtin,length:scalar:stage="constant";type="f32"
226228
webgpu:shader,validation,expression,call,builtin,max:values:*
227229
// FAIL: others in `value_constructor` due to https://github.com/gfx-rs/wgpu/issues/4720, possibly more
228230
webgpu:shader,validation,expression,call,builtin,value_constructor:array_value:*

naga/src/proc/constant_evaluator.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,12 @@ impl<'a> ConstantEvaluator<'a> {
17911791
F: core::ops::Mul<F>,
17921792
F: num_traits::Float + iter::Sum,
17931793
{
1794-
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1794+
if e.len() == 1 {
1795+
// Avoids possible overflow in squaring
1796+
e[0].abs()
1797+
} else {
1798+
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1799+
}
17951800
}
17961801

17971802
let result = match_literal_vector!(match e1 => Literal {
@@ -1812,12 +1817,17 @@ impl<'a> ConstantEvaluator<'a> {
18121817
F: core::ops::Mul<F>,
18131818
F: num_traits::Float + iter::Sum + core::ops::Sub,
18141819
{
1815-
a.iter()
1816-
.zip(b.iter())
1817-
.map(|(&aa, &bb)| aa - bb)
1818-
.map(|ei| ei * ei)
1819-
.sum::<F>()
1820-
.sqrt()
1820+
if a.len() == 1 {
1821+
// Avoids possible overflow in squaring
1822+
(a[0] - b[0]).abs()
1823+
} else {
1824+
a.iter()
1825+
.zip(b.iter())
1826+
.map(|(&aa, &bb)| aa - bb)
1827+
.map(|ei| ei * ei)
1828+
.sum::<F>()
1829+
.sqrt()
1830+
}
18211831
}
18221832
let result = match_literal_vector!(match (e1, e2) => Literal {
18231833
Float => |e1, e2| { float_distance(e1, e2) },

0 commit comments

Comments
 (0)