Skip to content

Commit 157cf36

Browse files
authored
[wgsl-out] Ray tracing pipelines (gfx-rs#8970)
* wgsl-out rt pipelines * Format. * Enable the ray tracing pipeline extension. * Add more checks for ray tracing pipelines. * Changelog. * Remove ir target.
1 parent dc29f72 commit 157cf36

7 files changed

Lines changed: 161 additions & 725 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ By @cwfitzgerald in [#8999](https://github.com/gfx-rs/wgpu/pull/8999).
7676
#### naga
7777

7878
- Initial wgsl-in ray tracing pipelines. By @Vecvec in [#8570](https://github.com/gfx-rs/wgpu/pull/8570).
79+
- wgsl-out ray tracing pipelines. By @Vecvec in [#8970](https://github.com/gfx-rs/wgpu/pull/8970).
7980
- Allow parsing shaders which make use of `SPV_KHR_non_semantic_info` for debug info. Also removes `naga::front::spv::SUPPORTED_EXT_SETS`. By @inner-daemons in #8827.
8081

8182
#### GLES

naga/src/back/wgsl/writer.rs

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum Attribute {
3636
MeshStage(String),
3737
TaskPayload(String),
3838
PerPrimitive,
39+
IncomingRayPayload(String),
3940
}
4041

4142
/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
@@ -243,10 +244,17 @@ impl<W: Write> Writer<W> {
243244
Attribute::WorkGroupSize(ep.workgroup_size),
244245
]
245246
}
246-
ShaderStage::RayGeneration
247-
| ShaderStage::AnyHit
248-
| ShaderStage::ClosestHit
249-
| ShaderStage::Miss => unreachable!(),
247+
ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
248+
ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
249+
let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
250+
.name
251+
.clone()
252+
.unwrap();
253+
vec![
254+
Attribute::Stage(ep.stage),
255+
Attribute::IncomingRayPayload(payload_name),
256+
]
257+
}
250258
};
251259
self.write_attributes(&attributes)?;
252260
// Add a newline after attribute
@@ -287,6 +295,7 @@ impl<W: Write> Writer<W> {
287295
primitive_index: bool,
288296
cooperative_matrix: bool,
289297
draw_index: bool,
298+
ray_tracing_pipeline: bool,
290299
}
291300
let mut needed = RequiredEnabled {
292301
mesh_shaders: module.uses_mesh_shaders(),
@@ -313,6 +322,22 @@ impl<W: Write> Writer<W> {
313322
needed.mesh_shaders = true;
314323
}
315324
crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
325+
crate::Binding::BuiltIn(
326+
crate::BuiltIn::RayInvocationId
327+
| crate::BuiltIn::NumRayInvocations
328+
| crate::BuiltIn::InstanceCustomData
329+
| crate::BuiltIn::GeometryIndex
330+
| crate::BuiltIn::WorldRayOrigin
331+
| crate::BuiltIn::WorldRayDirection
332+
| crate::BuiltIn::ObjectRayOrigin
333+
| crate::BuiltIn::ObjectRayDirection
334+
| crate::BuiltIn::RayTmin
335+
| crate::BuiltIn::RayTCurrentMax
336+
| crate::BuiltIn::ObjectToWorld
337+
| crate::BuiltIn::WorldToObject,
338+
) => {
339+
needed.ray_tracing_pipeline = true;
340+
}
316341
_ => {}
317342
};
318343

@@ -332,6 +357,9 @@ impl<W: Write> Writer<W> {
332357
TypeInner::CooperativeMatrix { .. } => {
333358
needed.cooperative_matrix = true;
334359
}
360+
TypeInner::AccelerationStructure { .. } => {
361+
needed.ray_tracing_pipeline = true;
362+
}
335363
_ => {}
336364
}
337365
}
@@ -350,6 +378,44 @@ impl<W: Write> Writer<W> {
350378
}
351379
}
352380

381+
if module.global_variables.iter().any(|gv| {
382+
gv.1.space == crate::AddressSpace::IncomingRayPayload
383+
|| gv.1.space == crate::AddressSpace::RayPayload
384+
}) {
385+
needed.ray_tracing_pipeline = true;
386+
}
387+
388+
if module.entry_points.iter().any(|ep| {
389+
matches!(
390+
ep.stage,
391+
ShaderStage::RayGeneration
392+
| ShaderStage::AnyHit
393+
| ShaderStage::ClosestHit
394+
| ShaderStage::Miss
395+
)
396+
}) {
397+
needed.ray_tracing_pipeline = true;
398+
}
399+
400+
if module.global_variables.iter().any(|gv| {
401+
gv.1.space == crate::AddressSpace::IncomingRayPayload
402+
|| gv.1.space == crate::AddressSpace::RayPayload
403+
}) {
404+
needed.ray_tracing_pipeline = true;
405+
}
406+
407+
if module.entry_points.iter().any(|ep| {
408+
matches!(
409+
ep.stage,
410+
ShaderStage::RayGeneration
411+
| ShaderStage::AnyHit
412+
| ShaderStage::ClosestHit
413+
| ShaderStage::Miss
414+
)
415+
}) {
416+
needed.ray_tracing_pipeline = true;
417+
}
418+
353419
// Write required declarations
354420
let mut any_written = false;
355421
if needed.f16 {
@@ -380,6 +446,10 @@ impl<W: Write> Writer<W> {
380446
writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
381447
any_written = true;
382448
}
449+
if needed.ray_tracing_pipeline {
450+
writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
451+
any_written = true;
452+
}
383453
if any_written {
384454
// Empty line for readability
385455
writeln!(self.out)?;
@@ -501,10 +571,10 @@ impl<W: Write> Writer<W> {
501571
ShaderStage::Task => "task",
502572
//Handled by another variant in the Attribute enum, so this code should never be hit.
503573
ShaderStage::Mesh => unreachable!(),
504-
ShaderStage::RayGeneration
505-
| ShaderStage::AnyHit
506-
| ShaderStage::ClosestHit
507-
| ShaderStage::Miss => unreachable!(),
574+
ShaderStage::RayGeneration => "ray_generation",
575+
ShaderStage::AnyHit => "any_hit",
576+
ShaderStage::ClosestHit => "closest_hit",
577+
ShaderStage::Miss => "miss",
508578
};
509579

510580
write!(self.out, "@{stage_str} ")?;
@@ -542,6 +612,9 @@ impl<W: Write> Writer<W> {
542612
write!(self.out, "@payload({payload_name}) ")?;
543613
}
544614
Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
615+
Attribute::IncomingRayPayload(ref payload_name) => {
616+
write!(self.out, "@incoming_payload({payload_name}) ")?;
617+
}
545618
};
546619
}
547620
Ok(())
@@ -1103,7 +1176,21 @@ impl<W: Write> Writer<W> {
11031176
self.write_expr(module, data.stride, func_ctx)?;
11041177
writeln!(self.out, ");")?
11051178
}
1106-
Statement::RayPipelineFunction(_) => unreachable!(),
1179+
Statement::RayPipelineFunction(fun) => match fun {
1180+
crate::RayPipelineFunction::TraceRay {
1181+
acceleration_structure,
1182+
descriptor,
1183+
payload,
1184+
} => {
1185+
write!(self.out, "{level}traceRay(")?;
1186+
self.write_expr(module, acceleration_structure, func_ctx)?;
1187+
write!(self.out, ", ")?;
1188+
self.write_expr(module, descriptor, func_ctx)?;
1189+
write!(self.out, ", ")?;
1190+
self.write_expr(module, payload, func_ctx)?;
1191+
writeln!(self.out, ");")?
1192+
}
1193+
},
11071194
}
11081195

11091196
Ok(())

naga/src/common/wgsl/to_wgsl.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,25 +195,26 @@ impl TryToWgsl for crate::BuiltIn {
195195
Bi::PrimitiveCount => "primitive_count",
196196
Bi::CullPrimitive => "cull_primitive",
197197

198+
Bi::RayInvocationId => "ray_invocation_id",
199+
Bi::NumRayInvocations => "num_ray_invocations",
200+
Bi::InstanceCustomData => "instance_custom_data",
201+
Bi::GeometryIndex => "geometry_index",
202+
Bi::WorldRayOrigin => "world_ray_origin",
203+
Bi::WorldRayDirection => "world_ray_direction",
204+
Bi::ObjectRayOrigin => "object_ray_origin",
205+
Bi::ObjectRayDirection => "object_ray_direction",
206+
Bi::RayTmin => "ray_t_min",
207+
Bi::RayTCurrentMax => "ray_t_current_max",
208+
Bi::ObjectToWorld => "object_to_world",
209+
Bi::WorldToObject => "world_to_object",
210+
Bi::HitKind => "hit_kind",
211+
198212
Bi::BaseInstance
199213
| Bi::BaseVertex
200214
| Bi::CullDistance
201215
| Bi::PointSize
202216
| Bi::PointCoord
203-
| Bi::WorkGroupSize
204-
| Bi::RayInvocationId
205-
| Bi::NumRayInvocations
206-
| Bi::InstanceCustomData
207-
| Bi::GeometryIndex
208-
| Bi::WorldRayOrigin
209-
| Bi::WorldRayDirection
210-
| Bi::ObjectRayOrigin
211-
| Bi::ObjectRayDirection
212-
| Bi::RayTmin
213-
| Bi::RayTCurrentMax
214-
| Bi::ObjectToWorld
215-
| Bi::WorldToObject
216-
| Bi::HitKind => return None,
217+
| Bi::WorkGroupSize => return None,
217218
})
218219
}
219220
}
@@ -387,7 +388,8 @@ pub const fn address_space_str(
387388
As::Handle => return (None, None),
388389
As::Function => "function",
389390
As::TaskPayload => "task_payload",
390-
As::IncomingRayPayload | As::RayPayload => return (None, None),
391+
As::IncomingRayPayload => "incoming_ray_payload",
392+
As::RayPayload => "ray_payload",
391393
}),
392394
None,
393395
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
capabilities = "RAY_TRACING_PIPELINE"
2-
targets = "IR"
2+
targets = "WGSL"

0 commit comments

Comments
 (0)