Skip to content

Commit 97bafd2

Browse files
authored
[SPIR-V] Add support for the extension VK_EXT_mesh_shader (#4725)
* Support VK_EXT_mesh_shader * Fix errors when compiling with SPV_NV_mesh_shader * Minor tweak onto amplification shaders * Fix pre-checkin failure * Add amplification and mesh tests for EXT_mesh_shader * Add some comments * Fix primitive indices Co-authored-by: Tianyuan <[email protected]>
1 parent 55be211 commit 97bafd2

20 files changed

Lines changed: 777 additions & 124 deletions

docs/SPIR-V.rst

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ Supported extensions
294294
* SPV_KHR_shader_draw_parameters
295295
* SPV_EXT_descriptor_indexing
296296
* SPV_EXT_fragment_fully_covered
297+
* SPV_EXT_mesh_shader
297298
* SPV_EXT_shader_stencil_support
298299
* SPV_AMD_shader_early_and_late_fragment_tests
299300
* SPV_AMD_shader_explicit_vertex_parameter
@@ -1522,13 +1523,15 @@ some system-value (SV) semantic strings will be translated into SPIR-V
15221523
| +-------------+----------------------------------------+-----------------------+-----------------------------+
15231524
| | DsIn | ``PrimitiveId`` | N/A | ``Tessellation`` |
15241525
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1525-
| SV_PrimitiveID | GSIn | ``PrimitiveId`` | N/A | ``Geometry`` |
1526-
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1526+
| | GSIn | ``PrimitiveId`` | N/A | ``Geometry`` |
1527+
| SV_PrimitiveID +-------------+----------------------------------------+-----------------------+-----------------------------+
15271528
| | GSOut | ``PrimitiveId`` | N/A | ``Geometry`` |
15281529
| +-------------+----------------------------------------+-----------------------+-----------------------------+
15291530
| | PSIn | ``PrimitiveId`` | N/A | ``Geometry`` |
15301531
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1531-
| | MSOut | ``PrimitiveId`` | N/A | ``MeshShadingNV`` |
1532+
| | | | | ``MeshShadingNV`` |
1533+
| | MSOut | ``PrimitiveId`` | N/A | |
1534+
| | | | | ``MeshShadingEXT`` |
15321535
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
15331536
| | PCOut | ``TessLevelOuter`` | N/A | ``Tessellation`` |
15341537
| SV_TessFactor +-------------+----------------------------------------+-----------------------+-----------------------------+
@@ -1546,15 +1549,19 @@ some system-value (SV) semantic strings will be translated into SPIR-V
15461549
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
15471550
| | GSOut | ``Layer`` | N/A | ``Geometry`` |
15481551
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1549-
| SV_RenderTargetArrayIndex | PSIn | ``Layer`` | N/A | ``Geometry`` |
1550-
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1551-
| | MSOut | ``Layer`` | N/A | ``MeshShadingNV`` |
1552+
| | PSIn | ``Layer`` | N/A | ``Geometry`` |
1553+
| SV_RenderTargetArrayIndex +-------------+----------------------------------------+-----------------------+-----------------------------+
1554+
| | | | | ``MeshShadingNV`` |
1555+
| | MSOut | ``Layer`` | N/A | |
1556+
| | | | | ``MeshShadingEXT`` |
15521557
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
15531558
| | GSOut | ``ViewportIndex`` | N/A | ``MultiViewport`` |
15541559
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1555-
| SV_ViewportArrayIndex | PSIn | ``ViewportIndex`` | N/A | ``MultiViewport`` |
1556-
| +-------------+----------------------------------------+-----------------------+-----------------------------+
1557-
| | MSOut | ``ViewportIndex`` | N/A | ``MeshShadingNV`` |
1560+
| | PSIn | ``ViewportIndex`` | N/A | ``MultiViewport`` |
1561+
| SV_ViewportArrayIndex +-------------+----------------------------------------+-----------------------+-----------------------------+
1562+
| | | | | ``MeshShadingNV`` |
1563+
| | MSOut | ``ViewportIndex`` | N/A | |
1564+
| | | | | ``MeshShadingEXT`` |
15581565
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
15591566
| | PSIn | ``SampleMask`` | N/A | ``Shader`` |
15601567
| SV_Coverage +-------------+----------------------------------------+-----------------------+-----------------------------+
@@ -1582,6 +1589,9 @@ some system-value (SV) semantic strings will be translated into SPIR-V
15821589
| +-------------+----------------------------------------+-----------------------+-----------------------------+
15831590
| | MSOut | ``PrimitiveShadingRateKHR`` | N/A | ``FragmentShadingRate`` |
15841591
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
1592+
| SV_CullPrimitive | MSOut | ``CullPrimitiveEXT`` | N/A | ``MeshShadingEXT `` |
1593+
+---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
1594+
15851595

15861596
For entities (function parameters, function return values, struct fields) with
15871597
the above SV semantic strings attached, SPIR-V variables of the
@@ -3409,26 +3419,34 @@ shaders and are translated to SPIR-V execution modes according to the table belo
34093419

34103420
.. table:: Mapping from HLSL attribute to SPIR-V execution mode
34113421

3412-
+-------------------+--------------------+-------------------------+
3413-
| HLSL Attribute | Value | SPIR-V Execution Mode |
3414-
+===================+====================+=========================+
3415-
|``outputtopology`` | ``point`` | ``OutputPoints`` |
3416-
| +--------------------+-------------------------+
3417-
|``(Mesh shader)`` | ``line`` | ``OutputLinesNV`` |
3418-
| +--------------------+-------------------------+
3419-
| | ``triangle`` | ``OutputTrianglesNV`` |
3420-
+-------------------+--------------------+-------------------------+
3421-
| ``numthreads`` | ``X, Y, Z`` | ``LocalSize X, Y, Z`` |
3422-
| | | |
3423-
| | ``(X*Y*Z <= 128)`` | |
3424-
+-------------------+--------------------+-------------------------+
3422+
+-----------------------+--------------------+-------------------------+
3423+
| HLSL Attribute | Value | SPIR-V Execution Mode |
3424+
+=======================+====================+=========================+
3425+
|``outputtopology`` | ``point`` | ``OutputPoints`` |
3426+
| +--------------------+-------------------------+
3427+
| (SPV_NV_mesh_shader) | ``line`` | ``OutputLinesNV`` |
3428+
| | | |
3429+
| +--------------------+-------------------------+
3430+
| | ``triangle`` | ``OutputTrianglesNV`` |
3431+
+-----------------------+--------------------+-------------------------+
3432+
|``outputtopology`` | ``point`` | ``OutputPoints`` |
3433+
| +--------------------+-------------------------+
3434+
| (SPV_EXT_mesh_shader) | ``line`` | ``OutputLinesEXT`` |
3435+
| | | |
3436+
| +--------------------+-------------------------+
3437+
| | ``triangle`` | ``OutputTrianglesEXT`` |
3438+
+-----------------------+--------------------+-------------------------+
3439+
| ``numthreads`` | ``X, Y, Z`` | ``LocalSize X, Y, Z`` |
3440+
| | | |
3441+
| | ``(X*Y*Z <= 128)`` | |
3442+
+-----------------------+--------------------+-------------------------+
34253443

34263444
Intrinsics
34273445
~~~~~~~~~~
34283446
The following HLSL intrinsics are used in Mesh or Amplification shaders
34293447
and are translated to SPIR-V intrinsics according to the table below:
34303448

3431-
.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics
3449+
.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics for SPV_NV_mesh_shader
34323450

34333451
+---------------------------+--------------------+-----------------------------------------+
34343452
| HLSL Intrinsic | Parameters | SPIR-V Intrinsic |
@@ -3446,6 +3464,24 @@ and are translated to SPIR-V intrinsics according to the table below:
34463464
| | ``MeshPayload`` | |
34473465
+---------------------------+--------------------+-----------------------------------------+
34483466

3467+
.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics for SPV_EXT_mesh_shader
3468+
3469+
+---------------------------+--------------------+--------------------------------------------------------------+
3470+
| HLSL Intrinsic | Parameters | SPIR-V Intrinsic |
3471+
+===========================+====================+==============================================================+
3472+
| ``SetMeshOutputCounts`` | ``numVertices`` | ``OpSetMeshOutputsEXT`` |
3473+
| | | |
3474+
| ``(Mesh shader)`` | ``numPrimitives`` | |
3475+
+---------------------------+--------------------+--------------------------------------------------------------+
3476+
| ``DispatchMesh`` | ``ThreadX`` | ``OpEmitMeshTasksEXT ThreadX ThreadY ThreadZ MeshPayload`` |
3477+
| | | |
3478+
| ``(Amplification shader)``| ``ThreadY`` | ``TaskCountNV ThreadX*ThreadY*ThreadZ`` |
3479+
| | | |
3480+
| | ``ThreadZ`` | |
3481+
| | | |
3482+
| | ``MeshPayload`` | |
3483+
+---------------------------+--------------------+--------------------------------------------------------------+
3484+
34493485
| Note : For ``DispatchMesh`` intrinsic, we also emit ``MeshPayload`` as output block with ``PerTaskNV`` decoration
34503486
34513487
Mesh Interface Variables

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ enum class Extension {
4343
EXT_descriptor_indexing,
4444
EXT_fragment_fully_covered,
4545
EXT_fragment_invocation_density,
46+
EXT_mesh_shader,
4647
EXT_shader_stencil_export,
4748
EXT_shader_viewport_index_layer,
4849
AMD_gpu_shader_half_float,

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,20 @@ class SpirvBuilder {
459459
/// \brief Creates an OpEndPrimitive instruction.
460460
void createEndPrimitive(SourceLocation, SourceRange range = {});
461461

462+
/// \brief Creates an OpEmitMeshTasksEXT instruction.
463+
void createEmitMeshTasksEXT(SpirvInstruction* xDim,
464+
SpirvInstruction* yDim,
465+
SpirvInstruction* zDim,
466+
SourceLocation loc,
467+
SpirvInstruction *payload = nullptr,
468+
SourceRange range = {});
469+
470+
/// \brief Creates an OpSetMeshOutputsEXT instruction.
471+
void createSetMeshOutputsEXT(SpirvInstruction* vertCount,
472+
SpirvInstruction* primCount,
473+
SourceLocation loc,
474+
SourceRange range = {});
475+
462476
/// \brief Creates an OpArrayLength instruction.
463477
SpirvArrayLength *createArrayLength(QualType resultType, SourceLocation loc,
464478
SpirvInstruction *structure,

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class SpirvInstruction {
8484
IK_Switch, // OpSwitch
8585
IK_Unreachable, // OpUnreachable
8686
IK_RayTracingTerminate, // OpIgnoreIntersectionKHR/OpTerminateRayKHR
87+
IK_EmitMeshTasksEXT, // OpEmitMeshTasksEXT
8788

8889
// Normal instruction kinds
8990
// In alphabetical order
@@ -107,6 +108,8 @@ class SpirvInstruction {
107108
IK_EndPrimitive, // OpEndPrimitive
108109
IK_EmitVertex, // OpEmitVertex
109110

111+
IK_SetMeshOutputsEXT, // OpSetMeshOutputsEXT
112+
110113
// The following section is for group non-uniform instructions.
111114
// Used by LLVM-style RTTI; order matters.
112115
IK_GroupNonUniformBinaryOp, // Group non-uniform binary operations
@@ -664,7 +667,7 @@ class SpirvTerminator : public SpirvInstruction {
664667
// For LLVM-style RTTI
665668
static bool classof(const SpirvInstruction *inst) {
666669
return inst->getKind() >= IK_Branch &&
667-
inst->getKind() <= IK_RayTracingTerminate;
670+
inst->getKind() <= IK_EmitMeshTasksEXT;
668671
}
669672

670673
protected:
@@ -2153,6 +2156,61 @@ class SpirvDebugInstruction : public SpirvInstruction {
21532156
SpirvExtInstImport *instructionSet;
21542157
};
21552158

2159+
/// \brief OpEmitMeshTasksEXT instruction.
2160+
class SpirvEmitMeshTasksEXT : public SpirvInstruction {
2161+
public:
2162+
SpirvEmitMeshTasksEXT(SpirvInstruction* xDim,
2163+
SpirvInstruction* yDim,
2164+
SpirvInstruction* zDim,
2165+
SpirvInstruction* payload,
2166+
SourceLocation loc, SourceRange range = {});
2167+
2168+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEmitMeshTasksEXT)
2169+
2170+
// For LLVM-style RTTI
2171+
static bool classof(const SpirvInstruction *inst) {
2172+
return inst->getKind() == IK_EmitMeshTasksEXT;
2173+
}
2174+
2175+
bool invokeVisitor(Visitor *v) override;
2176+
2177+
SpirvInstruction *getXDimension() const { return xDim; }
2178+
SpirvInstruction *getYDimension() const { return yDim; }
2179+
SpirvInstruction *getZDimension() const { return zDim; }
2180+
SpirvInstruction *getPayload() const { return payload; }
2181+
2182+
private:
2183+
SpirvInstruction *xDim;
2184+
SpirvInstruction *yDim;
2185+
SpirvInstruction *zDim;
2186+
SpirvInstruction *payload;
2187+
};
2188+
2189+
/// \brief OpSetMeshOutputsEXT instruction.
2190+
class SpirvSetMeshOutputsEXT : public SpirvInstruction {
2191+
public:
2192+
SpirvSetMeshOutputsEXT(SpirvInstruction* vertCount,
2193+
SpirvInstruction* primCount,
2194+
SourceLocation loc,
2195+
SourceRange range = {});
2196+
2197+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSetMeshOutputsEXT)
2198+
2199+
// For LLVM-style RTTI
2200+
static bool classof(const SpirvInstruction *inst) {
2201+
return inst->getKind() == IK_SetMeshOutputsEXT;
2202+
}
2203+
2204+
bool invokeVisitor(Visitor *v) override;
2205+
2206+
SpirvInstruction *getVertexCount() const { return vertCount; }
2207+
SpirvInstruction *getPrimitiveCount() const { return primCount; }
2208+
2209+
private:
2210+
SpirvInstruction *vertCount;
2211+
SpirvInstruction *primCount;
2212+
};
2213+
21562214
class SpirvDebugInfoNone : public SpirvDebugInstruction {
21572215
public:
21582216
SpirvDebugInfoNone();

tools/clang/include/clang/SPIRV/SpirvVisitor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ class Visitor {
143143
DEFINE_VISIT_METHOD(SpirvReadClock)
144144
DEFINE_VISIT_METHOD(SpirvRayTracingTerminateOpKHR)
145145
DEFINE_VISIT_METHOD(SpirvIntrinsicInstruction)
146+
147+
DEFINE_VISIT_METHOD(SpirvEmitMeshTasksEXT)
148+
DEFINE_VISIT_METHOD(SpirvSetMeshOutputsEXT)
146149
#undef DEFINE_VISIT_METHOD
147150

148151
const SpirvCodeGenOptions &getCodeGenOptions() const { return spvOptions; }

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
307307
case spv::BuiltIn::PrimitiveId: {
308308
// PrimitiveID can be used as PSIn or MSPOut.
309309
if (shaderModel == spv::ExecutionModel::Fragment ||
310-
shaderModel == spv::ExecutionModel::MeshNV)
310+
shaderModel == spv::ExecutionModel::MeshNV ||
311+
shaderModel == spv::ExecutionModel::MeshEXT)
311312
addCapability(spv::Capability::Geometry);
312313
break;
313314
}
@@ -324,7 +325,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
324325
addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
325326
}
326327
} else if (shaderModel == spv::ExecutionModel::Fragment ||
327-
shaderModel == spv::ExecutionModel::MeshNV) {
328+
shaderModel == spv::ExecutionModel::MeshNV ||
329+
shaderModel == spv::ExecutionModel::MeshEXT) {
328330
// SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
329331
addCapability(spv::Capability::Geometry);
330332
}
@@ -343,7 +345,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
343345
}
344346
} else if (shaderModel == spv::ExecutionModel::Fragment ||
345347
shaderModel == spv::ExecutionModel::Geometry ||
346-
shaderModel == spv::ExecutionModel::MeshNV) {
348+
shaderModel == spv::ExecutionModel::MeshNV ||
349+
shaderModel == spv::ExecutionModel::MeshEXT) {
347350
// SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
348351
addCapability(spv::Capability::MultiViewport);
349352
}
@@ -558,6 +561,17 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
558561
}
559562
}
560563

564+
case spv::Op::OpSetMeshOutputsEXT:
565+
case spv::Op::OpEmitMeshTasksEXT: {
566+
if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
567+
featureManager.requestTargetEnv(SPV_ENV_UNIVERSAL_1_4, "MeshShader",
568+
{});
569+
addCapability(spv::Capability::MeshShadingEXT);
570+
addExtension(Extension::EXT_mesh_shader, "SPV_EXT_mesh_shader", {});
571+
}
572+
break;
573+
}
574+
561575
default:
562576
break;
563577
}
@@ -603,6 +617,11 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
603617
addCapability(spv::Capability::MeshShadingNV);
604618
addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
605619
break;
620+
case spv::ExecutionModel::MeshEXT:
621+
case spv::ExecutionModel::TaskEXT:
622+
addCapability(spv::Capability::MeshShadingEXT);
623+
addExtension(Extension::EXT_mesh_shader, "SPV_EXT_mesh_shader", {});
624+
break;
606625
default:
607626
llvm_unreachable("found unknown shader model");
608627
break;

0 commit comments

Comments
 (0)