@@ -7,7 +7,7 @@ use alloc::{
77use core:: ops:: Range ;
88use metal:: {
99 MTLIndexType , MTLLoadAction , MTLPrimitiveType , MTLScissorRect , MTLSize , MTLStoreAction ,
10- MTLViewport , MTLVisibilityResultMode , NSRange ,
10+ MTLViewport , MTLVisibilityResultMode , NSRange , NSUInteger ,
1111} ;
1212use smallvec:: SmallVec ;
1313
@@ -31,6 +31,75 @@ impl Default for super::CommandState {
3131 }
3232}
3333
34+ /// Helper for passing encoders to `update_bind_group_state`.
35+ ///
36+ /// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for
37+ /// that stage.
38+ enum Encoder < ' e > {
39+ Vertex ( & ' e metal:: RenderCommandEncoder ) ,
40+ Fragment ( & ' e metal:: RenderCommandEncoder ) ,
41+ Task ( & ' e metal:: RenderCommandEncoder ) ,
42+ Mesh ( & ' e metal:: RenderCommandEncoder ) ,
43+ Compute ( & ' e metal:: ComputeCommandEncoder ) ,
44+ }
45+
46+ impl Encoder < ' _ > {
47+ fn stage ( & self ) -> naga:: ShaderStage {
48+ match self {
49+ Self :: Vertex ( _) => naga:: ShaderStage :: Vertex ,
50+ Self :: Fragment ( _) => naga:: ShaderStage :: Fragment ,
51+ Self :: Task ( _) => naga:: ShaderStage :: Task ,
52+ Self :: Mesh ( _) => naga:: ShaderStage :: Mesh ,
53+ Self :: Compute ( _) => naga:: ShaderStage :: Compute ,
54+ }
55+ }
56+
57+ fn set_buffer (
58+ & self ,
59+ index : NSUInteger ,
60+ buffer : Option < & metal:: BufferRef > ,
61+ offset : wgt:: BufferAddress ,
62+ ) {
63+ match * self {
64+ Self :: Vertex ( enc) => enc. set_vertex_buffer ( index, buffer, offset) ,
65+ Self :: Fragment ( enc) => enc. set_fragment_buffer ( index, buffer, offset) ,
66+ Self :: Task ( enc) => enc. set_object_buffer ( index, buffer, offset) ,
67+ Self :: Mesh ( enc) => enc. set_mesh_buffer ( index, buffer, offset) ,
68+ Self :: Compute ( enc) => enc. set_buffer ( index, buffer, offset) ,
69+ }
70+ }
71+
72+ fn set_bytes ( & self , index : NSUInteger , length : u64 , bytes : * const core:: ffi:: c_void ) {
73+ match * self {
74+ Self :: Vertex ( enc) => enc. set_vertex_bytes ( index, length, bytes) ,
75+ Self :: Fragment ( enc) => enc. set_fragment_bytes ( index, length, bytes) ,
76+ Self :: Task ( enc) => enc. set_object_bytes ( index, length, bytes) ,
77+ Self :: Mesh ( enc) => enc. set_mesh_bytes ( index, length, bytes) ,
78+ Self :: Compute ( enc) => enc. set_bytes ( index, length, bytes) ,
79+ }
80+ }
81+
82+ fn set_sampler_state ( & self , index : NSUInteger , state : Option < & metal:: SamplerStateRef > ) {
83+ match * self {
84+ Self :: Vertex ( enc) => enc. set_vertex_sampler_state ( index, state) ,
85+ Self :: Fragment ( enc) => enc. set_fragment_sampler_state ( index, state) ,
86+ Self :: Task ( enc) => enc. set_object_sampler_state ( index, state) ,
87+ Self :: Mesh ( enc) => enc. set_mesh_sampler_state ( index, state) ,
88+ Self :: Compute ( enc) => enc. set_sampler_state ( index, state) ,
89+ }
90+ }
91+
92+ fn set_texture ( & self , index : NSUInteger , texture : Option < & metal:: TextureRef > ) {
93+ match * self {
94+ Self :: Vertex ( enc) => enc. set_vertex_texture ( index, texture) ,
95+ Self :: Fragment ( enc) => enc. set_fragment_texture ( index, texture) ,
96+ Self :: Task ( enc) => enc. set_object_texture ( index, texture) ,
97+ Self :: Mesh ( enc) => enc. set_mesh_texture ( index, texture) ,
98+ Self :: Compute ( enc) => enc. set_texture ( index, texture) ,
99+ }
100+ }
101+ }
102+
34103impl super :: CommandEncoder {
35104 pub fn raw_command_buffer ( & self ) -> Option < & metal:: CommandBuffer > {
36105 self . raw_cmd_buf . as_ref ( )
@@ -146,31 +215,29 @@ impl super::CommandEncoder {
146215 }
147216
148217 /// Updates the bindings for a single shader stage, called in `set_bind_group`.
149- #[ expect( clippy:: too_many_arguments) ]
150218 fn update_bind_group_state (
151219 & mut self ,
152- stage : naga:: ShaderStage ,
153- render_encoder : Option < & metal:: RenderCommandEncoder > ,
154- compute_encoder : Option < & metal:: ComputeCommandEncoder > ,
220+ encoder : Encoder < ' _ > ,
155221 index_base : super :: ResourceData < u32 > ,
156222 bg_info : & super :: BindGroupLayoutInfo ,
157223 dynamic_offsets : & [ wgt:: DynamicOffset ] ,
158224 group_index : u32 ,
159225 group : & super :: BindGroup ,
160226 ) {
161- let resource_indices = match stage {
162- naga:: ShaderStage :: Vertex => & bg_info. base_resource_indices . vs ,
163- naga:: ShaderStage :: Fragment => & bg_info. base_resource_indices . fs ,
164- naga:: ShaderStage :: Task => & bg_info. base_resource_indices . ts ,
165- naga:: ShaderStage :: Mesh => & bg_info. base_resource_indices . ms ,
166- naga:: ShaderStage :: Compute => & bg_info. base_resource_indices . cs ,
227+ use naga:: ShaderStage as S ;
228+ let resource_indices = match encoder. stage ( ) {
229+ S :: Vertex => & bg_info. base_resource_indices . vs ,
230+ S :: Fragment => & bg_info. base_resource_indices . fs ,
231+ S :: Task => & bg_info. base_resource_indices . ts ,
232+ S :: Mesh => & bg_info. base_resource_indices . ms ,
233+ S :: Compute => & bg_info. base_resource_indices . cs ,
167234 } ;
168- let buffers = match stage {
169- naga :: ShaderStage :: Vertex => group. counters . vs . buffers ,
170- naga :: ShaderStage :: Fragment => group. counters . fs . buffers ,
171- naga :: ShaderStage :: Task => group. counters . ts . buffers ,
172- naga :: ShaderStage :: Mesh => group. counters . ms . buffers ,
173- naga :: ShaderStage :: Compute => group. counters . cs . buffers ,
235+ let buffers = match encoder . stage ( ) {
236+ S :: Vertex => group. counters . vs . buffers ,
237+ S :: Fragment => group. counters . fs . buffers ,
238+ S :: Task => group. counters . ts . buffers ,
239+ S :: Mesh => group. counters . ms . buffers ,
240+ S :: Compute => group. counters . cs . buffers ,
174241 } ;
175242 let mut changes_sizes_buffer = false ;
176243 for index in 0 ..buffers {
@@ -179,18 +246,9 @@ impl super::CommandEncoder {
179246 if let Some ( dyn_index) = buf. dynamic_index {
180247 offset += dynamic_offsets[ dyn_index as usize ] as wgt:: BufferAddress ;
181248 }
182- let a1 = ( resource_indices. buffers + index) as u64 ;
183- let a2 = Some ( buf. ptr . as_native ( ) ) ;
184- let a3 = offset;
185- match stage {
186- naga:: ShaderStage :: Vertex => render_encoder. unwrap ( ) . set_vertex_buffer ( a1, a2, a3) ,
187- naga:: ShaderStage :: Fragment => {
188- render_encoder. unwrap ( ) . set_fragment_buffer ( a1, a2, a3)
189- }
190- naga:: ShaderStage :: Task => render_encoder. unwrap ( ) . set_object_buffer ( a1, a2, a3) ,
191- naga:: ShaderStage :: Mesh => render_encoder. unwrap ( ) . set_mesh_buffer ( a1, a2, a3) ,
192- naga:: ShaderStage :: Compute => compute_encoder. unwrap ( ) . set_buffer ( a1, a2, a3) ,
193- }
249+ let index = ( resource_indices. buffers + index) as u64 ;
250+ let buffer = Some ( buf. ptr . as_native ( ) ) ;
251+ encoder. set_buffer ( index, buffer, offset) ;
194252 if let Some ( size) = buf. binding_size {
195253 let br = naga:: ResourceBinding {
196254 group : group_index,
@@ -203,66 +261,40 @@ impl super::CommandEncoder {
203261 if changes_sizes_buffer {
204262 if let Some ( ( index, sizes) ) = self
205263 . state
206- . make_sizes_buffer_update ( stage, & mut self . temp . binding_sizes )
264+ . make_sizes_buffer_update ( encoder . stage ( ) , & mut self . temp . binding_sizes )
207265 {
208- let a1 = index as _ ;
209- let a2 = ( sizes. len ( ) * WORD_SIZE ) as u64 ;
210- let a3 = sizes. as_ptr ( ) . cast ( ) ;
211- match stage {
212- naga:: ShaderStage :: Vertex => {
213- render_encoder. unwrap ( ) . set_vertex_bytes ( a1, a2, a3)
214- }
215- naga:: ShaderStage :: Fragment => {
216- render_encoder. unwrap ( ) . set_fragment_bytes ( a1, a2, a3)
217- }
218- naga:: ShaderStage :: Task => render_encoder. unwrap ( ) . set_object_bytes ( a1, a2, a3) ,
219- naga:: ShaderStage :: Mesh => render_encoder. unwrap ( ) . set_mesh_bytes ( a1, a2, a3) ,
220- naga:: ShaderStage :: Compute => compute_encoder. unwrap ( ) . set_bytes ( a1, a2, a3) ,
221- }
266+ let index = index as _ ;
267+ let length = ( sizes. len ( ) * WORD_SIZE ) as u64 ;
268+ let bytes_ptr = sizes. as_ptr ( ) . cast ( ) ;
269+ encoder. set_bytes ( index, length, bytes_ptr) ;
222270 }
223271 }
224- let samplers = match stage {
225- naga :: ShaderStage :: Vertex => group. counters . vs . samplers ,
226- naga :: ShaderStage :: Fragment => group. counters . fs . samplers ,
227- naga :: ShaderStage :: Task => group. counters . ts . samplers ,
228- naga :: ShaderStage :: Mesh => group. counters . ms . samplers ,
229- naga :: ShaderStage :: Compute => group. counters . cs . samplers ,
272+ let samplers = match encoder . stage ( ) {
273+ S :: Vertex => group. counters . vs . samplers ,
274+ S :: Fragment => group. counters . fs . samplers ,
275+ S :: Task => group. counters . ts . samplers ,
276+ S :: Mesh => group. counters . ms . samplers ,
277+ S :: Compute => group. counters . cs . samplers ,
230278 } ;
231279 for index in 0 ..samplers {
232280 let res = group. samplers [ ( index_base. samplers + index) as usize ] ;
233- let a1 = ( resource_indices. samplers + index) as u64 ;
234- let a2 = Some ( res. as_native ( ) ) ;
235- match stage {
236- naga:: ShaderStage :: Vertex => {
237- render_encoder. unwrap ( ) . set_vertex_sampler_state ( a1, a2)
238- }
239- naga:: ShaderStage :: Fragment => {
240- render_encoder. unwrap ( ) . set_fragment_sampler_state ( a1, a2)
241- }
242- naga:: ShaderStage :: Task => render_encoder. unwrap ( ) . set_object_sampler_state ( a1, a2) ,
243- naga:: ShaderStage :: Mesh => render_encoder. unwrap ( ) . set_mesh_sampler_state ( a1, a2) ,
244- naga:: ShaderStage :: Compute => compute_encoder. unwrap ( ) . set_sampler_state ( a1, a2) ,
245- }
281+ let index = ( resource_indices. samplers + index) as u64 ;
282+ let state = Some ( res. as_native ( ) ) ;
283+ encoder. set_sampler_state ( index, state) ;
246284 }
247285
248- let textures = match stage {
249- naga :: ShaderStage :: Vertex => group. counters . vs . textures ,
250- naga :: ShaderStage :: Fragment => group. counters . fs . textures ,
251- naga :: ShaderStage :: Task => group. counters . ts . textures ,
252- naga :: ShaderStage :: Mesh => group. counters . ms . textures ,
253- naga :: ShaderStage :: Compute => group. counters . cs . textures ,
286+ let textures = match encoder . stage ( ) {
287+ S :: Vertex => group. counters . vs . textures ,
288+ S :: Fragment => group. counters . fs . textures ,
289+ S :: Task => group. counters . ts . textures ,
290+ S :: Mesh => group. counters . ms . textures ,
291+ S :: Compute => group. counters . cs . textures ,
254292 } ;
255293 for index in 0 ..textures {
256294 let res = group. textures [ ( index_base. textures + index) as usize ] ;
257- let a1 = ( resource_indices. textures + index) as u64 ;
258- let a2 = Some ( res. as_native ( ) ) ;
259- match stage {
260- naga:: ShaderStage :: Vertex => render_encoder. unwrap ( ) . set_vertex_texture ( a1, a2) ,
261- naga:: ShaderStage :: Fragment => render_encoder. unwrap ( ) . set_fragment_texture ( a1, a2) ,
262- naga:: ShaderStage :: Task => render_encoder. unwrap ( ) . set_object_texture ( a1, a2) ,
263- naga:: ShaderStage :: Mesh => render_encoder. unwrap ( ) . set_mesh_texture ( a1, a2) ,
264- naga:: ShaderStage :: Compute => compute_encoder. unwrap ( ) . set_texture ( a1, a2) ,
265- }
295+ let index = ( resource_indices. textures + index) as u64 ;
296+ let texture = Some ( res. as_native ( ) ) ;
297+ encoder. set_texture ( index, texture) ;
266298 }
267299 }
268300}
@@ -841,9 +873,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
841873 let compute_encoder = self . state . compute . clone ( ) ;
842874 if let Some ( encoder) = render_encoder {
843875 self . update_bind_group_state (
844- naga:: ShaderStage :: Vertex ,
845- Some ( & encoder) ,
846- None ,
876+ Encoder :: Vertex ( & encoder) ,
847877 // All zeros, as vs comes first
848878 super :: ResourceData :: default ( ) ,
849879 bg_info,
@@ -852,9 +882,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
852882 group,
853883 ) ;
854884 self . update_bind_group_state (
855- naga:: ShaderStage :: Task ,
856- Some ( & encoder) ,
857- None ,
885+ Encoder :: Task ( & encoder) ,
858886 // All zeros, as ts comes first
859887 super :: ResourceData :: default ( ) ,
860888 bg_info,
@@ -863,19 +891,15 @@ impl crate::CommandEncoder for super::CommandEncoder {
863891 group,
864892 ) ;
865893 self . update_bind_group_state (
866- naga:: ShaderStage :: Mesh ,
867- Some ( & encoder) ,
868- None ,
894+ Encoder :: Mesh ( & encoder) ,
869895 group. counters . ts . clone ( ) ,
870896 bg_info,
871897 dynamic_offsets,
872898 group_index,
873899 group,
874900 ) ;
875901 self . update_bind_group_state (
876- naga:: ShaderStage :: Fragment ,
877- Some ( & encoder) ,
878- None ,
902+ Encoder :: Fragment ( & encoder) ,
879903 super :: ResourceData {
880904 buffers : group. counters . vs . buffers
881905 + group. counters . ts . buffers
@@ -899,9 +923,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
899923 }
900924 if let Some ( encoder) = compute_encoder {
901925 self . update_bind_group_state (
902- naga:: ShaderStage :: Compute ,
903- None ,
904- Some ( & encoder) ,
926+ Encoder :: Compute ( & encoder) ,
905927 super :: ResourceData {
906928 buffers : group. counters . vs . buffers
907929 + group. counters . ts . buffers
0 commit comments