Skip to content

Commit fc9912e

Browse files
refactor(metal): Use descriptive names in update_bind_group_state (gfx-rs#8628)
1 parent 500e2ba commit fc9912e

1 file changed

Lines changed: 115 additions & 93 deletions

File tree

wgpu-hal/src/metal/command.rs

Lines changed: 115 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use alloc::{
77
use core::ops::Range;
88
use metal::{
99
MTLIndexType, MTLLoadAction, MTLPrimitiveType, MTLScissorRect, MTLSize, MTLStoreAction,
10-
MTLViewport, MTLVisibilityResultMode, NSRange,
10+
MTLViewport, MTLVisibilityResultMode, NSRange, NSUInteger,
1111
};
1212
use smallvec::SmallVec;
1313

@@ -31,6 +31,75 @@ impl Default for super::CommandState {
3131
}
3232
}
3333

34+
/// Helper for passing encoders to `update_bind_group_state`.
35+
///
36+
/// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for
37+
/// that stage.
38+
enum Encoder<'e> {
39+
Vertex(&'e metal::RenderCommandEncoder),
40+
Fragment(&'e metal::RenderCommandEncoder),
41+
Task(&'e metal::RenderCommandEncoder),
42+
Mesh(&'e metal::RenderCommandEncoder),
43+
Compute(&'e metal::ComputeCommandEncoder),
44+
}
45+
46+
impl Encoder<'_> {
47+
fn stage(&self) -> naga::ShaderStage {
48+
match self {
49+
Self::Vertex(_) => naga::ShaderStage::Vertex,
50+
Self::Fragment(_) => naga::ShaderStage::Fragment,
51+
Self::Task(_) => naga::ShaderStage::Task,
52+
Self::Mesh(_) => naga::ShaderStage::Mesh,
53+
Self::Compute(_) => naga::ShaderStage::Compute,
54+
}
55+
}
56+
57+
fn set_buffer(
58+
&self,
59+
index: NSUInteger,
60+
buffer: Option<&metal::BufferRef>,
61+
offset: wgt::BufferAddress,
62+
) {
63+
match *self {
64+
Self::Vertex(enc) => enc.set_vertex_buffer(index, buffer, offset),
65+
Self::Fragment(enc) => enc.set_fragment_buffer(index, buffer, offset),
66+
Self::Task(enc) => enc.set_object_buffer(index, buffer, offset),
67+
Self::Mesh(enc) => enc.set_mesh_buffer(index, buffer, offset),
68+
Self::Compute(enc) => enc.set_buffer(index, buffer, offset),
69+
}
70+
}
71+
72+
fn set_bytes(&self, index: NSUInteger, length: u64, bytes: *const core::ffi::c_void) {
73+
match *self {
74+
Self::Vertex(enc) => enc.set_vertex_bytes(index, length, bytes),
75+
Self::Fragment(enc) => enc.set_fragment_bytes(index, length, bytes),
76+
Self::Task(enc) => enc.set_object_bytes(index, length, bytes),
77+
Self::Mesh(enc) => enc.set_mesh_bytes(index, length, bytes),
78+
Self::Compute(enc) => enc.set_bytes(index, length, bytes),
79+
}
80+
}
81+
82+
fn set_sampler_state(&self, index: NSUInteger, state: Option<&metal::SamplerStateRef>) {
83+
match *self {
84+
Self::Vertex(enc) => enc.set_vertex_sampler_state(index, state),
85+
Self::Fragment(enc) => enc.set_fragment_sampler_state(index, state),
86+
Self::Task(enc) => enc.set_object_sampler_state(index, state),
87+
Self::Mesh(enc) => enc.set_mesh_sampler_state(index, state),
88+
Self::Compute(enc) => enc.set_sampler_state(index, state),
89+
}
90+
}
91+
92+
fn set_texture(&self, index: NSUInteger, texture: Option<&metal::TextureRef>) {
93+
match *self {
94+
Self::Vertex(enc) => enc.set_vertex_texture(index, texture),
95+
Self::Fragment(enc) => enc.set_fragment_texture(index, texture),
96+
Self::Task(enc) => enc.set_object_texture(index, texture),
97+
Self::Mesh(enc) => enc.set_mesh_texture(index, texture),
98+
Self::Compute(enc) => enc.set_texture(index, texture),
99+
}
100+
}
101+
}
102+
34103
impl super::CommandEncoder {
35104
pub fn raw_command_buffer(&self) -> Option<&metal::CommandBuffer> {
36105
self.raw_cmd_buf.as_ref()
@@ -146,31 +215,29 @@ impl super::CommandEncoder {
146215
}
147216

148217
/// Updates the bindings for a single shader stage, called in `set_bind_group`.
149-
#[expect(clippy::too_many_arguments)]
150218
fn update_bind_group_state(
151219
&mut self,
152-
stage: naga::ShaderStage,
153-
render_encoder: Option<&metal::RenderCommandEncoder>,
154-
compute_encoder: Option<&metal::ComputeCommandEncoder>,
220+
encoder: Encoder<'_>,
155221
index_base: super::ResourceData<u32>,
156222
bg_info: &super::BindGroupLayoutInfo,
157223
dynamic_offsets: &[wgt::DynamicOffset],
158224
group_index: u32,
159225
group: &super::BindGroup,
160226
) {
161-
let resource_indices = match stage {
162-
naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs,
163-
naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs,
164-
naga::ShaderStage::Task => &bg_info.base_resource_indices.ts,
165-
naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms,
166-
naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs,
227+
use naga::ShaderStage as S;
228+
let resource_indices = match encoder.stage() {
229+
S::Vertex => &bg_info.base_resource_indices.vs,
230+
S::Fragment => &bg_info.base_resource_indices.fs,
231+
S::Task => &bg_info.base_resource_indices.ts,
232+
S::Mesh => &bg_info.base_resource_indices.ms,
233+
S::Compute => &bg_info.base_resource_indices.cs,
167234
};
168-
let buffers = match stage {
169-
naga::ShaderStage::Vertex => group.counters.vs.buffers,
170-
naga::ShaderStage::Fragment => group.counters.fs.buffers,
171-
naga::ShaderStage::Task => group.counters.ts.buffers,
172-
naga::ShaderStage::Mesh => group.counters.ms.buffers,
173-
naga::ShaderStage::Compute => group.counters.cs.buffers,
235+
let buffers = match encoder.stage() {
236+
S::Vertex => group.counters.vs.buffers,
237+
S::Fragment => group.counters.fs.buffers,
238+
S::Task => group.counters.ts.buffers,
239+
S::Mesh => group.counters.ms.buffers,
240+
S::Compute => group.counters.cs.buffers,
174241
};
175242
let mut changes_sizes_buffer = false;
176243
for index in 0..buffers {
@@ -179,18 +246,9 @@ impl super::CommandEncoder {
179246
if let Some(dyn_index) = buf.dynamic_index {
180247
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
181248
}
182-
let a1 = (resource_indices.buffers + index) as u64;
183-
let a2 = Some(buf.ptr.as_native());
184-
let a3 = offset;
185-
match stage {
186-
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3),
187-
naga::ShaderStage::Fragment => {
188-
render_encoder.unwrap().set_fragment_buffer(a1, a2, a3)
189-
}
190-
naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3),
191-
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3),
192-
naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3),
193-
}
249+
let index = (resource_indices.buffers + index) as u64;
250+
let buffer = Some(buf.ptr.as_native());
251+
encoder.set_buffer(index, buffer, offset);
194252
if let Some(size) = buf.binding_size {
195253
let br = naga::ResourceBinding {
196254
group: group_index,
@@ -203,66 +261,40 @@ impl super::CommandEncoder {
203261
if changes_sizes_buffer {
204262
if let Some((index, sizes)) = self
205263
.state
206-
.make_sizes_buffer_update(stage, &mut self.temp.binding_sizes)
264+
.make_sizes_buffer_update(encoder.stage(), &mut self.temp.binding_sizes)
207265
{
208-
let a1 = index as _;
209-
let a2 = (sizes.len() * WORD_SIZE) as u64;
210-
let a3 = sizes.as_ptr().cast();
211-
match stage {
212-
naga::ShaderStage::Vertex => {
213-
render_encoder.unwrap().set_vertex_bytes(a1, a2, a3)
214-
}
215-
naga::ShaderStage::Fragment => {
216-
render_encoder.unwrap().set_fragment_bytes(a1, a2, a3)
217-
}
218-
naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3),
219-
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3),
220-
naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3),
221-
}
266+
let index = index as _;
267+
let length = (sizes.len() * WORD_SIZE) as u64;
268+
let bytes_ptr = sizes.as_ptr().cast();
269+
encoder.set_bytes(index, length, bytes_ptr);
222270
}
223271
}
224-
let samplers = match stage {
225-
naga::ShaderStage::Vertex => group.counters.vs.samplers,
226-
naga::ShaderStage::Fragment => group.counters.fs.samplers,
227-
naga::ShaderStage::Task => group.counters.ts.samplers,
228-
naga::ShaderStage::Mesh => group.counters.ms.samplers,
229-
naga::ShaderStage::Compute => group.counters.cs.samplers,
272+
let samplers = match encoder.stage() {
273+
S::Vertex => group.counters.vs.samplers,
274+
S::Fragment => group.counters.fs.samplers,
275+
S::Task => group.counters.ts.samplers,
276+
S::Mesh => group.counters.ms.samplers,
277+
S::Compute => group.counters.cs.samplers,
230278
};
231279
for index in 0..samplers {
232280
let res = group.samplers[(index_base.samplers + index) as usize];
233-
let a1 = (resource_indices.samplers + index) as u64;
234-
let a2 = Some(res.as_native());
235-
match stage {
236-
naga::ShaderStage::Vertex => {
237-
render_encoder.unwrap().set_vertex_sampler_state(a1, a2)
238-
}
239-
naga::ShaderStage::Fragment => {
240-
render_encoder.unwrap().set_fragment_sampler_state(a1, a2)
241-
}
242-
naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2),
243-
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2),
244-
naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2),
245-
}
281+
let index = (resource_indices.samplers + index) as u64;
282+
let state = Some(res.as_native());
283+
encoder.set_sampler_state(index, state);
246284
}
247285

248-
let textures = match stage {
249-
naga::ShaderStage::Vertex => group.counters.vs.textures,
250-
naga::ShaderStage::Fragment => group.counters.fs.textures,
251-
naga::ShaderStage::Task => group.counters.ts.textures,
252-
naga::ShaderStage::Mesh => group.counters.ms.textures,
253-
naga::ShaderStage::Compute => group.counters.cs.textures,
286+
let textures = match encoder.stage() {
287+
S::Vertex => group.counters.vs.textures,
288+
S::Fragment => group.counters.fs.textures,
289+
S::Task => group.counters.ts.textures,
290+
S::Mesh => group.counters.ms.textures,
291+
S::Compute => group.counters.cs.textures,
254292
};
255293
for index in 0..textures {
256294
let res = group.textures[(index_base.textures + index) as usize];
257-
let a1 = (resource_indices.textures + index) as u64;
258-
let a2 = Some(res.as_native());
259-
match stage {
260-
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2),
261-
naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2),
262-
naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2),
263-
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2),
264-
naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2),
265-
}
295+
let index = (resource_indices.textures + index) as u64;
296+
let texture = Some(res.as_native());
297+
encoder.set_texture(index, texture);
266298
}
267299
}
268300
}
@@ -841,9 +873,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
841873
let compute_encoder = self.state.compute.clone();
842874
if let Some(encoder) = render_encoder {
843875
self.update_bind_group_state(
844-
naga::ShaderStage::Vertex,
845-
Some(&encoder),
846-
None,
876+
Encoder::Vertex(&encoder),
847877
// All zeros, as vs comes first
848878
super::ResourceData::default(),
849879
bg_info,
@@ -852,9 +882,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
852882
group,
853883
);
854884
self.update_bind_group_state(
855-
naga::ShaderStage::Task,
856-
Some(&encoder),
857-
None,
885+
Encoder::Task(&encoder),
858886
// All zeros, as ts comes first
859887
super::ResourceData::default(),
860888
bg_info,
@@ -863,19 +891,15 @@ impl crate::CommandEncoder for super::CommandEncoder {
863891
group,
864892
);
865893
self.update_bind_group_state(
866-
naga::ShaderStage::Mesh,
867-
Some(&encoder),
868-
None,
894+
Encoder::Mesh(&encoder),
869895
group.counters.ts.clone(),
870896
bg_info,
871897
dynamic_offsets,
872898
group_index,
873899
group,
874900
);
875901
self.update_bind_group_state(
876-
naga::ShaderStage::Fragment,
877-
Some(&encoder),
878-
None,
902+
Encoder::Fragment(&encoder),
879903
super::ResourceData {
880904
buffers: group.counters.vs.buffers
881905
+ group.counters.ts.buffers
@@ -899,9 +923,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
899923
}
900924
if let Some(encoder) = compute_encoder {
901925
self.update_bind_group_state(
902-
naga::ShaderStage::Compute,
903-
None,
904-
Some(&encoder),
926+
Encoder::Compute(&encoder),
905927
super::ResourceData {
906928
buffers: group.counters.vs.buffers
907929
+ group.counters.ts.buffers

0 commit comments

Comments
 (0)