@@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
288288 llvm::Value *DestPtr,
289289 clang::QualType DestTy) override ;
290290 void AddHLSLFunctionInfo (llvm::Function *, const FunctionDecl *FD) override ;
291+ bool FindDispatchGridSemantic (const CXXRecordDecl *RD,
292+ hlsl::SVDispatchGrid &SDGRec,
293+ CharUnits Offset = CharUnits());
291294 void AddHLSLNodeRecordTypeInfo (const clang::ParmVarDecl *parmDecl,
292295 hlsl::NodeIOProperties &node);
293296 void EmitHLSLFunctionProlog (llvm::Function *,
@@ -2560,6 +2563,66 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
25602563 m_ScopeMap[F] = ScopeInfo (F, FD->getLocation ());
25612564}
25622565
2566+ // Find the input node record field with the SV_DispatchGrid semantic.
2567+ // We have already diagnosed any error conditions in Sema, so we
2568+ // expect valid size and types, and use the first occurance found.
2569+ // We return true if we have populated the SV_DispatchGrid values.
2570+ bool CGMSHLSLRuntime::FindDispatchGridSemantic (const CXXRecordDecl *RD,
2571+ hlsl::SVDispatchGrid &SDGRec,
2572+ CharUnits Offset) {
2573+ const ASTRecordLayout &Layout = CGM.getContext ().getASTRecordLayout (RD);
2574+
2575+ // Check (non-virtual) bases
2576+ for (const CXXBaseSpecifier &Base : RD->bases ()) {
2577+ DXASSERT (!Base.getType ()->isDependentType (),
2578+ " Node Record with dependent base class not caught by Sema" );
2579+ if (Base.getType ()->isDependentType ())
2580+ continue ;
2581+ CXXRecordDecl *BaseDecl = Base.getType ()->getAsCXXRecordDecl ();
2582+ CharUnits BaseOffset = Offset + Layout.getBaseClassOffset (BaseDecl);
2583+ if (FindDispatchGridSemantic (BaseDecl, SDGRec, BaseOffset))
2584+ return true ;
2585+ }
2586+
2587+ // Check each field in this record.
2588+ for (FieldDecl *Field : RD->fields ()) {
2589+ uint64_t FieldNo = Field->getFieldIndex ();
2590+ CharUnits FieldOffset = Offset + CGM.getContext ().toCharUnitsFromBits (
2591+ Layout.getFieldOffset (FieldNo));
2592+
2593+ // If this field is a record check its fields
2594+ if (const CXXRecordDecl *D = Field->getType ()->getAsCXXRecordDecl ()) {
2595+ if (FindDispatchGridSemantic (D, SDGRec, FieldOffset))
2596+ return true ;
2597+ }
2598+ // Otherwise check this field for the SV_DispatchGrid semantic annotation
2599+ for (const hlsl::UnusualAnnotation *UA : Field->getUnusualAnnotations ()) {
2600+ if (UA->getKind () == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2601+ const hlsl::SemanticDecl *SD = cast<hlsl::SemanticDecl>(UA);
2602+ if (SD->SemanticName .equals (" SV_DispatchGrid" )) {
2603+ const llvm::Type *FTy = CGM.getTypes ().ConvertType (Field->getType ());
2604+ const llvm::Type *ElTy = FTy;
2605+ SDGRec.NumComponents = 1 ;
2606+ SDGRec.ByteOffset = (unsigned )FieldOffset.getQuantity ();
2607+ if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
2608+ SDGRec.NumComponents = VT->getNumElements ();
2609+ ElTy = VT->getElementType ();
2610+ } else if (const llvm::ArrayType *AT =
2611+ dyn_cast<llvm::ArrayType>(FTy)) {
2612+ SDGRec.NumComponents = AT->getNumElements ();
2613+ ElTy = AT->getElementType ();
2614+ }
2615+ SDGRec.ComponentType = (ElTy->getIntegerBitWidth () == 16 )
2616+ ? DXIL::ComponentType::U16
2617+ : DXIL::ComponentType::U32;
2618+ return true ;
2619+ }
2620+ }
2621+ }
2622+ }
2623+ return false ;
2624+ }
2625+
25632626void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo (
25642627 const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
25652628 clang::QualType paramTy = parmDecl->getType ().getCanonicalType ();
@@ -2577,7 +2640,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
25772640 DiagnosticsEngine &Diags = CGM.getDiags ();
25782641 auto &Rec = TemplateArgs.get (0 );
25792642 clang::QualType RecType = Rec.getAsType ();
2580- llvm::Type *Type = CGM.getTypes ().ConvertType (RecType);
25812643 CXXRecordDecl *RD = RecType->getAsCXXRecordDecl ();
25822644
25832645 // Get the TrackRWInputSharing flag from the record attribute
@@ -2597,63 +2659,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
25972659
25982660 // Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
25992661 // size(MY_RECORD), alignment = alignof(MY_RECORD)
2662+ llvm::Type *Type = CGM.getTypes ().ConvertType (RecType);
26002663 node.RecordType .size = CGM.getDataLayout ().getTypeAllocSize (Type);
26012664 node.RecordType .alignment =
26022665 CGM.getDataLayout ().getABITypeAlignment (Type);
2603- // Iterate over fields of the MY_RECORD(example) struct
2604- for (auto fieldDecl : RD->fields ()) {
2605- // Check if any of the fields have a semantic annotation =
2606- // SV_DispatchGrid
2607- for (const hlsl::UnusualAnnotation *it :
2608- fieldDecl->getUnusualAnnotations ()) {
2609- if (it->getKind () == hlsl::UnusualAnnotation::UA_SemanticDecl) {
2610- const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
2611- // if we find a field with SV_DispatchGrid, fill out the
2612- // SV_DispatchGrid member with byteoffset of the field,
2613- // NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
2614- // the only types allowed
2615- if (sd->SemanticName .equals (" SV_DispatchGrid" )) {
2616- clang::QualType FT = fieldDecl->getType ();
2617- auto &DL = CGM.getDataLayout ();
2618- auto &SDGRec = node.RecordType .SV_DispatchGrid ;
2619-
2620- DXASSERT_NOMSG (SDGRec.NumComponents == 0 );
2621-
2622- unsigned fieldIdx = fieldDecl->getFieldIndex ();
2623- if (StructType *ST = dyn_cast<StructType>(Type)) {
2624- SDGRec.ByteOffset =
2625- DL.getStructLayout (ST)->getElementOffset (fieldIdx);
2626- }
2627- const llvm::Type *lTy = CGM.getTypes ().ConvertType (FT);
2628- if (const llvm::VectorType *VT =
2629- dyn_cast<llvm::VectorType>(lTy)) {
2630- DXASSERT (VT->getElementType ()->isIntegerTy (), " invalid type" );
2631- SDGRec.NumComponents = VT->getNumElements ();
2632- SDGRec.ComponentType =
2633- (VT->getElementType ()->getIntegerBitWidth () == 16 )
2634- ? DXIL::ComponentType::U16
2635- : DXIL::ComponentType::U32;
2636- } else if (const llvm::ArrayType *AT =
2637- dyn_cast<llvm::ArrayType>(lTy)) {
2638- DXASSERT (AT->getElementType ()->isIntegerTy (), " invalid type" );
2639- DXASSERT_NOMSG (AT->getNumElements () <= 3 );
2640- SDGRec.NumComponents = AT->getNumElements ();
2641- SDGRec.ComponentType =
2642- (AT->getElementType ()->getIntegerBitWidth () == 16 )
2643- ? DXIL::ComponentType::U16
2644- : DXIL::ComponentType::U32;
2645- } else {
2646- // Scalar U16 or U32
2647- DXASSERT (lTy->isIntegerTy (), " invalid type" );
2648- SDGRec.NumComponents = 1 ;
2649- SDGRec.ComponentType = (lTy->getIntegerBitWidth () == 16 )
2650- ? DXIL::ComponentType::U16
2651- : DXIL::ComponentType::U32;
2652- }
2653- }
2654- }
2655- }
2656- }
2666+
2667+ FindDispatchGridSemantic (RD, node.RecordType .SV_DispatchGrid );
26572668 }
26582669 }
26592670 }
0 commit comments