2828
2929#include " dxc/DXIL/DxilConstants.h"
3030#include " dxc/DXIL/DxilShaderModel.h"
31+ #include " dxc/HlslIntrinsicOp.h"
3132
3233using namespace clang ;
3334using namespace sema ;
@@ -49,9 +50,9 @@ struct PayloadUse {
4950 const MemberExpr *Member = nullptr ;
5051};
5152
52- struct TraceRayCall {
53- TraceRayCall () = default ;
54- TraceRayCall (const CallExpr *Call, const CFGBlock *Parent)
53+ struct PayloadBuiltinCall {
54+ PayloadBuiltinCall () = default ;
55+ PayloadBuiltinCall (const CallExpr *Call, const CFGBlock *Parent)
5556 : Call(Call), Parent(Parent) {}
5657 const CallExpr *Call = nullptr ;
5758 const CFGBlock *Parent = nullptr ;
@@ -71,7 +72,7 @@ struct DxrShaderDiagnoseInfo {
7172 const FunctionDecl *funcDecl;
7273 const VarDecl *Payload;
7374 DXIL::PayloadAccessShaderStage Stage;
74- std::vector<TraceRayCall> TraceCalls ;
75+ std::vector<PayloadBuiltinCall> PayloadBuiltinCalls ;
7576 std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
7677 std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
7778 std::vector<PayloadUse> PayloadAsCallArg;
@@ -121,24 +122,42 @@ GetPayloadQualifierForStage(FieldDecl *Field,
121122 return DXIL::PayloadAccessQualifier::NoAccess;
122123}
123124
124- // Returns the declaration of the payload used in a TraceRay call
125- const VarDecl *GetPayloadParameterForTraceCall (const CallExpr *Trace) {
126- const Decl *callee = Trace->getCalleeDecl ();
127- if (!callee)
125+ static int GetPayloadParamIdxForIntrinsic (const FunctionDecl *FD) {
126+ HLSLIntrinsicAttr *IntrinAttr = FD->getAttr <HLSLIntrinsicAttr>();
127+ if (!IntrinAttr)
128+ return -1 ;
129+ switch ((IntrinsicOp)IntrinAttr->getOpcode ()) {
130+ default :
131+ return -1 ;
132+ case IntrinsicOp::IOP_TraceRay:
133+ case IntrinsicOp::MOP_DxHitObject_TraceRay:
134+ case IntrinsicOp::MOP_DxHitObject_Invoke:
135+ return FD->getNumParams () - 1 ;
136+ }
137+ }
138+
139+ static bool IsBuiltinWithPayload (const FunctionDecl *FD) {
140+ return GetPayloadParamIdxForIntrinsic (FD) >= 0 ;
141+ }
142+
143+ // Returns the declaration of the payload used in a call to TraceRay,
144+ // HitObject::TraceRay or HitObject::Invoke.
145+ const VarDecl *GetPayloadParameterForBuiltinCall (const CallExpr *Call) {
146+ const Decl *Callee = Call->getCalleeDecl ();
147+ if (!Callee)
128148 return nullptr ;
129149
130- if (!isa<FunctionDecl>(callee ))
150+ if (!isa<FunctionDecl>(Callee ))
131151 return nullptr ;
132152
133- const FunctionDecl *FD = cast<FunctionDecl>(callee);
153+ int PldParamIdx = GetPayloadParamIdxForIntrinsic (cast<FunctionDecl>(Callee));
154+ if (PldParamIdx < 0 )
155+ return nullptr ;
134156
135- if (FD->isImplicit () && FD->getName () == " TraceRay" ) {
136- const Stmt *Param = IgnoreParensAndDecay (Trace->getArg (7 ));
137- if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
138- if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl ()))
139- return Decl;
140- }
141- }
157+ const Stmt *Param = IgnoreParensAndDecay (Call->getArg (PldParamIdx));
158+ if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param))
159+ if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl ()))
160+ return Decl;
142161 return nullptr ;
143162}
144163
@@ -190,12 +209,9 @@ void CollectReadsWritesAndCallsForPayload(const Stmt *S,
190209 }
191210}
192211
193- // Collects all TraceRay calls.
194- void CollectTraceRayCalls (const Stmt *S, DxrShaderDiagnoseInfo &Info,
195- const CFGBlock *Block) {
196- // TraceRay has void as return type so it should never be something else
197- // than a plain CallExpr.
198-
212+ // Collects all calls to TraceRay, HitObject::TraceRay and HitObject::Invoke.
213+ void CollectBuiltinCallsWithPayload (const Stmt *S, DxrShaderDiagnoseInfo &Info,
214+ const CFGBlock *Block) {
199215 if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {
200216
201217 const Decl *Callee = Call->getCalleeDecl ();
@@ -204,11 +220,8 @@ void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
204220
205221 const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
206222
207- // Ignore trace calls here.
208- if (CalledFunction->isImplicit () &&
209- CalledFunction->getName () == " TraceRay" ) {
210- Info.TraceCalls .push_back ({Call, Block});
211- }
223+ if (IsBuiltinWithPayload (CalledFunction))
224+ Info.PayloadBuiltinCalls .push_back ({Call, Block});
212225 }
213226}
214227
@@ -528,13 +541,14 @@ void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
528541 }
529542}
530543
531- // Forward traverse the CFG and collect calls to TraceRay.
532- void ForwardTraverseCFGAndCollectTraceCalls (
544+ // Forward traverse the CFG and collect calls to TraceRay, HitObject::TraceRay
545+ // and HitObject::Invoke.
546+ void ForwardTraverseCFGAndCollectBuiltinCallsWithPayload (
533547 const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
534548 std::set<const CFGBlock *> &Visited) {
535549 auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
536550 if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
537- CollectTraceRayCalls (S->getStmt (), Info, &Block);
551+ CollectBuiltinCallsWithPayload (S->getStmt (), Info, &Block);
538552 }
539553 };
540554
@@ -664,9 +678,9 @@ DiagnosePayloadAsFunctionArg(
664678 const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
665679
666680 // Ignore trace calls here.
667- if (CalledFunction-> isImplicit () &&
668- CalledFunction-> getName () == " TraceRay " ) {
669- Info. TraceCalls . push_back (TraceRayCall {Call, Use.Parent });
681+ if (IsBuiltinWithPayload (CalledFunction)) {
682+ Info. PayloadBuiltinCalls . push_back (
683+ PayloadBuiltinCall {Call, Use.Parent });
670684 continue ;
671685 }
672686
@@ -789,10 +803,12 @@ void HandlePayloadInitializer(DxrShaderDiagnoseInfo &Info) {
789803 }
790804}
791805
792- // Emit diagnostics for a TraceRay call.
793- void DiagnoseTraceCall (Sema &S, const VarDecl *Payload,
794- const TraceRayCall &Trace, DominatorTree &DT) {
795- // For each TraceRay call check if write(caller) fields are written.
806+ // Emit diagnostics for this call to either TraceRay, HitObject::TraceRay or
807+ // HitObject::Invoke.
808+ void DiagnoseBuiltinCallWithPayload (Sema &S, const VarDecl *Payload,
809+ const PayloadBuiltinCall &PldCall,
810+ DominatorTree &DT) {
811+ // For each call check if write(caller) fields are written.
796812 const DXIL::PayloadAccessShaderStage CallerStage =
797813 DXIL::PayloadAccessShaderStage::Caller;
798814
@@ -810,6 +826,13 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
810826 return ;
811827 }
812828
829+ // Verify that the payload type is legal
830+ if (!hlsl::IsHLSLCopyableAnnotatableRecord (Payload->getType ())) {
831+ S.Diag (Payload->getLocation (), diag::err_payload_attrs_must_be_udt)
832+ << /* payload|attributes|callable*/ 0 << Payload;
833+ return ;
834+ }
835+
813836 if (ContainsLongVector (Payload->getType ())) {
814837 const unsigned PayloadParametersIdx = 10 ;
815838 S.Diag (Payload->getLocation (), diag::err_hlsl_unsupported_long_vector)
@@ -832,12 +855,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
832855
833856 std::set<const CFGBlock *> Visited;
834857
835- const CFGBlock *Parent = Trace .Parent ;
858+ const CFGBlock *Parent = PldCall .Parent ;
836859 Visited.insert (Parent);
837- // Collect payload accesses in the same block until we reach the TraceRay call
860+ // Collect payload accesses in the same block until we reach the call
838861 for (auto Element : *Parent) {
839862 if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
840- if (S->getStmt () == Trace .Call )
863+ if (S->getStmt () == PldCall .Call )
841864 break ;
842865 CollectReadsWritesAndCallsForPayload (S->getStmt (), TraceInfo, Parent);
843866 }
@@ -850,10 +873,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
850873 BackwardTraverseCFGAndCollectReadsWrites (*Pred, TraceInfo, Visited);
851874 }
852875
876+ int PldArgIdx = PldCall.Call ->getNumArgs () - 1 ;
877+
853878 // Warn if a writeable field has not been written.
854879 for (const FieldDecl *Field : WriteableFields) {
855880 if (!TraceInfo.WritesPerField .count (Field)) {
856- S.Diag (Trace .Call ->getArg (7 )->getExprLoc (),
881+ S.Diag (PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
857882 diag::warn_hlsl_payload_access_no_write_for_trace_payload)
858883 << Field->getName ();
859884 }
@@ -862,7 +887,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
862887 for (const FieldDecl *Field : NonWriteableFields) {
863888 if (TraceInfo.WritesPerField .count (Field)) {
864889 S.Diag (
865- Trace .Call ->getArg (7 )->getExprLoc (),
890+ PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
866891 diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
867892 << Field->getName ();
868893 }
@@ -878,7 +903,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
878903 bool CallFound = false ;
879904 for (auto Element : *Parent) { // TODO: reverse iterate?
880905 if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
881- if (S->getStmt () == Trace .Call ) {
906+ if (S->getStmt () == PldCall .Call ) {
882907 CallFound = true ;
883908 continue ;
884909 }
@@ -895,7 +920,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
895920
896921 for (const FieldDecl *Field : ReadableFields) {
897922 if (!TraceInfo.ReadsPerField .count (Field)) {
898- S.Diag (Trace .Call ->getArg (7 )->getExprLoc (),
923+ S.Diag (PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
899924 diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
900925 << Field->getName ();
901926 }
@@ -928,27 +953,29 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
928953 }
929954}
930955
931- // Emit diagnostics for all TraceRay calls.
932- void DiagnoseTraceCalls (Sema &S, CFG &ShaderCFG, DominatorTree &DT,
933- DxrShaderDiagnoseInfo &Info) {
934- // Collect TraceRay calls in the shader.
956+ // Emit diagnostics for all calls to TraceRay, HitObject::TraceRay or
957+ // HitObject::Invoke.
958+ void DiagnoseBuiltinCallsWithPayload (Sema &S, CFG &ShaderCFG, DominatorTree &DT,
959+ DxrShaderDiagnoseInfo &Info) {
960+ // Collect calls with payload in the shader.
935961 std::set<const CFGBlock *> Visited;
936- ForwardTraverseCFGAndCollectTraceCalls (ShaderCFG.getEntry (), Info, Visited);
962+ ForwardTraverseCFGAndCollectBuiltinCallsWithPayload (ShaderCFG.getEntry (),
963+ Info, Visited);
937964
938965 std::set<const CallExpr *> Diagnosed;
939966
940- for (const TraceRayCall &TraceCall : Info.TraceCalls ) {
941- if (Diagnosed.count (TraceCall .Call ))
967+ for (const PayloadBuiltinCall &PldCall : Info.PayloadBuiltinCalls ) {
968+ if (Diagnosed.count (PldCall .Call ))
942969 continue ;
943- Diagnosed.insert (TraceCall .Call );
970+ Diagnosed.insert (PldCall .Call );
944971
945- const VarDecl *Payload = GetPayloadParameterForTraceCall (TraceCall .Call );
946- DiagnoseTraceCall (S, Payload, TraceCall , DT);
972+ const VarDecl *Payload = GetPayloadParameterForBuiltinCall (PldCall .Call );
973+ DiagnoseBuiltinCallWithPayload (S, Payload, PldCall , DT);
947974 }
948975}
949976
950977// Emit diagnostics for all access to the payload of a shader,
951- // and the input to TraceRay calls.
978+ // and the input to TraceRay, HitObject::TraceRay or HitObject::Invoke calls.
952979std::vector<const FieldDecl *>
953980DiagnosePayloadAccess (Sema &S, DxrShaderDiagnoseInfo &Info,
954981 const std::set<const FieldDecl *> &FieldsToIgnoreRead,
@@ -1012,7 +1039,7 @@ DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
10121039 DiagnosePayloadReads (S, TheCFG, DT, Info, NonReadableFields);
10131040 }
10141041
1015- DiagnoseTraceCalls (S, TheCFG, DT, Info);
1042+ DiagnoseBuiltinCallsWithPayload (S, TheCFG, DT, Info);
10161043
10171044 return WrittenFields;
10181045}
0 commit comments