Skip to content

Commit 6bca101

Browse files
Various SPIR-V fixes for mesh shaders (gfx-rs#8756)
* Initial commit adding miscellaneous changes from msl-write and hlsl-write * Same as previous commit * Fixed divergence issue * Removed some unnecessary barriers * Zero init workgroup memory * Added limits validation * Added changelog * Handled overflow, removed todo * Lets see if this fixes llvmpipe * Also this commit fixes llvmpipe maybe * Unfortunate but not too unexpected at this point * Updated feature to say to use ShaderRuntimeChecks::unchecked() * Updated snapshots and took some changes from the hlsl writer * Snapshots * 2 quick tweaks * Updated framework with suggestions by Connor in gfx-rs#8752 * Moved the task runtime limits into naga::back * Fixed soem stuff * Fixed checks: >= into > * Cant believe I forgot this * Removed code using u64 to check stuff * Removed thing with limiting it to 2<<21 * Some more work * Fixed compiles * Added new field to spv options * Updated some stuff to pass around the task runtime limits in more ways * New PR started * Did another quick fix * Another quick fix * Added changelog entry * Added some explanation docs * Fixed a warning * Fixed shader * Fixed some things & added docs * Reverted dxc thing * Fixed thing * Refactored to TaskDispatchLimits * Fixed compile error
1 parent c70b53c commit 6bca101

44 files changed

Lines changed: 2056 additions & 1499 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ By @cwfitzgerald in [#8999](https://github.com/gfx-rs/wgpu/pull/8999).
6969
- Added `Dx12BackendOptions::force_shader_model` to allow using advanced features in passthrough shaders without bundling DXC. By @inner-daemons in [#8984](https://github.com/gfx-rs/wgpu/pull/8984).
7070
- Changed passthrough shaders to not require an entry point parameter, so that the same shader module may be used in multiple entry points. Also added support for metallib passthrough. By @inner-daemons in #8886.
7171
- Added `Dx12Compiler::Auto` to automatically use static or dynamic DXC if available, before falling back to FXC. By @inner-daemons in [#8882](https://github.com/gfx-rs/wgpu/pull/8882).
72+
- Added support for `insert_debug_marker`, `push_debug_group` and `pop_debug_group` on WebGPU. By @evilpie in [#9017](https://github.com/gfx-rs/wgpu/pull/9017).
7273
- Added support for `@builtin(draw_index)` to the vulkan backend. By @inner-daemons in #8883.
7374
- Added support for `enable primitive_index` and `@builtin(primitive_index)` with support on all platforms. By @inner-daemons in #8879.
7475

@@ -127,6 +128,9 @@ By @cwfitzgerald in [#8999](https://github.com/gfx-rs/wgpu/pull/8999).
127128
- Fix incorrect acceptance of some swizzle selectors that are not valid for their operand, e.g. `const v = vec2<i32>(); let r = v.xyz`. By @andyleiserson in [#8949](https://github.com/gfx-rs/wgpu/pull/8949).
128129
- Fixed calculation of the total number of bindings in a pipeline layout when validating against device limits. By @andyleiserson in [#8997](https://github.com/gfx-rs/wgpu/pull/8997).
129130

131+
#### Vulkan
132+
- Fixed a variety of mesh shader SPIR-V writer issues from the original implementation. By @inner-daemons in [#8756](https://github.com/gfx-rs/wgpu/pull/8756)
133+
130134
#### GLES
131135

132136
- `DisplayHandle` should now be passed to `InstanceDescriptor` for correct EGL initialization on Wayland. By @MarijnS95 in [#8012](https://github.com/gfx-rs/wgpu/pull/8012)

benches/benches/wgpu-benchmark/shader.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,16 @@ pub fn backends(ctx: BenchmarkContext) -> anyhow::Result<Vec<SubBenchResult>> {
366366
let mut data = Vec::new();
367367
let mut writer = naga::back::spv::Writer::new(&Default::default()).unwrap();
368368
for input in &inputs.inner {
369+
let shared_info = WriterSharedOptions {
370+
mesh_output_validation: input.options.mesh_output_validation,
371+
task_limits: input.options.task_limits,
372+
bounds_checks_policies: input.options.bounds_check_policies,
373+
};
369374
if input.options.targets.unwrap().contains(Targets::SPIRV) {
370375
if input.filename().contains("pointer-function-arg") {
371376
continue;
372377
}
373-
let opt = input
374-
.options
375-
.spv
376-
.to_options(input.options.bounds_check_policies, None);
378+
let opt = input.options.spv.to_options(&shared_info, None);
377379
if writer.set_options(&opt).is_ok() {
378380
let _ = writer.write(
379381
input.module.as_ref().unwrap(),

docs/api-specs/mesh_shading.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,11 @@ A task shader entry point must have a `@workgroup_size` attribute, meeting the s
130130

131131
A task shader entry point must also have a `@payload(G)` property, where `G` is the name of a global variable in the `task_payload` address space. Each task shader workgroup has its own instance of this variable, visible to all invocations in the workgroup. Whatever value the workgroup collectively stores in that global variable becomes the **task payload**, and is provided to all invocations in the mesh shader grid dispatched for the workgroup. A task payload variable must be at least 4 bytes in size.
132132

133-
A task shader entry point must return a `vec3<u32>` value. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section.
133+
A task shader entry point must return a `vec3<u32>` value decorated with `@builtin(mesh_task_size)`. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section.
134134

135135
Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid;
136136
and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload.
137137

138-
Task shaders must return a value of type `vec3<u32>` decorated with `@builtin(mesh_task_size)`.
139-
140138
Task shaders can use compute and subgroup builtin inputs, in addition to `view_index` and `draw_id`.
141139

142140
### Mesh shader

examples/features/src/mesh_shader/mod.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
// Same as in mesh shader tests
22
fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule {
3-
device.create_shader_module(wgpu::ShaderModuleDescriptor {
4-
label: None,
5-
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
6-
})
3+
// Workgroup memory zero initialization can be expensive for mesh shaders
4+
unsafe {
5+
device.create_shader_module_trusted(
6+
wgpu::ShaderModuleDescriptor {
7+
label: None,
8+
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
9+
},
10+
wgpu::ShaderRuntimeChecks::unchecked(),
11+
)
12+
}
713
}
814
fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::ShaderModule {
915
let out_path = format!(

examples/features/src/mesh_shader/shader.wgsl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@ var<workgroup> workgroupData: f32;
3333

3434
@task
3535
@payload(taskPayload)
36-
@workgroup_size(1)
37-
fn ts_main() -> @builtin(mesh_task_size) vec3<u32> {
38-
workgroupData = 1.0;
39-
taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0);
40-
taskPayload.visible = true;
41-
return vec3(1, 1, 1);
36+
@workgroup_size(64)
37+
fn ts_main(@builtin(local_invocation_id) thread_id: vec3<u32>) -> @builtin(mesh_task_size) vec3<u32> {
38+
if thread_id.x == 0 {
39+
workgroupData = 1.0;
40+
taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0);
41+
taskPayload.visible = true;
42+
return vec3(1, 1, 1);
43+
}
44+
return vec3(0, 0, 0);
4245
}
4346

4447
struct MeshOutput {
@@ -52,24 +55,21 @@ var<workgroup> mesh_output: MeshOutput;
5255

5356
@mesh(mesh_output)
5457
@payload(taskPayload)
55-
@workgroup_size(1)
56-
fn ms_main() {
57-
mesh_output.vertex_count = 3;
58-
mesh_output.primitive_count = 1;
59-
workgroupData = 2.0;
58+
@workgroup_size(64)
59+
fn ms_main(@builtin(local_invocation_id) thread_id: vec3<u32>) {
60+
if thread_id.x == 0 {
61+
mesh_output.vertex_count = 3;
62+
mesh_output.primitive_count = 1;
63+
workgroupData = 2.0;
6064

61-
mesh_output.vertices[0].position = positions[0];
62-
mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask;
63-
64-
mesh_output.vertices[1].position = positions[1];
65-
mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask;
66-
67-
mesh_output.vertices[2].position = positions[2];
68-
mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask;
69-
70-
mesh_output.primitives[0].indices = vec3<u32>(0, 1, 2);
71-
mesh_output.primitives[0].cull = !taskPayload.visible;
72-
mesh_output.primitives[0].colorMask = vec4<f32>(1.0, 0.0, 1.0, 1.0);
65+
mesh_output.primitives[0].indices = vec3<u32>(0, 1, 2);
66+
mesh_output.primitives[0].cull = !taskPayload.visible;
67+
mesh_output.primitives[0].colorMask = vec4<f32>(1.0, 0.0, 1.0, 1.0);
68+
}
69+
if thread_id.x < 3 {
70+
mesh_output.vertices[thread_id.x].position = positions[thread_id.x];
71+
mesh_output.vertices[thread_id.x].color = colors[thread_id.x] * taskPayload.colorMask;
72+
}
7373
}
7474

7575
@fragment

naga-cli/src/bin/naga.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ struct Args {
152152
/// "67108864"), the string "none", or the string "all".
153153
#[argh(option, default = "CapabilitiesArg(naga::valid::Capabilities::all())")]
154154
capabilities: CapabilitiesArg,
155+
156+
/// the limits on the task shader dispatch size
157+
#[argh(option, default = "TaskDispatchLimitsArg(None)")]
158+
task_limits: TaskDispatchLimitsArg,
159+
160+
/// whether or not the mesh shader output should be validated.
161+
#[argh(option, default = "true")]
162+
validate_mesh_output: bool,
163+
}
164+
165+
/// Newtype so we can implement [`FromStr`] for `Option<TaskDispatchLimits>`.
166+
#[derive(Debug, Clone, Copy)]
167+
struct TaskDispatchLimitsArg(Option<naga::back::TaskDispatchLimits>);
168+
169+
impl FromStr for TaskDispatchLimitsArg {
170+
type Err = String;
171+
172+
fn from_str(s: &str) -> Result<Self, Self::Err> {
173+
let values = s
174+
.split_once(",")
175+
.ok_or_else(|| format!("No comma present for --task-limits value: {s}"))?;
176+
let x = values.0.parse::<u32>().map_err(|e| e.to_string())?;
177+
let y = values.1.parse::<u32>().map_err(|e| e.to_string())?;
178+
Ok(Self(Some(naga::back::TaskDispatchLimits {
179+
max_mesh_workgroups_per_dim: x,
180+
max_mesh_workgroups_total: y,
181+
})))
182+
}
155183
}
156184

157185
/// Newtype so we can implement [`FromStr`] for `BoundsCheckPolicy`.
@@ -545,6 +573,9 @@ fn run() -> anyhow::Result<()> {
545573
params.compact = args.compact;
546574
params.capabilities = args.capabilities.0;
547575

576+
params.spv_out.mesh_shader_primitive_indices_clamp = args.validate_mesh_output;
577+
params.spv_out.task_dispatch_limits = args.task_limits.0;
578+
548579
if args.bulk_validate {
549580
return bulk_validate(&args, &params);
550581
}

naga-test/src/lib.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ where
7575
Ok(map)
7676
}
7777

78+
#[derive(Default, serde::Deserialize)]
79+
#[serde(default)]
80+
pub struct WriterSharedOptions {
81+
pub mesh_output_validation: bool,
82+
pub task_limits: Option<naga::back::TaskDispatchLimits>,
83+
pub bounds_checks_policies: naga::proc::BoundsCheckPolicies,
84+
}
85+
7886
#[derive(Default, serde::Deserialize)]
7987
#[serde(default)]
8088
pub struct WgslInParameters {
@@ -137,7 +145,7 @@ impl Default for SpirvOutParameters {
137145
impl SpirvOutParameters {
138146
pub fn to_options<'a>(
139147
&'a self,
140-
bounds_check_policies: naga::proc::BoundsCheckPolicies,
148+
shared_info: &WriterSharedOptions,
141149
debug_info: Option<naga::back::spv::DebugInfo<'a>>,
142150
) -> naga::back::spv::Options<'a> {
143151
use naga::back::spv;
@@ -157,14 +165,16 @@ impl SpirvOutParameters {
157165
} else {
158166
Some(self.capabilities.clone())
159167
},
160-
bounds_check_policies,
168+
bounds_check_policies: shared_info.bounds_checks_policies,
161169
fake_missing_bindings: true,
162170
binding_map: self.binding_map.clone(),
163171
zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
164172
force_loop_bounding: true,
165173
ray_query_initialization_tracking: true,
166174
debug_info,
167175
use_storage_input_output_16: self.use_storage_input_output_16,
176+
task_dispatch_limits: shared_info.task_limits,
177+
mesh_shader_primitive_indices_clamp: shared_info.mesh_output_validation,
168178
}
169179
}
170180
}
@@ -235,6 +245,17 @@ pub struct Parameters {
235245

236246
pub bounds_check_policies: naga::proc::BoundsCheckPolicies,
237247
pub pipeline_constants: naga::back::PipelineConstants,
248+
249+
pub mesh_output_validation: bool,
250+
#[serde(default = "default_task_limits")]
251+
pub task_limits: Option<naga::back::TaskDispatchLimits>,
252+
}
253+
254+
fn default_task_limits() -> Option<naga::back::TaskDispatchLimits> {
255+
Some(naga::back::TaskDispatchLimits {
256+
max_mesh_workgroups_per_dim: 256,
257+
max_mesh_workgroups_total: 1024,
258+
})
238259
}
239260

240261
/// Information about a shader input file.

naga/src/back/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ fn get_entry_points(
148148
/// [`EntryPoint`]: crate::EntryPoint
149149
/// [`Module`]: crate::Module
150150
/// [`Module::entry_points`]: crate::Module::entry_points
151+
#[derive(Clone, Copy, Debug)]
151152
pub enum FunctionType {
152153
/// A regular function.
153154
Function(crate::Handle<crate::Function>),
@@ -391,3 +392,11 @@ pub enum RayIntersectionType {
391392
Triangle = 1,
392393
BoundingBox = 4,
393394
}
395+
396+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
397+
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
398+
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
399+
pub struct TaskDispatchLimits {
400+
pub max_mesh_workgroups_per_dim: u32,
401+
pub max_mesh_workgroups_total: u32,
402+
}

naga/src/back/spv/block.rs

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,11 @@ impl Writer {
237237
ir_result: &crate::FunctionResult,
238238
result_members: &[ResultMember],
239239
body: &mut Vec<Instruction>,
240-
task_payload: Option<Word>,
241240
) -> Result<Instruction, Error> {
242241
for (index, res_member) in result_members.iter().enumerate() {
243242
// This isn't a real builtin, and is handled elsewhere
244243
if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
245-
continue;
244+
return Ok(Instruction::return_value(value_id));
246245
}
247246
let member_value_id = match ir_result.binding {
248247
Some(_) => value_id,
@@ -274,13 +273,7 @@ impl Writer {
274273
_ => {}
275274
}
276275
}
277-
self.try_write_entry_point_task_return(
278-
value_id,
279-
ir_result,
280-
result_members,
281-
body,
282-
task_payload,
283-
)
276+
Ok(Instruction::return_void())
284277
}
285278
}
286279

@@ -3754,26 +3747,14 @@ impl BlockContext<'_> {
37543747
self.ir_function.result.as_ref().unwrap(),
37553748
&context.results,
37563749
&mut block.body,
3757-
context.task_payload_variable_id,
37583750
)?,
37593751
None => Instruction::return_value(value_id),
37603752
};
37613753
self.function.consume(block, instruction);
37623754
return Ok(BlockExitDisposition::Discarded);
37633755
}
37643756
Statement::Return { value: None } => {
3765-
if let Some(super::EntryPointContext {
3766-
mesh_state: Some(ref mesh_state),
3767-
..
3768-
}) = self.function.entry_point_context
3769-
{
3770-
self.function.consume(
3771-
block,
3772-
Instruction::branch(mesh_state.entry_point_epilogue_id),
3773-
);
3774-
} else {
3775-
self.function.consume(block, Instruction::return_void());
3776-
}
3757+
self.function.consume(block, Instruction::return_void());
37773758
return Ok(BlockExitDisposition::Discarded);
37783759
}
37793760
Statement::Kill => {
@@ -4242,16 +4223,6 @@ impl BlockContext<'_> {
42424223
LoopContext::default(),
42434224
debug_info,
42444225
)?;
4245-
if let Some(super::EntryPointContext {
4246-
mesh_state: Some(ref mesh_state),
4247-
..
4248-
}) = self.function.entry_point_context
4249-
{
4250-
let mut block = Block::new(mesh_state.entry_point_epilogue_id);
4251-
self.writer
4252-
.write_mesh_shader_return(mesh_state, &mut block)?;
4253-
self.function.consume(block, Instruction::return_void());
4254-
}
42554226

42564227
Ok(())
42574228
}

0 commit comments

Comments
 (0)