@@ -36,6 +36,7 @@ enum Attribute {
3636 MeshStage ( String ) ,
3737 TaskPayload ( String ) ,
3838 PerPrimitive ,
39+ IncomingRayPayload ( String ) ,
3940}
4041
4142/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
@@ -243,10 +244,17 @@ impl<W: Write> Writer<W> {
243244 Attribute :: WorkGroupSize ( ep. workgroup_size) ,
244245 ]
245246 }
246- ShaderStage :: RayGeneration
247- | ShaderStage :: AnyHit
248- | ShaderStage :: ClosestHit
249- | ShaderStage :: Miss => unreachable ! ( ) ,
247+ ShaderStage :: RayGeneration => vec ! [ Attribute :: Stage ( ShaderStage :: RayGeneration ) ] ,
248+ ShaderStage :: AnyHit | ShaderStage :: ClosestHit | ShaderStage :: Miss => {
249+ let payload_name = module. global_variables [ ep. incoming_ray_payload . unwrap ( ) ]
250+ . name
251+ . clone ( )
252+ . unwrap ( ) ;
253+ vec ! [
254+ Attribute :: Stage ( ep. stage) ,
255+ Attribute :: IncomingRayPayload ( payload_name) ,
256+ ]
257+ }
250258 } ;
251259 self . write_attributes ( & attributes) ?;
252260 // Add a newline after attribute
@@ -287,6 +295,7 @@ impl<W: Write> Writer<W> {
287295 primitive_index : bool ,
288296 cooperative_matrix : bool ,
289297 draw_index : bool ,
298+ ray_tracing_pipeline : bool ,
290299 }
291300 let mut needed = RequiredEnabled {
292301 mesh_shaders : module. uses_mesh_shaders ( ) ,
@@ -313,6 +322,22 @@ impl<W: Write> Writer<W> {
313322 needed. mesh_shaders = true ;
314323 }
315324 crate :: Binding :: BuiltIn ( crate :: BuiltIn :: DrawIndex ) => needed. draw_index = true ,
325+ crate :: Binding :: BuiltIn (
326+ crate :: BuiltIn :: RayInvocationId
327+ | crate :: BuiltIn :: NumRayInvocations
328+ | crate :: BuiltIn :: InstanceCustomData
329+ | crate :: BuiltIn :: GeometryIndex
330+ | crate :: BuiltIn :: WorldRayOrigin
331+ | crate :: BuiltIn :: WorldRayDirection
332+ | crate :: BuiltIn :: ObjectRayOrigin
333+ | crate :: BuiltIn :: ObjectRayDirection
334+ | crate :: BuiltIn :: RayTmin
335+ | crate :: BuiltIn :: RayTCurrentMax
336+ | crate :: BuiltIn :: ObjectToWorld
337+ | crate :: BuiltIn :: WorldToObject ,
338+ ) => {
339+ needed. ray_tracing_pipeline = true ;
340+ }
316341 _ => { }
317342 } ;
318343
@@ -332,6 +357,9 @@ impl<W: Write> Writer<W> {
332357 TypeInner :: CooperativeMatrix { .. } => {
333358 needed. cooperative_matrix = true ;
334359 }
360+ TypeInner :: AccelerationStructure { .. } => {
361+ needed. ray_tracing_pipeline = true ;
362+ }
335363 _ => { }
336364 }
337365 }
@@ -350,6 +378,44 @@ impl<W: Write> Writer<W> {
350378 }
351379 }
352380
381+ if module. global_variables . iter ( ) . any ( |gv| {
382+ gv. 1 . space == crate :: AddressSpace :: IncomingRayPayload
383+ || gv. 1 . space == crate :: AddressSpace :: RayPayload
384+ } ) {
385+ needed. ray_tracing_pipeline = true ;
386+ }
387+
388+ if module. entry_points . iter ( ) . any ( |ep| {
389+ matches ! (
390+ ep. stage,
391+ ShaderStage :: RayGeneration
392+ | ShaderStage :: AnyHit
393+ | ShaderStage :: ClosestHit
394+ | ShaderStage :: Miss
395+ )
396+ } ) {
397+ needed. ray_tracing_pipeline = true ;
398+ }
399+
400+ if module. global_variables . iter ( ) . any ( |gv| {
401+ gv. 1 . space == crate :: AddressSpace :: IncomingRayPayload
402+ || gv. 1 . space == crate :: AddressSpace :: RayPayload
403+ } ) {
404+ needed. ray_tracing_pipeline = true ;
405+ }
406+
407+ if module. entry_points . iter ( ) . any ( |ep| {
408+ matches ! (
409+ ep. stage,
410+ ShaderStage :: RayGeneration
411+ | ShaderStage :: AnyHit
412+ | ShaderStage :: ClosestHit
413+ | ShaderStage :: Miss
414+ )
415+ } ) {
416+ needed. ray_tracing_pipeline = true ;
417+ }
418+
353419 // Write required declarations
354420 let mut any_written = false ;
355421 if needed. f16 {
@@ -380,6 +446,10 @@ impl<W: Write> Writer<W> {
380446 writeln ! ( self . out, "enable wgpu_cooperative_matrix;" ) ?;
381447 any_written = true ;
382448 }
449+ if needed. ray_tracing_pipeline {
450+ writeln ! ( self . out, "enable wgpu_ray_tracing_pipeline;" ) ?;
451+ any_written = true ;
452+ }
383453 if any_written {
384454 // Empty line for readability
385455 writeln ! ( self . out) ?;
@@ -501,10 +571,10 @@ impl<W: Write> Writer<W> {
501571 ShaderStage :: Task => "task" ,
502572 //Handled by another variant in the Attribute enum, so this code should never be hit.
503573 ShaderStage :: Mesh => unreachable ! ( ) ,
504- ShaderStage :: RayGeneration
505- | ShaderStage :: AnyHit
506- | ShaderStage :: ClosestHit
507- | ShaderStage :: Miss => unreachable ! ( ) ,
574+ ShaderStage :: RayGeneration => "ray_generation" ,
575+ ShaderStage :: AnyHit => "any_hit" ,
576+ ShaderStage :: ClosestHit => "closest_hit" ,
577+ ShaderStage :: Miss => "miss" ,
508578 } ;
509579
510580 write ! ( self . out, "@{stage_str} " ) ?;
@@ -542,6 +612,9 @@ impl<W: Write> Writer<W> {
542612 write ! ( self . out, "@payload({payload_name}) " ) ?;
543613 }
544614 Attribute :: PerPrimitive => write ! ( self . out, "@per_primitive " ) ?,
615+ Attribute :: IncomingRayPayload ( ref payload_name) => {
616+ write ! ( self . out, "@incoming_payload({payload_name}) " ) ?;
617+ }
545618 } ;
546619 }
547620 Ok ( ( ) )
@@ -1103,7 +1176,21 @@ impl<W: Write> Writer<W> {
11031176 self . write_expr ( module, data. stride , func_ctx) ?;
11041177 writeln ! ( self . out, ");" ) ?
11051178 }
1106- Statement :: RayPipelineFunction ( _) => unreachable ! ( ) ,
1179+ Statement :: RayPipelineFunction ( fun) => match fun {
1180+ crate :: RayPipelineFunction :: TraceRay {
1181+ acceleration_structure,
1182+ descriptor,
1183+ payload,
1184+ } => {
1185+ write ! ( self . out, "{level}traceRay(" ) ?;
1186+ self . write_expr ( module, acceleration_structure, func_ctx) ?;
1187+ write ! ( self . out, ", " ) ?;
1188+ self . write_expr ( module, descriptor, func_ctx) ?;
1189+ write ! ( self . out, ", " ) ?;
1190+ self . write_expr ( module, payload, func_ctx) ?;
1191+ writeln ! ( self . out, ");" ) ?
1192+ }
1193+ } ,
11071194 }
11081195
11091196 Ok ( ( ) )
0 commit comments