Skip to content

Commit 153723a

Browse files
cryvoshErichDonGubler
authored andcommitted
Fix workgroupUniformLoad returning an atomic type
1 parent 02f8d9a commit 153723a

11 files changed

Lines changed: 371 additions & 36 deletions

naga/src/back/spv/block.rs

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3710,43 +3710,15 @@ impl BlockContext<'_> {
37103710
self.writer
37113711
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
37123712
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
3713-
// Embed the body of
3714-
match self.write_access_chain(
3713+
// Match `Expression::Load` behavior, including `OpAtomicLoad` when
3714+
// loading from a pointer to `atomic<T>`.
3715+
let id = self.write_checked_load(
37153716
pointer,
37163717
&mut block,
37173718
AccessTypeAdjustment::None,
3718-
)? {
3719-
ExpressionPointer::Ready { pointer_id } => {
3720-
let id = self.gen_id();
3721-
block.body.push(Instruction::load(
3722-
result_type_id,
3723-
id,
3724-
pointer_id,
3725-
None,
3726-
));
3727-
self.cached[result] = id;
3728-
}
3729-
ExpressionPointer::Conditional { condition, access } => {
3730-
self.cached[result] = self.write_conditional_indexed_load(
3731-
result_type_id,
3732-
condition,
3733-
&mut block,
3734-
move |id_gen, block| {
3735-
// The in-bounds path. Perform the access and the load.
3736-
let pointer_id = access.result_id.unwrap();
3737-
let value_id = id_gen.next();
3738-
block.body.push(access);
3739-
block.body.push(Instruction::load(
3740-
result_type_id,
3741-
value_id,
3742-
pointer_id,
3743-
None,
3744-
));
3745-
value_id
3746-
},
3747-
)
3748-
}
3749-
}
3719+
result_type_id,
3720+
)?;
3721+
self.cached[result] = id;
37503722
self.writer
37513723
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
37523724
}

naga/src/front/wgsl/lower/mod.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3059,7 +3059,32 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
30593059
ir::TypeInner::Pointer {
30603060
base,
30613061
space: ir::AddressSpace::WorkGroup,
3062-
} => base,
3062+
} => match ctx.module.types[base].inner {
3063+
// Match `Expression::Load` semantics:
3064+
// loading through a pointer to `atomic<T>` produces a `T`.
3065+
ir::TypeInner::Atomic(scalar) => ctx.module.types.insert(
3066+
ir::Type {
3067+
name: None,
3068+
inner: ir::TypeInner::Scalar(scalar),
3069+
},
3070+
span,
3071+
),
3072+
_ => base,
3073+
},
3074+
ir::TypeInner::ValuePointer {
3075+
size,
3076+
scalar,
3077+
space: ir::AddressSpace::WorkGroup,
3078+
} => ctx.module.types.insert(
3079+
ir::Type {
3080+
name: None,
3081+
inner: match size {
3082+
Some(size) => ir::TypeInner::Vector { size, scalar },
3083+
None => ir::TypeInner::Scalar(scalar),
3084+
},
3085+
},
3086+
span,
3087+
),
30633088
ref other => {
30643089
log::error!("Type {other:?} passed to workgroupUniformLoad");
30653090
let span = ctx.ast_expressions.get_span(expr);

naga/src/valid/function.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,24 @@ impl super::Validator {
14791479
base: ty,
14801480
space: AddressSpace::WorkGroup,
14811481
};
1482-
if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types) {
1482+
let atomic_specialization_ok = match *pointer_inner {
1483+
Ti::Pointer {
1484+
base: pointer_base,
1485+
space: AddressSpace::WorkGroup,
1486+
} => match (
1487+
&context.types[pointer_base].inner,
1488+
&context.types[ty].inner,
1489+
) {
1490+
(&Ti::Atomic(pointer_scalar), &Ti::Scalar(result_scalar)) => {
1491+
pointer_scalar == result_scalar
1492+
}
1493+
_ => false,
1494+
},
1495+
_ => false,
1496+
};
1497+
if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types)
1498+
&& !atomic_specialization_ok
1499+
{
14831500
return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
14841501
.with_span_static(span, "WorkGroupUniformLoad"));
14851502
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Test workgroupUniformLoad specialization for atomic<T> -> T
2+
// Issue: https://github.com/gfx-rs/wgpu/issues/8785
3+
4+
var<workgroup> wg_scalar: atomic<u32>;
5+
var<workgroup> wg_signed: atomic<i32>;
6+
7+
@compute @workgroup_size(64)
8+
fn test_atomic_workgroup_uniform_load(
9+
@builtin(workgroup_id) workgroup_id: vec3u,
10+
@builtin(local_invocation_id) local_id: vec3u
11+
) {
12+
let active_tile_index = workgroup_id.x + workgroup_id.y * 32768;
13+
14+
// Each thread may set the atomic
15+
atomicOr(&wg_scalar, u32(active_tile_index >= 64));
16+
atomicAdd(&wg_signed, 1i);
17+
18+
workgroupBarrier();
19+
20+
// workgroupUniformLoad on atomic<u32> should return u32
21+
let scalar_val: u32 = workgroupUniformLoad(&wg_scalar);
22+
23+
// workgroupUniformLoad on atomic<i32> should return i32
24+
let signed_val: i32 = workgroupUniformLoad(&wg_signed);
25+
26+
// Should be able to use the result in comparisons
27+
if scalar_val == 0u && signed_val > 0i {
28+
return;
29+
}
30+
}

naga/tests/naga/wgsl_errors.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@ fn check_success(input: &str) {
4747
}
4848
}
4949

50+
#[test]
51+
fn workgroup_uniform_load_atomic_returns_scalar() {
52+
let input = r#"
53+
var<workgroup> wg_scratch: atomic<u32>;
54+
55+
@compute @workgroup_size(4, 4, 4)
56+
fn interval_tile_main(
57+
@builtin(workgroup_id) workgroup_id: vec3u,
58+
@builtin(local_invocation_id) local_id: vec3u
59+
) {
60+
let active_tile_index = workgroup_id.x + workgroup_id.y * 32768u;
61+
atomicOr(&wg_scratch, u32(active_tile_index >= 64u));
62+
workgroupBarrier();
63+
if workgroupUniformLoad(&wg_scratch) == 0 {
64+
return;
65+
}
66+
}
67+
"#;
68+
69+
let module = naga::front::wgsl::parse_str(input).unwrap_or_else(|err| {
70+
panic!(
71+
"expected success, but parsing failed with:\n{}",
72+
err.emit_to_string(input)
73+
)
74+
});
75+
76+
naga::valid::Validator::new(valid::ValidationFlags::default(), Capabilities::all())
77+
.validate(&module)
78+
.unwrap();
79+
}
80+
5081
#[test]
5182
fn very_negative_integers() {
5283
// wgpu#4492
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#version 310 es
2+
3+
precision highp float;
4+
precision highp int;
5+
6+
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
7+
8+
shared uint wg_scalar;
9+
10+
shared int wg_signed;
11+
12+
13+
void main() {
14+
if (gl_LocalInvocationID == uvec3(0u)) {
15+
wg_scalar = 0u;
16+
wg_signed = 0;
17+
}
18+
memoryBarrierShared();
19+
barrier();
20+
uvec3 workgroup_id = gl_WorkGroupID;
21+
uvec3 local_id = gl_LocalInvocationID;
22+
bool local = false;
23+
uint active_tile_index = (workgroup_id.x + (workgroup_id.y * 32768u));
24+
uint _e11 = atomicOr(wg_scalar, uint((active_tile_index >= 64u)));
25+
int _e14 = atomicAdd(wg_signed, 1);
26+
memoryBarrierShared();
27+
barrier();
28+
memoryBarrierShared();
29+
barrier();
30+
uint _e16 = wg_scalar;
31+
memoryBarrierShared();
32+
barrier();
33+
memoryBarrierShared();
34+
barrier();
35+
int _e18 = wg_signed;
36+
memoryBarrierShared();
37+
barrier();
38+
if ((_e16 == 0u)) {
39+
local = (_e18 > 0);
40+
} else {
41+
local = false;
42+
}
43+
bool _e26 = local;
44+
if (_e26) {
45+
return;
46+
} else {
47+
return;
48+
}
49+
}
50+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
groupshared uint wg_scalar;
2+
groupshared int wg_signed;
3+
4+
[numthreads(64, 1, 1)]
5+
void test_atomic_workgroup_uniform_load(uint3 workgroup_id : SV_GroupID, uint3 local_id : SV_GroupThreadID, uint3 __local_invocation_id : SV_GroupThreadID)
6+
{
7+
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
8+
wg_scalar = (uint)0;
9+
wg_signed = (int)0;
10+
}
11+
GroupMemoryBarrierWithGroupSync();
12+
bool local = (bool)0;
13+
14+
uint active_tile_index = (workgroup_id.x + (workgroup_id.y * 32768u));
15+
uint _e11; InterlockedOr(wg_scalar, uint((active_tile_index >= 64u)), _e11);
16+
int _e14; InterlockedAdd(wg_signed, int(1), _e14);
17+
GroupMemoryBarrierWithGroupSync();
18+
GroupMemoryBarrierWithGroupSync();
19+
uint _e16 = wg_scalar;
20+
GroupMemoryBarrierWithGroupSync();
21+
GroupMemoryBarrierWithGroupSync();
22+
int _e18 = wg_signed;
23+
GroupMemoryBarrierWithGroupSync();
24+
if ((_e16 == 0u)) {
25+
local = (_e18 > int(0));
26+
} else {
27+
local = false;
28+
}
29+
bool _e26 = local;
30+
if (_e26) {
31+
return;
32+
} else {
33+
return;
34+
}
35+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"test_atomic_workgroup_uniform_load",
9+
target_profile:"cs_5_1",
10+
),
11+
],
12+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// language: metal1.0
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
8+
struct test_atomic_workgroup_uniform_loadInput {
9+
};
10+
kernel void test_atomic_workgroup_uniform_load(
11+
metal::uint3 workgroup_id [[threadgroup_position_in_grid]]
12+
, metal::uint3 local_id [[thread_position_in_threadgroup]]
13+
, threadgroup metal::atomic_uint& wg_scalar
14+
, threadgroup metal::atomic_int& wg_signed
15+
) {
16+
if (metal::all(local_id == metal::uint3(0u))) {
17+
metal::atomic_store_explicit(&wg_scalar, 0, metal::memory_order_relaxed);
18+
metal::atomic_store_explicit(&wg_signed, 0, metal::memory_order_relaxed);
19+
}
20+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
21+
bool local = {};
22+
uint active_tile_index = workgroup_id.x + (workgroup_id.y * 32768u);
23+
uint _e11 = metal::atomic_fetch_or_explicit(&wg_scalar, static_cast<uint>(active_tile_index >= 64u), metal::memory_order_relaxed);
24+
int _e14 = metal::atomic_fetch_add_explicit(&wg_signed, 1, metal::memory_order_relaxed);
25+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
26+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
27+
uint unnamed = metal::atomic_load_explicit(&wg_scalar, metal::memory_order_relaxed);
28+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
29+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
30+
int unnamed_1 = metal::atomic_load_explicit(&wg_signed, metal::memory_order_relaxed);
31+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
32+
if (unnamed == 0u) {
33+
local = unnamed_1 > 0;
34+
} else {
35+
local = false;
36+
}
37+
bool _e26 = local;
38+
if (_e26) {
39+
return;
40+
} else {
41+
return;
42+
}
43+
}

0 commit comments

Comments
 (0)