Skip to content

Commit fb74718

Browse files
alelenvs-perron
andauthored
Add validation support for SPV_EXT_shader_invocation_reorder. (#6401)
Co-authored-by: Steven Perron <[email protected]>
1 parent 3b94e14 commit fb74718

15 files changed

Lines changed: 1542 additions & 11 deletions

source/opcode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
267267
// spv::Op::OpTypeAccelerationStructureNV
268268
case spv::Op::OpTypeRayQueryKHR:
269269
case spv::Op::OpTypeHitObjectNV:
270+
case spv::Op::OpTypeHitObjectEXT:
270271
case spv::Op::OpTypeUntypedPointerKHR:
271272
case spv::Op::OpTypeNodePayloadArrayAMDX:
272273
case spv::Op::OpTypeTensorLayoutNV:

source/opt/ir_context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) {
557557
(uint32_t)spv::Op::OpTypeAccelerationStructureKHR,
558558
(uint32_t)spv::Op::OpTypeRayQueryKHR,
559559
(uint32_t)spv::Op::OpTypeHitObjectNV,
560+
(uint32_t)spv::Op::OpTypeHitObjectEXT,
560561
(uint32_t)spv::Op::OpTypeArray,
561562
(uint32_t)spv::Op::OpTypeRuntimeArray,
562563
(uint32_t)spv::Op::OpTypeNodePayloadArrayAMDX,

source/opt/type_manager.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
237237
DefineParameterlessCase(AccelerationStructureNV);
238238
DefineParameterlessCase(RayQueryKHR);
239239
DefineParameterlessCase(HitObjectNV);
240+
DefineParameterlessCase(HitObjectEXT);
240241
#undef DefineParameterlessCase
241242
case Type::kInteger:
242243
typeInst = MakeUnique<Instruction>(
@@ -654,6 +655,7 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
654655
DefineNoSubtypeCase(AccelerationStructureNV);
655656
DefineNoSubtypeCase(RayQueryKHR);
656657
DefineNoSubtypeCase(HitObjectNV);
658+
DefineNoSubtypeCase(HitObjectEXT);
657659
#undef DefineNoSubtypeCase
658660
case Type::kVector: {
659661
const Vector* vec_ty = type.AsVector();
@@ -1082,6 +1084,9 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
10821084
case spv::Op::OpTypeHitObjectNV:
10831085
type = new HitObjectNV();
10841086
break;
1087+
case spv::Op::OpTypeHitObjectEXT:
1088+
type = new HitObjectEXT();
1089+
break;
10851090
case spv::Op::OpTypeTensorLayoutNV:
10861091
type = new TensorLayoutNV(inst.GetSingleWordInOperand(0),
10871092
inst.GetSingleWordInOperand(1));

source/opt/types.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ std::unique_ptr<Type> Type::Clone() const {
135135
DeclareKindCase(CooperativeVectorNV);
136136
DeclareKindCase(RayQueryKHR);
137137
DeclareKindCase(HitObjectNV);
138+
DeclareKindCase(HitObjectEXT);
138139
DeclareKindCase(TensorARM);
139140
DeclareKindCase(GraphARM);
140141
#undef DeclareKindCase
@@ -187,6 +188,7 @@ bool Type::operator==(const Type& other) const {
187188
DeclareKindCase(CooperativeVectorNV);
188189
DeclareKindCase(RayQueryKHR);
189190
DeclareKindCase(HitObjectNV);
191+
DeclareKindCase(HitObjectEXT);
190192
DeclareKindCase(TensorLayoutNV);
191193
DeclareKindCase(TensorViewNV);
192194
DeclareKindCase(TensorARM);
@@ -249,6 +251,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
249251
DeclareKindCase(CooperativeVectorNV);
250252
DeclareKindCase(RayQueryKHR);
251253
DeclareKindCase(HitObjectNV);
254+
DeclareKindCase(HitObjectEXT);
252255
DeclareKindCase(TensorLayoutNV);
253256
DeclareKindCase(TensorViewNV);
254257
DeclareKindCase(TensorARM);

source/opt/types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class CooperativeMatrixKHR;
6767
class CooperativeVectorNV;
6868
class RayQueryKHR;
6969
class HitObjectNV;
70+
class HitObjectEXT;
7071
class TensorLayoutNV;
7172
class TensorViewNV;
7273
class TensorARM;
@@ -114,6 +115,7 @@ class Type {
114115
kCooperativeVectorNV,
115116
kRayQueryKHR,
116117
kHitObjectNV,
118+
kHitObjectEXT,
117119
kTensorLayoutNV,
118120
kTensorViewNV,
119121
kTensorARM,
@@ -222,6 +224,7 @@ class Type {
222224
DeclareCastMethod(CooperativeVectorNV)
223225
DeclareCastMethod(RayQueryKHR)
224226
DeclareCastMethod(HitObjectNV)
227+
DeclareCastMethod(HitObjectEXT)
225228
DeclareCastMethod(TensorLayoutNV)
226229
DeclareCastMethod(TensorViewNV)
227230
DeclareCastMethod(TensorARM)
@@ -862,6 +865,7 @@ DefineParameterlessType(NamedBarrier, named_barrier);
862865
DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV);
863866
DefineParameterlessType(RayQueryKHR, rayQueryKHR);
864867
DefineParameterlessType(HitObjectNV, hitObjectNV);
868+
DefineParameterlessType(HitObjectEXT, hitObjectEXT);
865869
#undef DefineParameterlessType
866870

867871
} // namespace analysis

source/val/validate.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
399399
if (auto error = RayQueryPass(*vstate, &instruction)) return error;
400400
if (auto error = RayTracingPass(*vstate, &instruction)) return error;
401401
if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
402+
if (auto error = RayReorderEXTPass(*vstate, &instruction)) return error;
402403
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
403404
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
404405
if (auto error = TensorPass(*vstate, &instruction)) return error;

source/val/validate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst);
220220
/// Validates correctness of shader execution reorder instructions.
221221
spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst);
222222

223+
/// Validates correctness of shader execution reorder EXT instructions.
224+
spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst);
225+
223226
/// Validates correctness of mesh shading instructions.
224227
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
225228

source/val/validate_annotation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
217217
sc != spv::StorageClass::IncomingCallableDataKHR &&
218218
sc != spv::StorageClass::ShaderRecordBufferKHR &&
219219
sc != spv::StorageClass::HitObjectAttributeNV &&
220+
sc != spv::StorageClass::HitObjectAttributeEXT &&
220221
sc != spv::StorageClass::TileImageEXT) {
221222
return _.diag(SPV_ERROR_INVALID_ID, target)
222223
<< _.VkErrorID(6672) << _.SpvDecorationString(dec)

source/val/validate_extensions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,7 @@ spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
10881088
ExtensionToString(kSPV_KHR_workgroup_memory_explicit_layout) ||
10891089
extension == ExtensionToString(kSPV_EXT_mesh_shader) ||
10901090
extension == ExtensionToString(kSPV_NV_shader_invocation_reorder) ||
1091+
extension == ExtensionToString(kSPV_EXT_shader_invocation_reorder) ||
10911092
extension ==
10921093
ExtensionToString(kSPV_NV_cluster_acceleration_structure) ||
10931094
extension == ExtensionToString(kSPV_NV_linear_swept_spheres) ||

source/val/validate_logical_pointers.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,42 @@ spv_result_t ValidateLogicalPointerOperands(ValidationState_t& _,
247247
case spv::Op::OpHitObjectIsEmptyNV:
248248
case spv::Op::OpHitObjectIsHitNV:
249249
case spv::Op::OpHitObjectIsMissNV:
250+
// SPV_EXT_shader_invocation_reorder
251+
case spv::Op::OpHitObjectRecordFromQueryEXT:
252+
case spv::Op::OpHitObjectRecordMissEXT:
253+
case spv::Op::OpHitObjectRecordMissMotionEXT:
254+
case spv::Op::OpHitObjectGetIntersectionTriangleVertexPositionsEXT:
255+
case spv::Op::OpHitObjectGetRayFlagsEXT:
256+
case spv::Op::OpHitObjectSetShaderBindingTableRecordIndexEXT:
257+
case spv::Op::OpHitObjectReorderExecuteShaderEXT:
258+
case spv::Op::OpHitObjectTraceReorderExecuteEXT:
259+
case spv::Op::OpHitObjectTraceMotionReorderExecuteEXT:
260+
case spv::Op::OpReorderThreadWithHintEXT:
261+
case spv::Op::OpReorderThreadWithHitObjectEXT:
262+
case spv::Op::OpHitObjectTraceRayEXT:
263+
case spv::Op::OpHitObjectTraceRayMotionEXT:
264+
case spv::Op::OpHitObjectRecordEmptyEXT:
265+
case spv::Op::OpHitObjectExecuteShaderEXT:
266+
case spv::Op::OpHitObjectGetCurrentTimeEXT:
267+
case spv::Op::OpHitObjectGetAttributesEXT:
268+
case spv::Op::OpHitObjectGetHitKindEXT:
269+
case spv::Op::OpHitObjectGetPrimitiveIndexEXT:
270+
case spv::Op::OpHitObjectGetGeometryIndexEXT:
271+
case spv::Op::OpHitObjectGetInstanceIdEXT:
272+
case spv::Op::OpHitObjectGetInstanceCustomIndexEXT:
273+
case spv::Op::OpHitObjectGetObjectRayOriginEXT:
274+
case spv::Op::OpHitObjectGetObjectRayDirectionEXT:
275+
case spv::Op::OpHitObjectGetWorldRayDirectionEXT:
276+
case spv::Op::OpHitObjectGetWorldRayOriginEXT:
277+
case spv::Op::OpHitObjectGetObjectToWorldEXT:
278+
case spv::Op::OpHitObjectGetWorldToObjectEXT:
279+
case spv::Op::OpHitObjectGetRayTMaxEXT:
280+
case spv::Op::OpHitObjectGetRayTMinEXT:
281+
case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexEXT:
282+
case spv::Op::OpHitObjectGetShaderRecordBufferHandleEXT:
283+
case spv::Op::OpHitObjectIsEmptyEXT:
284+
case spv::Op::OpHitObjectIsHitEXT:
285+
case spv::Op::OpHitObjectIsMissEXT:
250286
// SPV_NV_raw_access_chains
251287
case spv::Op::OpRawAccessChainNV:
252288
// SPV_NV_cooperative_matrix2

0 commit comments

Comments
 (0)