Skip to content

Commit 8ccf165

Browse files
WIP: refactor(core): hoist shader bit-to-naga conversion in call stack
TODO: split movement of `enum` mapping out TODO: explain - Prep for a new `enum` later, `naga::ShaderStage` is the closest thing.
1 parent c5e460f commit 8ccf165

3 files changed

Lines changed: 47 additions & 35 deletions

File tree

wgpu-core/src/device/resource.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ use crate::{
5252
snatch::{SnatchGuard, SnatchLock, Snatchable},
5353
timestamp_normalization::TIMESTAMP_NORMALIZATION_BUFFER_USES,
5454
track::{BindGroupStates, DeviceTracker, TrackerIndexAllocators, UsageScope, UsageScopePool},
55-
validation,
55+
validation::{self, stage_bit_from_shader_stage},
5656
weak_vec::WeakVec,
5757
FastHashMap, LabelHelpers, OnceCellOrLock,
5858
};
@@ -3743,7 +3743,7 @@ impl Device {
37433743
let final_entry_point_name;
37443744

37453745
{
3746-
let stage = wgt::ShaderStages::COMPUTE;
3746+
let stage = naga::ShaderStage::Compute;
37473747

37483748
final_entry_point_name = shader_module.finalize_entry_point_name(
37493749
stage,
@@ -4230,13 +4230,16 @@ impl Device {
42304230
pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) => {
42314231
vertex_stage = {
42324232
let stage_desc = &vertex.stage;
4233-
let stage = wgt::ShaderStages::VERTEX;
4233+
let stage = naga::ShaderStage::Vertex;
4234+
let stage_bit = stage_bit_from_shader_stage(stage);
42344235

42354236
let vertex_shader_module = &stage_desc.module;
42364237
vertex_shader_module.same_device(self)?;
42374238

4238-
let stage_err =
4239-
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
4239+
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage {
4240+
stage: stage_bit,
4241+
error,
4242+
};
42404243

42414244
_vertex_entry_point_name = vertex_shader_module
42424245
.finalize_entry_point_name(
@@ -4256,7 +4259,7 @@ impl Device {
42564259
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
42574260
)
42584261
.map_err(stage_err)?;
4259-
validated_stages |= stage;
4262+
validated_stages |= stage_bit;
42604263
}
42614264
Some(hal::ProgrammableStage {
42624265
module: vertex_shader_module.raw(),
@@ -4272,12 +4275,15 @@ impl Device {
42724275

42734276
task_stage = if let Some(task) = task {
42744277
let stage_desc = &task.stage;
4275-
let stage = wgt::ShaderStages::TASK;
4278+
let stage = naga::ShaderStage::Task;
4279+
let stage_bit = stage_bit_from_shader_stage(stage);
42764280
let task_shader_module = &stage_desc.module;
42774281
task_shader_module.same_device(self)?;
42784282

4279-
let stage_err =
4280-
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
4283+
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage {
4284+
stage: stage_bit,
4285+
error,
4286+
};
42814287

42824288
_task_entry_point_name = task_shader_module
42834289
.finalize_entry_point_name(
@@ -4297,7 +4303,7 @@ impl Device {
42974303
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
42984304
)
42994305
.map_err(stage_err)?;
4300-
validated_stages |= stage;
4306+
validated_stages |= stage_bit;
43014307
}
43024308
Some(hal::ProgrammableStage {
43034309
module: task_shader_module.raw(),
@@ -4311,12 +4317,15 @@ impl Device {
43114317
};
43124318
mesh_stage = {
43134319
let stage_desc = &mesh.stage;
4314-
let stage = wgt::ShaderStages::MESH;
4320+
let stage = naga::ShaderStage::Mesh;
4321+
let stage_bit = stage_bit_from_shader_stage(stage);
43154322
let mesh_shader_module = &stage_desc.module;
43164323
mesh_shader_module.same_device(self)?;
43174324

4318-
let stage_err =
4319-
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
4325+
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage {
4326+
stage: stage_bit,
4327+
error,
4328+
};
43204329

43214330
_mesh_entry_point_name = mesh_shader_module
43224331
.finalize_entry_point_name(
@@ -4336,7 +4345,7 @@ impl Device {
43364345
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
43374346
)
43384347
.map_err(stage_err)?;
4339-
validated_stages |= stage;
4348+
validated_stages |= stage_bit;
43404349
}
43414350
Some(hal::ProgrammableStage {
43424351
module: mesh_shader_module.raw(),
@@ -4352,12 +4361,16 @@ impl Device {
43524361
let fragment_entry_point_name;
43534362
let fragment_stage = match desc.fragment {
43544363
Some(ref fragment_state) => {
4355-
let stage = wgt::ShaderStages::FRAGMENT;
4364+
let stage = naga::ShaderStage::Fragment;
4365+
let stage_bit = stage_bit_from_shader_stage(stage);
43564366

43574367
let shader_module = &fragment_state.stage.module;
43584368
shader_module.same_device(self)?;
43594369

4360-
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
4370+
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage {
4371+
stage: stage_bit,
4372+
error,
4373+
};
43614374

43624375
fragment_entry_point_name = shader_module
43634376
.finalize_entry_point_name(
@@ -4381,14 +4394,14 @@ impl Device {
43814394
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
43824395
)
43834396
.map_err(stage_err)?;
4384-
validated_stages |= stage;
4397+
validated_stages |= stage_bit;
43854398
}
43864399

43874400
if let Some(ref interface) = shader_module.interface {
43884401
shader_expects_dual_source_blending = interface
43894402
.fragment_uses_dual_source_blending(&fragment_entry_point_name)
43904403
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
4391-
stage,
4404+
stage: stage_bit,
43924405
error,
43934406
})?;
43944407
}

wgpu-core/src/pipeline.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ impl ShaderModule {
8989

9090
pub(crate) fn finalize_entry_point_name(
9191
&self,
92-
stage_bit: wgt::ShaderStages,
92+
stage: naga::ShaderStage,
9393
entry_point: Option<&str>,
9494
) -> Result<String, validation::StageError> {
9595
match &self.interface {
96-
Some(interface) => interface.finalize_entry_point_name(stage_bit, entry_point),
96+
Some(interface) => interface.finalize_entry_point_name(stage, entry_point),
9797
None => entry_point
9898
.map(|ep| ep.to_string())
9999
.ok_or(validation::StageError::NoEntryPointFound),

wgpu-core/src/validation.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,10 +1148,9 @@ impl Interface {
11481148

11491149
pub fn finalize_entry_point_name(
11501150
&self,
1151-
stage_bit: wgt::ShaderStages,
1151+
stage: naga::ShaderStage,
11521152
entry_point_name: Option<&str>,
11531153
) -> Result<String, StageError> {
1154-
let stage = Self::shader_stage_from_stage_bit(stage_bit);
11551154
entry_point_name
11561155
.map(|ep| ep.to_string())
11571156
.map(Ok)
@@ -1168,36 +1167,26 @@ impl Interface {
11681167
})
11691168
}
11701169

1171-
pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage {
1172-
match stage_bit {
1173-
wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex,
1174-
wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment,
1175-
wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute,
1176-
wgt::ShaderStages::MESH => naga::ShaderStage::Mesh,
1177-
wgt::ShaderStages::TASK => naga::ShaderStage::Task,
1178-
_ => unreachable!(),
1179-
}
1180-
}
1181-
11821170
pub fn check_stage(
11831171
&self,
11841172
layouts: &mut BindingLayoutSource<'_>,
11851173
shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
11861174
entry_point_name: &str,
1187-
stage_bit: wgt::ShaderStages,
1175+
shader_stage: naga::ShaderStage,
11881176
inputs: StageIo,
11891177
compare_function: Option<wgt::CompareFunction>,
11901178
) -> Result<StageIo, StageError> {
11911179
// Since a shader module can have multiple entry points with the same name,
11921180
// we need to look for one with the right execution model.
1193-
let shader_stage = Self::shader_stage_from_stage_bit(stage_bit);
11941181
let pair = (shader_stage, entry_point_name.to_string());
11951182
let entry_point = match self.entry_points.get(&pair) {
11961183
Some(some) => some,
11971184
None => return Err(StageError::MissingEntryPoint(pair.1)),
11981185
};
11991186
let (_, entry_point_name) = pair;
12001187

1188+
let stage_bit = stage_bit_from_shader_stage(shader_stage);
1189+
12011190
// check resources visibility
12021191
for &handle in entry_point.resources.iter() {
12031192
let res = &self.resources[handle];
@@ -1624,3 +1613,13 @@ pub fn validate_color_attachment_bytes_per_sample(
16241613

16251614
Ok(())
16261615
}
1616+
1617+
pub(crate) fn stage_bit_from_shader_stage(shader_stage: naga::ShaderStage) -> wgt::ShaderStages {
1618+
match shader_stage {
1619+
naga::ShaderStage::Vertex => wgt::ShaderStages::VERTEX,
1620+
naga::ShaderStage::Fragment => wgt::ShaderStages::FRAGMENT,
1621+
naga::ShaderStage::Compute => wgt::ShaderStages::COMPUTE,
1622+
naga::ShaderStage::Mesh => wgt::ShaderStages::MESH,
1623+
naga::ShaderStage::Task => wgt::ShaderStages::TASK,
1624+
}
1625+
}

0 commit comments

Comments
 (0)