@@ -278,96 +278,120 @@ impl<W: Write> Writer<W> {
278278 /// Helper method which writes all the `enable` declarations
279279 /// needed for a module.
280280 fn write_enable_declarations ( & mut self , module : & Module ) -> BackendResult {
281- let mut needs_f16 = false ;
282- let mut needs_dual_source_blending = false ;
283- let mut needs_clip_distances = false ;
284- let mut needs_mesh_shaders = false ;
285- let mut needs_cooperative_matrix = false ;
281+ #[ derive( Default ) ]
282+ struct RequiredEnabled {
283+ f16 : bool ,
284+ dual_source_blending : bool ,
285+ clip_distances : bool ,
286+ mesh_shaders : bool ,
287+ primitive_index : bool ,
288+ cooperative_matrix : bool ,
289+ }
290+ let mut needed = RequiredEnabled :: default ( ) ;
291+
292+ let check_binding = |binding : & crate :: Binding , needed : & mut RequiredEnabled | match * binding
293+ {
294+ crate :: Binding :: Location {
295+ blend_src : Some ( _) , ..
296+ } => {
297+ needed. dual_source_blending = true ;
298+ }
299+ crate :: Binding :: BuiltIn ( crate :: BuiltIn :: ClipDistance ) => {
300+ needed. clip_distances = true ;
301+ }
302+ crate :: Binding :: BuiltIn ( crate :: BuiltIn :: PrimitiveIndex ) => {
303+ needed. primitive_index = true ;
304+ }
305+ crate :: Binding :: Location {
306+ per_primitive : true ,
307+ ..
308+ } => {
309+ needed. mesh_shaders = true ;
310+ }
311+ crate :: Binding :: BuiltIn (
312+ crate :: BuiltIn :: MeshTaskSize
313+ | crate :: BuiltIn :: CullPrimitive
314+ | crate :: BuiltIn :: PointIndex
315+ | crate :: BuiltIn :: LineIndices
316+ | crate :: BuiltIn :: TriangleIndices
317+ | crate :: BuiltIn :: VertexCount
318+ | crate :: BuiltIn :: Vertices
319+ | crate :: BuiltIn :: PrimitiveCount
320+ | crate :: BuiltIn :: Primitives ,
321+ ) => {
322+ needed. mesh_shaders = true ;
323+ }
324+ _ => { }
325+ } ;
286326
287327 // Determine which `enable` declarations are needed
288328 for ( _, ty) in module. types . iter ( ) {
289329 match ty. inner {
290330 TypeInner :: Scalar ( scalar)
291331 | TypeInner :: Vector { scalar, .. }
292332 | TypeInner :: Matrix { scalar, .. } => {
293- needs_f16 |= scalar == crate :: Scalar :: F16 ;
333+ needed . f16 |= scalar == crate :: Scalar :: F16 ;
294334 }
295335 TypeInner :: Struct { ref members, .. } => {
296336 for binding in members. iter ( ) . filter_map ( |m| m. binding . as_ref ( ) ) {
297- match * binding {
298- crate :: Binding :: Location {
299- blend_src : Some ( _) , ..
300- } => {
301- needs_dual_source_blending = true ;
302- }
303- crate :: Binding :: BuiltIn ( crate :: BuiltIn :: ClipDistance ) => {
304- needs_clip_distances = true ;
305- }
306- crate :: Binding :: Location {
307- per_primitive : true ,
308- ..
309- } => {
310- needs_mesh_shaders = true ;
311- }
312- crate :: Binding :: BuiltIn (
313- crate :: BuiltIn :: MeshTaskSize
314- | crate :: BuiltIn :: CullPrimitive
315- | crate :: BuiltIn :: PointIndex
316- | crate :: BuiltIn :: LineIndices
317- | crate :: BuiltIn :: TriangleIndices
318- | crate :: BuiltIn :: VertexCount
319- | crate :: BuiltIn :: Vertices
320- | crate :: BuiltIn :: PrimitiveCount
321- | crate :: BuiltIn :: Primitives ,
322- ) => {
323- needs_mesh_shaders = true ;
324- }
325- _ => { }
326- }
337+ check_binding ( binding, & mut needed) ;
327338 }
328339 }
329340 TypeInner :: CooperativeMatrix { .. } => {
330- needs_cooperative_matrix = true ;
341+ needed . cooperative_matrix = true ;
331342 }
332343 _ => { }
333344 }
334345 }
335346
336- if module
337- . entry_points
338- . iter ( )
339- . any ( |ep| matches ! ( ep. stage, ShaderStage :: Mesh | ShaderStage :: Task ) )
340- {
341- needs_mesh_shaders = true ;
347+ for ep in & module. entry_points {
348+ if matches ! ( ep. stage, ShaderStage :: Mesh | ShaderStage :: Task ) {
349+ needed. mesh_shaders = true ;
350+ }
351+ if let Some ( res) = ep. function . result . as_ref ( ) . and_then ( |a| a. binding . as_ref ( ) ) {
352+ check_binding ( res, & mut needed) ;
353+ }
354+ for arg in ep
355+ . function
356+ . arguments
357+ . iter ( )
358+ . filter_map ( |a| a. binding . as_ref ( ) )
359+ {
360+ check_binding ( arg, & mut needed) ;
361+ }
342362 }
343363
344364 if module
345365 . global_variables
346366 . iter ( )
347367 . any ( |gv| gv. 1 . space == crate :: AddressSpace :: TaskPayload )
348368 {
349- needs_mesh_shaders = true ;
369+ needed . mesh_shaders = true ;
350370 }
351371
352372 // Write required declarations
353373 let mut any_written = false ;
354- if needs_f16 {
374+ if needed . f16 {
355375 writeln ! ( self . out, "enable f16;" ) ?;
356376 any_written = true ;
357377 }
358- if needs_dual_source_blending {
378+ if needed . dual_source_blending {
359379 writeln ! ( self . out, "enable dual_source_blending;" ) ?;
360380 any_written = true ;
361381 }
362- if needs_clip_distances {
382+ if needed . clip_distances {
363383 writeln ! ( self . out, "enable clip_distances;" ) ?;
364384 any_written = true ;
365385 }
366- if needs_mesh_shaders {
386+ if needed . mesh_shaders {
367387 writeln ! ( self . out, "enable wgpu_mesh_shader;" ) ?;
368388 any_written = true ;
369389 }
370- if needs_cooperative_matrix {
390+ if needed. primitive_index {
391+ writeln ! ( self . out, "enable primitive_index;" ) ?;
392+ any_written = true ;
393+ }
394+ if needed. cooperative_matrix {
371395 writeln ! ( self . out, "enable wgpu_cooperative_matrix;" ) ?;
372396 any_written = true ;
373397 }
0 commit comments