@@ -325,6 +325,17 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325325};
326326} // namespace
327327
328+ static bool GetIntConstAttrArg (ASTContext &astContext, const Expr *expr,
329+ unsigned *n) {
330+ llvm::APSInt val;
331+ if (!expr)
332+ return true ;
333+ if (!expr->isIntegerConstantExpr (val, astContext))
334+ return false ;
335+ *n = val.getLimitedValue ();
336+ return true ;
337+ }
338+
328339// ------------------------------------------------------------------------------
329340//
330341// CGMSHLSLRuntime methods.
@@ -1419,6 +1430,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14191430 }
14201431
14211432 DiagnosticsEngine &Diags = CGM.getDiags ();
1433+ ASTContext &astContext = CGM.getTypes ().getContext ();
14221434
14231435 std::unique_ptr<DxilFunctionProps> funcProps =
14241436 llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1641,12 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16291641
16301642 // Populate numThreads
16311643 if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1632-
1633- funcProps->numThreads [0 ] = Attr->getX ();
1634- funcProps->numThreads [1 ] = Attr->getY ();
1635- funcProps->numThreads [2 ] = Attr->getZ ();
1644+ funcProps->numThreads [0 ] = 1 ;
1645+ funcProps->numThreads [1 ] = 1 ;
1646+ funcProps->numThreads [2 ] = 1 ;
1647+ GetIntConstAttrArg (astContext, Attr->getX (), &funcProps->numThreads [0 ]);
1648+ GetIntConstAttrArg (astContext, Attr->getY (), &funcProps->numThreads [1 ]);
1649+ GetIntConstAttrArg (astContext, Attr->getZ (), &funcProps->numThreads [2 ]);
16361650
16371651 if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
16381652 unsigned DiagID = Diags.getCustomDiagID (
@@ -1803,33 +1817,50 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18031817 funcProps->Node .IsProgramEntry = true ;
18041818 }
18051819
1820+ funcProps->NodeShaderID .Index = 0 ;
18061821 if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
18071822 funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1808- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1823+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (),
1824+ &funcProps->NodeShaderID .Index );
18091825 } else {
18101826 funcProps->NodeShaderID .Name = FD->getName ().str ();
1811- funcProps->NodeShaderID .Index = 0 ;
18121827 }
18131828 if (const auto *pAttr =
18141829 FD->getAttr <HLSLNodeLocalRootArgumentsTableIndexAttr>()) {
18151830 funcProps->Node .LocalRootArgumentsTableIndex = pAttr->getIndex ();
18161831 }
18171832 if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
18181833 funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1819- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1834+ funcProps->NodeShaderSharedInput .Index = 0 ;
1835+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (),
1836+ &funcProps->NodeShaderSharedInput .Index );
18201837 }
18211838 if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1822- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1823- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1824- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1839+ funcProps->Node .DispatchGrid [0 ] = 1 ;
1840+ funcProps->Node .DispatchGrid [1 ] = 1 ;
1841+ funcProps->Node .DispatchGrid [2 ] = 1 ;
1842+ GetIntConstAttrArg (astContext, pAttr->getX (),
1843+ &funcProps->Node .DispatchGrid [0 ]);
1844+ GetIntConstAttrArg (astContext, pAttr->getY (),
1845+ &funcProps->Node .DispatchGrid [1 ]);
1846+ GetIntConstAttrArg (astContext, pAttr->getZ (),
1847+ &funcProps->Node .DispatchGrid [2 ]);
18251848 }
18261849 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1827- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1828- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1829- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1850+ funcProps->Node .MaxDispatchGrid [0 ] = 1 ;
1851+ funcProps->Node .MaxDispatchGrid [1 ] = 1 ;
1852+ funcProps->Node .MaxDispatchGrid [2 ] = 1 ;
1853+ GetIntConstAttrArg (astContext, pAttr->getX (),
1854+ &funcProps->Node .MaxDispatchGrid [0 ]);
1855+ GetIntConstAttrArg (astContext, pAttr->getY (),
1856+ &funcProps->Node .MaxDispatchGrid [1 ]);
1857+ GetIntConstAttrArg (astContext, pAttr->getZ (),
1858+ &funcProps->Node .MaxDispatchGrid [2 ]);
18301859 }
18311860 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1832- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1861+ funcProps->Node .MaxRecursionDepth = 0 ;
1862+ GetIntConstAttrArg (astContext, pAttr->getCount (),
1863+ &funcProps->Node .MaxRecursionDepth );
18331864 }
18341865 if (!FD->getAttr <HLSLNumThreadsAttr>()) {
18351866 // NumThreads wasn't specified.
@@ -2343,8 +2374,11 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23432374 NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23442375
23452376 if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2346- node.MaxRecords =
2347- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2377+ node.MaxRecords = 1 ;
2378+ GetIntConstAttrArg (
2379+ astContext,
2380+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (),
2381+ &node.MaxRecords );
23482382 }
23492383 if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
23502384 node.Flags .SetGloballyCoherent ();
@@ -2373,12 +2407,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23732407 }
23742408
23752409 // OutputID from attribute
2410+ node.OutputID .Index = 0 ;
23762411 if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
23772412 node.OutputID .Name = Attr->getName ().str ();
2378- node.OutputID .Index = Attr->getArrayIndex ();
2413+ GetIntConstAttrArg (astContext, Attr->getArrayIndex (),
2414+ &node.OutputID .Index );
23792415 } else {
23802416 node.OutputID .Name = parmDecl->getName ().str ();
2381- node.OutputID .Index = 0 ;
23822417 }
23832418
23842419 // Insert output decls for cross referencing once all info is
@@ -2433,8 +2468,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24332468 }
24342469 node.MaxRecordsSharedWith = ix;
24352470 }
2436- if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2437- node.MaxRecords = Attr->getMaxCount ();
2471+ if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>()) {
2472+ node.MaxRecords = 1 ;
2473+ GetIntConstAttrArg (astContext, Attr->getMaxCount (), &node.MaxRecords );
2474+ }
24382475 }
24392476
24402477 if (inputPatchCount > 1 ) {
0 commit comments