diff --git a/tools/clang/lib/Sema/SemaDXR.cpp b/tools/clang/lib/Sema/SemaDXR.cpp index 36ab55ea10..e5b2140cca 100644 --- a/tools/clang/lib/Sema/SemaDXR.cpp +++ b/tools/clang/lib/Sema/SemaDXR.cpp @@ -28,6 +28,7 @@ #include "dxc/DXIL/DxilConstants.h" #include "dxc/DXIL/DxilShaderModel.h" +#include "dxc/HlslIntrinsicOp.h" using namespace clang; using namespace sema; @@ -49,9 +50,9 @@ struct PayloadUse { const MemberExpr *Member = nullptr; }; -struct TraceRayCall { - TraceRayCall() = default; - TraceRayCall(const CallExpr *Call, const CFGBlock *Parent) +struct PayloadBuiltinCall { + PayloadBuiltinCall() = default; + PayloadBuiltinCall(const CallExpr *Call, const CFGBlock *Parent) : Call(Call), Parent(Parent) {} const CallExpr *Call = nullptr; const CFGBlock *Parent = nullptr; @@ -71,7 +72,7 @@ struct DxrShaderDiagnoseInfo { const FunctionDecl *funcDecl; const VarDecl *Payload; DXIL::PayloadAccessShaderStage Stage; - std::vector TraceCalls; + std::vector PayloadBuiltinCalls; std::map> WritesPerField; std::map> ReadsPerField; std::vector PayloadAsCallArg; @@ -121,24 +122,42 @@ GetPayloadQualifierForStage(FieldDecl *Field, return DXIL::PayloadAccessQualifier::NoAccess; } -// Returns the declaration of the payload used in a TraceRay call -const VarDecl *GetPayloadParameterForTraceCall(const CallExpr *Trace) { - const Decl *callee = Trace->getCalleeDecl(); - if (!callee) +static int GetPayloadParamIdxForIntrinsic(const FunctionDecl *FD) { + HLSLIntrinsicAttr *IntrinAttr = FD->getAttr(); + if (!IntrinAttr) + return -1; + switch ((IntrinsicOp)IntrinAttr->getOpcode()) { + default: + return -1; + case IntrinsicOp::IOP_TraceRay: + case IntrinsicOp::MOP_DxHitObject_TraceRay: + case IntrinsicOp::MOP_DxHitObject_Invoke: + return FD->getNumParams() - 1; + } +} + +static bool IsBuiltinWithPayload(const FunctionDecl *FD) { + return GetPayloadParamIdxForIntrinsic(FD) >= 0; +} + +// Returns the declaration of the payload used in a call to TraceRay, +// HitObject::TraceRay or HitObject::Invoke. +const VarDecl *GetPayloadParameterForBuiltinCall(const CallExpr *Call) { + const Decl *Callee = Call->getCalleeDecl(); + if (!Callee) return nullptr; - if (!isa(callee)) + if (!isa(Callee)) return nullptr; - const FunctionDecl *FD = cast(callee); + int PldParamIdx = GetPayloadParamIdxForIntrinsic(cast(Callee)); + if (PldParamIdx < 0) + return nullptr; - if (FD->isImplicit() && FD->getName() == "TraceRay") { - const Stmt *Param = IgnoreParensAndDecay(Trace->getArg(7)); - if (const DeclRefExpr *ParamRef = dyn_cast(Param)) { - if (const VarDecl *Decl = dyn_cast(ParamRef->getDecl())) - return Decl; - } - } + const Stmt *Param = IgnoreParensAndDecay(Call->getArg(PldParamIdx)); + if (const DeclRefExpr *ParamRef = dyn_cast(Param)) + if (const VarDecl *Decl = dyn_cast(ParamRef->getDecl())) + return Decl; return nullptr; } @@ -190,12 +209,9 @@ void CollectReadsWritesAndCallsForPayload(const Stmt *S, } } -// Collects all TraceRay calls. -void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info, - const CFGBlock *Block) { - // TraceRay has void as return type so it should never be something else - // than a plain CallExpr. - +// Collects all calls to TraceRay, HitObject::TraceRay and HitObject::Invoke. +void CollectBuiltinCallsWithPayload(const Stmt *S, DxrShaderDiagnoseInfo &Info, + const CFGBlock *Block) { if (const CallExpr *Call = dyn_cast(S)) { const Decl *Callee = Call->getCalleeDecl(); @@ -204,11 +220,8 @@ void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info, const FunctionDecl *CalledFunction = cast(Callee); - // Ignore trace calls here. - if (CalledFunction->isImplicit() && - CalledFunction->getName() == "TraceRay") { - Info.TraceCalls.push_back({Call, Block}); - } + if (IsBuiltinWithPayload(CalledFunction)) + Info.PayloadBuiltinCalls.push_back({Call, Block}); } } @@ -528,13 +541,14 @@ void TraverseCFG(const CFGBlock &Block, Action PerElementAction, } } -// Forward traverse the CFG and collect calls to TraceRay. -void ForwardTraverseCFGAndCollectTraceCalls( +// Forward traverse the CFG and collect calls to TraceRay, HitObject::TraceRay +// and HitObject::Invoke. +void ForwardTraverseCFGAndCollectBuiltinCallsWithPayload( const CFGBlock &Block, DxrShaderDiagnoseInfo &Info, std::set &Visited) { auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) { if (Optional S = Element.getAs()) { - CollectTraceRayCalls(S->getStmt(), Info, &Block); + CollectBuiltinCallsWithPayload(S->getStmt(), Info, &Block); } }; @@ -664,9 +678,9 @@ DiagnosePayloadAsFunctionArg( const FunctionDecl *CalledFunction = cast(Callee); // Ignore trace calls here. - if (CalledFunction->isImplicit() && - CalledFunction->getName() == "TraceRay") { - Info.TraceCalls.push_back(TraceRayCall{Call, Use.Parent}); + if (IsBuiltinWithPayload(CalledFunction)) { + Info.PayloadBuiltinCalls.push_back( + PayloadBuiltinCall{Call, Use.Parent}); continue; } @@ -789,10 +803,12 @@ void HandlePayloadInitializer(DxrShaderDiagnoseInfo &Info) { } } -// Emit diagnostics for a TraceRay call. -void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, - const TraceRayCall &Trace, DominatorTree &DT) { - // For each TraceRay call check if write(caller) fields are written. +// Emit diagnostics for this call to either TraceRay, HitObject::TraceRay or +// HitObject::Invoke. +void DiagnoseBuiltinCallWithPayload(Sema &S, const VarDecl *Payload, + const PayloadBuiltinCall &PldCall, + DominatorTree &DT) { + // For each call check if write(caller) fields are written. const DXIL::PayloadAccessShaderStage CallerStage = DXIL::PayloadAccessShaderStage::Caller; @@ -810,6 +826,13 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, return; } + // Verify that the payload type is legal + if (!hlsl::IsHLSLCopyableAnnotatableRecord(Payload->getType())) { + S.Diag(Payload->getLocation(), diag::err_payload_attrs_must_be_udt) + << /*payload|attributes|callable*/ 0 << Payload; + return; + } + if (ContainsLongVector(Payload->getType())) { const unsigned PayloadParametersIdx = 10; S.Diag(Payload->getLocation(), diag::err_hlsl_unsupported_long_vector) @@ -832,12 +855,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, std::set Visited; - const CFGBlock *Parent = Trace.Parent; + const CFGBlock *Parent = PldCall.Parent; Visited.insert(Parent); - // Collect payload accesses in the same block until we reach the TraceRay call + // Collect payload accesses in the same block until we reach the call for (auto Element : *Parent) { if (Optional S = Element.getAs()) { - if (S->getStmt() == Trace.Call) + if (S->getStmt() == PldCall.Call) break; CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent); } @@ -850,10 +873,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, BackwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited); } + int PldArgIdx = PldCall.Call->getNumArgs() - 1; + // Warn if a writeable field has not been written. for (const FieldDecl *Field : WriteableFields) { if (!TraceInfo.WritesPerField.count(Field)) { - S.Diag(Trace.Call->getArg(7)->getExprLoc(), + S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(), diag::warn_hlsl_payload_access_no_write_for_trace_payload) << Field->getName(); } @@ -862,7 +887,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, for (const FieldDecl *Field : NonWriteableFields) { if (TraceInfo.WritesPerField.count(Field)) { S.Diag( - Trace.Call->getArg(7)->getExprLoc(), + PldCall.Call->getArg(PldArgIdx)->getExprLoc(), diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload) << Field->getName(); } @@ -878,7 +903,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, bool CallFound = false; for (auto Element : *Parent) { // TODO: reverse iterate? if (Optional S = Element.getAs()) { - if (S->getStmt() == Trace.Call) { + if (S->getStmt() == PldCall.Call) { CallFound = true; continue; } @@ -895,7 +920,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, for (const FieldDecl *Field : ReadableFields) { if (!TraceInfo.ReadsPerField.count(Field)) { - S.Diag(Trace.Call->getArg(7)->getExprLoc(), + S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(), diag::warn_hlsl_payload_access_read_but_no_read_after_trace) << Field->getName(); } @@ -928,27 +953,29 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload, } } -// Emit diagnostics for all TraceRay calls. -void DiagnoseTraceCalls(Sema &S, CFG &ShaderCFG, DominatorTree &DT, - DxrShaderDiagnoseInfo &Info) { - // Collect TraceRay calls in the shader. +// Emit diagnostics for all calls to TraceRay, HitObject::TraceRay or +// HitObject::Invoke. +void DiagnoseBuiltinCallsWithPayload(Sema &S, CFG &ShaderCFG, DominatorTree &DT, + DxrShaderDiagnoseInfo &Info) { + // Collect calls with payload in the shader. std::set Visited; - ForwardTraverseCFGAndCollectTraceCalls(ShaderCFG.getEntry(), Info, Visited); + ForwardTraverseCFGAndCollectBuiltinCallsWithPayload(ShaderCFG.getEntry(), + Info, Visited); std::set Diagnosed; - for (const TraceRayCall &TraceCall : Info.TraceCalls) { - if (Diagnosed.count(TraceCall.Call)) + for (const PayloadBuiltinCall &PldCall : Info.PayloadBuiltinCalls) { + if (Diagnosed.count(PldCall.Call)) continue; - Diagnosed.insert(TraceCall.Call); + Diagnosed.insert(PldCall.Call); - const VarDecl *Payload = GetPayloadParameterForTraceCall(TraceCall.Call); - DiagnoseTraceCall(S, Payload, TraceCall, DT); + const VarDecl *Payload = GetPayloadParameterForBuiltinCall(PldCall.Call); + DiagnoseBuiltinCallWithPayload(S, Payload, PldCall, DT); } } // Emit diagnostics for all access to the payload of a shader, -// and the input to TraceRay calls. +// and the input to TraceRay, HitObject::TraceRay or HitObject::Invoke calls. std::vector DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info, const std::set &FieldsToIgnoreRead, @@ -1012,7 +1039,7 @@ DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info, DiagnosePayloadReads(S, TheCFG, DT, Info, NonReadableFields); } - DiagnoseTraceCalls(S, TheCFG, DT, Info); + DiagnoseBuiltinCallsWithPayload(S, TheCFG, DT, Info); return WrittenFields; } diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload.hlsl new file mode 100644 index 0000000000..f4781bc796 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload.hlsl @@ -0,0 +1,27 @@ +// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=0 %s -verify +// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=1 %s -verify + +RaytracingAccelerationStructure scene : register(t0); + +struct Payload +{ + int a : read (caller, closesthit, miss) : write(caller, closesthit, miss); +}; + +struct Attribs +{ + float2 barys; +}; + +[shader("raygeneration")] +void RayGen() +{ +// expected-error@+1{{type 'Payload' used as payload requires that it is annotated with the [raypayload] attribute}} + Payload payload_in_rg; + RayDesc ray; +#if TEST_NUM == 0 + dx::HitObject::TraceRay( scene, RAY_FLAG_NONE, 0xff, 0, 1, 0, ray, payload_in_rg ); +#else + dx::HitObject::Invoke( dx::HitObject(), payload_in_rg ); +#endif +} \ No newline at end of file diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload_udt.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload_udt.hlsl new file mode 100644 index 0000000000..e89e33a78f --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/objects/HitObject/hitobject_traceinvoke_payload_udt.hlsl @@ -0,0 +1,22 @@ +// RUN: %dxc -T lib_6_9 %s -verify + +struct +[raypayload] +Payload +{ + int a : read(caller, closesthit, miss) : write(caller, closesthit, miss); + dx::HitObject hit; +}; + +struct Attribs +{ + float2 barys; +}; + +[shader("raygeneration")] +void RayGen() +{ + // expected-error@+1{{payload parameter 'payload_in_rg' must be a user-defined type composed of only numeric types}} + Payload payload_in_rg; + dx::HitObject::Invoke( dx::HitObject(), payload_in_rg ); +} \ No newline at end of file