Skip to content

Commit e84e1c7

Browse files
refactor(core): use new ShaderStageForValidation
TODO: explain - bundle state needed into same place - Discovery: `compare_function` was being provided, but not checked, in other places. Intentional?
1 parent 8ccf165 commit e84e1c7

2 files changed

Lines changed: 61 additions & 41 deletions

File tree

wgpu-core/src/device/resource.rs

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ use crate::{
5252
snatch::{SnatchGuard, SnatchLock, Snatchable},
5353
timestamp_normalization::TIMESTAMP_NORMALIZATION_BUFFER_USES,
5454
track::{BindGroupStates, DeviceTracker, TrackerIndexAllocators, UsageScope, UsageScopePool},
55-
validation::{self, stage_bit_from_shader_stage},
55+
validation,
5656
weak_vec::WeakVec,
5757
FastHashMap, LabelHelpers, OnceCellOrLock,
5858
};
@@ -3743,10 +3743,10 @@ impl Device {
37433743
let final_entry_point_name;
37443744

37453745
{
3746-
let stage = naga::ShaderStage::Compute;
3746+
let stage = validation::ShaderStageForValidation::Compute;
37473747

37483748
final_entry_point_name = shader_module.finalize_entry_point_name(
3749-
stage,
3749+
stage.to_naga(),
37503750
desc.stage.entry_point.as_ref().map(|ep| ep.as_ref()),
37513751
)?;
37523752

@@ -3757,7 +3757,6 @@ impl Device {
37573757
&final_entry_point_name,
37583758
stage,
37593759
io,
3760-
None,
37613760
)?;
37623761
}
37633762
}
@@ -4230,8 +4229,10 @@ impl Device {
42304229
pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) => {
42314230
vertex_stage = {
42324231
let stage_desc = &vertex.stage;
4233-
let stage = naga::ShaderStage::Vertex;
4234-
let stage_bit = stage_bit_from_shader_stage(stage);
4232+
let stage = validation::ShaderStageForValidation::Vertex {
4233+
compare_function: desc.depth_stencil.as_ref().map(|d| d.depth_compare),
4234+
};
4235+
let stage_bit = stage.to_wgt_bit();
42354236

42364237
let vertex_shader_module = &stage_desc.module;
42374238
vertex_shader_module.same_device(self)?;
@@ -4243,7 +4244,7 @@ impl Device {
42434244

42444245
_vertex_entry_point_name = vertex_shader_module
42454246
.finalize_entry_point_name(
4246-
stage,
4247+
stage.to_naga(),
42474248
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
42484249
)
42494250
.map_err(stage_err)?;
@@ -4256,7 +4257,6 @@ impl Device {
42564257
&_vertex_entry_point_name,
42574258
stage,
42584259
io,
4259-
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
42604260
)
42614261
.map_err(stage_err)?;
42624262
validated_stages |= stage_bit;
@@ -4275,8 +4275,8 @@ impl Device {
42754275

42764276
task_stage = if let Some(task) = task {
42774277
let stage_desc = &task.stage;
4278-
let stage = naga::ShaderStage::Task;
4279-
let stage_bit = stage_bit_from_shader_stage(stage);
4278+
let stage = validation::ShaderStageForValidation::Task;
4279+
let stage_bit = stage.to_wgt_bit();
42804280
let task_shader_module = &stage_desc.module;
42814281
task_shader_module.same_device(self)?;
42824282

@@ -4287,7 +4287,7 @@ impl Device {
42874287

42884288
_task_entry_point_name = task_shader_module
42894289
.finalize_entry_point_name(
4290-
stage,
4290+
stage.to_naga(),
42914291
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
42924292
)
42934293
.map_err(stage_err)?;
@@ -4300,7 +4300,6 @@ impl Device {
43004300
&_task_entry_point_name,
43014301
stage,
43024302
io,
4303-
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
43044303
)
43054304
.map_err(stage_err)?;
43064305
validated_stages |= stage_bit;
@@ -4317,8 +4316,8 @@ impl Device {
43174316
};
43184317
mesh_stage = {
43194318
let stage_desc = &mesh.stage;
4320-
let stage = naga::ShaderStage::Mesh;
4321-
let stage_bit = stage_bit_from_shader_stage(stage);
4319+
let stage = validation::ShaderStageForValidation::Mesh;
4320+
let stage_bit = stage.to_wgt_bit();
43224321
let mesh_shader_module = &stage_desc.module;
43234322
mesh_shader_module.same_device(self)?;
43244323

@@ -4329,7 +4328,7 @@ impl Device {
43294328

43304329
_mesh_entry_point_name = mesh_shader_module
43314330
.finalize_entry_point_name(
4332-
stage,
4331+
stage.to_naga(),
43334332
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
43344333
)
43354334
.map_err(stage_err)?;
@@ -4342,7 +4341,6 @@ impl Device {
43424341
&_mesh_entry_point_name,
43434342
stage,
43444343
io,
4345-
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
43464344
)
43474345
.map_err(stage_err)?;
43484346
validated_stages |= stage_bit;
@@ -4361,8 +4359,8 @@ impl Device {
43614359
let fragment_entry_point_name;
43624360
let fragment_stage = match desc.fragment {
43634361
Some(ref fragment_state) => {
4364-
let stage = naga::ShaderStage::Fragment;
4365-
let stage_bit = stage_bit_from_shader_stage(stage);
4362+
let stage = validation::ShaderStageForValidation::Fragment;
4363+
let stage_bit = stage.to_wgt_bit();
43664364

43674365
let shader_module = &fragment_state.stage.module;
43684366
shader_module.same_device(self)?;
@@ -4374,7 +4372,7 @@ impl Device {
43744372

43754373
fragment_entry_point_name = shader_module
43764374
.finalize_entry_point_name(
4377-
stage,
4375+
stage.to_naga(),
43784376
fragment_state
43794377
.stage
43804378
.entry_point
@@ -4391,7 +4389,6 @@ impl Device {
43914389
&fragment_entry_point_name,
43924390
stage,
43934391
io,
4394-
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
43954392
)
43964393
.map_err(stage_err)?;
43974394
validated_stages |= stage_bit;

wgpu-core/src/validation.rs

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,20 +1172,19 @@ impl Interface {
11721172
layouts: &mut BindingLayoutSource<'_>,
11731173
shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
11741174
entry_point_name: &str,
1175-
shader_stage: naga::ShaderStage,
1175+
shader_stage: ShaderStageForValidation,
11761176
inputs: StageIo,
1177-
compare_function: Option<wgt::CompareFunction>,
11781177
) -> Result<StageIo, StageError> {
11791178
// Since a shader module can have multiple entry points with the same name,
11801179
// we need to look for one with the right execution model.
1181-
let pair = (shader_stage, entry_point_name.to_string());
1180+
let pair = (shader_stage.to_naga(), entry_point_name.to_string());
11821181
let entry_point = match self.entry_points.get(&pair) {
11831182
Some(some) => some,
11841183
None => return Err(StageError::MissingEntryPoint(pair.1)),
11851184
};
11861185
let (_, entry_point_name) = pair;
11871186

1188-
let stage_bit = stage_bit_from_shader_stage(shader_stage);
1187+
let stage_bit = shader_stage.to_wgt_bit();
11891188

11901189
// check resources visibility
11911190
for &handle in entry_point.resources.iter() {
@@ -1308,13 +1307,13 @@ impl Interface {
13081307
}
13091308

13101309
// check workgroup size limits
1311-
if shader_stage.compute_like() {
1310+
if shader_stage.to_naga().compute_like() {
13121311
let (
13131312
max_workgroup_size_limits,
13141313
max_workgroup_size_total,
13151314
per_dimension_limit,
13161315
total_limit,
1317-
) = match shader_stage {
1316+
) = match shader_stage.to_naga() {
13181317
naga::ShaderStage::Compute => (
13191318
[
13201319
self.limits.max_compute_workgroup_size_x,
@@ -1380,7 +1379,7 @@ impl Interface {
13801379
.ok_or(InputError::Missing)
13811380
.and_then(|provided| {
13821381
let (compatible, num_components, per_primitive_correct) =
1383-
match shader_stage {
1382+
match shader_stage.to_naga() {
13841383
// For vertex attributes, there are defaults filled out
13851384
// by the driver if data is not provided.
13861385
naga::ShaderStage::Vertex => {
@@ -1445,7 +1444,9 @@ impl Interface {
14451444
}
14461445

14471446
match shader_stage {
1448-
naga::ShaderStage::Vertex => {
1447+
ShaderStageForValidation::Vertex {
1448+
compare_function,
1449+
} => {
14491450
for output in entry_point.outputs.iter() {
14501451
//TODO: count builtins towards the limit?
14511452
inter_stage_components += match *output {
@@ -1478,7 +1479,7 @@ impl Interface {
14781479
}
14791480
}
14801481
}
1481-
naga::ShaderStage::Fragment => {
1482+
ShaderStageForValidation::Fragment => {
14821483
for output in &entry_point.outputs {
14831484
let &Varying::Local { location, ref iv } = output else {
14841485
continue;
@@ -1524,7 +1525,7 @@ impl Interface {
15241525
});
15251526
}
15261527
}
1527-
if shader_stage == naga::ShaderStage::Mesh
1528+
if shader_stage.to_naga() == naga::ShaderStage::Mesh
15281529
&& entry_point.task_payload_size != inputs.task_payload_size
15291530
{
15301531
return Err(StageError::TaskPayloadMustMatch {
@@ -1534,18 +1535,18 @@ impl Interface {
15341535
}
15351536

15361537
// Fragment shader primitive index is treated like a varying
1537-
if shader_stage == naga::ShaderStage::Fragment
1538+
if shader_stage.to_naga() == naga::ShaderStage::Fragment
15381539
&& this_stage_primitive_index
15391540
&& inputs.primitive_index == Some(false)
15401541
{
15411542
return Err(StageError::InvalidPrimitiveIndex);
1542-
} else if shader_stage == naga::ShaderStage::Fragment
1543+
} else if shader_stage.to_naga() == naga::ShaderStage::Fragment
15431544
&& !this_stage_primitive_index
15441545
&& inputs.primitive_index == Some(true)
15451546
{
15461547
return Err(StageError::MissingPrimitiveIndex);
15471548
}
1548-
if shader_stage == naga::ShaderStage::Mesh
1549+
if shader_stage.to_naga() == naga::ShaderStage::Mesh
15491550
&& inputs.task_payload_size.is_some()
15501551
&& has_draw_id
15511552
{
@@ -1564,7 +1565,7 @@ impl Interface {
15641565
Ok(StageIo {
15651566
task_payload_size: entry_point.task_payload_size,
15661567
varyings: outputs,
1567-
primitive_index: if shader_stage == naga::ShaderStage::Mesh {
1568+
primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
15681569
Some(this_stage_primitive_index)
15691570
} else {
15701571
None
@@ -1614,12 +1615,34 @@ pub fn validate_color_attachment_bytes_per_sample(
16141615
Ok(())
16151616
}
16161617

1617-
pub(crate) fn stage_bit_from_shader_stage(shader_stage: naga::ShaderStage) -> wgt::ShaderStages {
1618-
match shader_stage {
1619-
naga::ShaderStage::Vertex => wgt::ShaderStages::VERTEX,
1620-
naga::ShaderStage::Fragment => wgt::ShaderStages::FRAGMENT,
1621-
naga::ShaderStage::Compute => wgt::ShaderStages::COMPUTE,
1622-
naga::ShaderStage::Mesh => wgt::ShaderStages::MESH,
1623-
naga::ShaderStage::Task => wgt::ShaderStages::TASK,
1618+
pub enum ShaderStageForValidation {
1619+
Vertex {
1620+
compare_function: Option<wgt::CompareFunction>,
1621+
},
1622+
Mesh,
1623+
Fragment,
1624+
Compute,
1625+
Task,
1626+
}
1627+
1628+
impl ShaderStageForValidation {
1629+
pub fn to_naga(&self) -> naga::ShaderStage {
1630+
match self {
1631+
Self::Vertex { .. } => naga::ShaderStage::Vertex,
1632+
Self::Mesh { .. } => naga::ShaderStage::Mesh,
1633+
Self::Fragment { .. } => naga::ShaderStage::Fragment,
1634+
Self::Compute => naga::ShaderStage::Compute,
1635+
Self::Task => naga::ShaderStage::Task,
1636+
}
1637+
}
1638+
1639+
pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
1640+
match self {
1641+
Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
1642+
Self::Mesh { .. } => wgt::ShaderStages::MESH,
1643+
Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
1644+
Self::Compute => wgt::ShaderStages::COMPUTE,
1645+
Self::Task => wgt::ShaderStages::TASK,
1646+
}
16241647
}
16251648
}

0 commit comments

Comments
 (0)