Skip to content

Commit dd1a4b7

Browse files
committed
[draft] Allows constants as attribute arguments.
1 parent 34b6d0f commit dd1a4b7

12 files changed

Lines changed: 378 additions & 145 deletions

File tree

tools/clang/include/clang/Basic/Attr.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def HLSLMaxTessFactor: InheritableAttr {
671671
}
672672
def HLSLNumThreads: InheritableAttr {
673673
let Spellings = [CXX11<"", "numthreads", 2015>];
674-
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
674+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
675675
let Documentation = [Undocumented];
676676
}
677677
def HLSLRootSignature: InheritableAttr {
@@ -1007,7 +1007,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {
10071007

10081008
def HLSLNodeId : InheritableAttr {
10091009
let Spellings = [CXX11<"", "nodeid", 2017>];
1010-
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
1010+
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
10111011
let Documentation = [Undocumented];
10121012
}
10131013

@@ -1019,25 +1019,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {
10191019

10201020
def HLSLNodeShareInputOf : InheritableAttr {
10211021
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
1022-
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
1022+
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
10231023
let Documentation = [Undocumented];
10241024
}
10251025

10261026
def HLSLNodeDispatchGrid: InheritableAttr {
10271027
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
1028-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1028+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10291029
let Documentation = [Undocumented];
10301030
}
10311031

10321032
def HLSLNodeMaxDispatchGrid: InheritableAttr {
10331033
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
1034-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1034+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10351035
let Documentation = [Undocumented];
10361036
}
10371037

10381038
def HLSLNodeMaxRecursionDepth : InheritableAttr {
10391039
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
1040-
let Args = [UnsignedArgument<"Count">];
1040+
let Args = [ExprArgument<"Count">];
10411041
let Documentation = [Undocumented];
10421042
}
10431043

@@ -1185,7 +1185,7 @@ def HLSLHitObject : InheritableAttr {
11851185

11861186
def HLSLMaxRecords : InheritableAttr {
11871187
let Spellings = [CXX11<"", "MaxRecords", 2015>];
1188-
let Args = [IntArgument<"maxCount">];
1188+
let Args = [ExprArgument<"maxCount">];
11891189
let Documentation = [Undocumented];
11901190
}
11911191

tools/clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8838,6 +8838,8 @@ class Sema {
88388838
bool HasVAListArg;
88398839
};
88408840

8841+
bool evaluateSpecConstInt(ASTContext &astContext, const Expr *E,
8842+
int64_t *value);
88418843
bool getFormatStringInfo(const FormatAttr *Format, bool IsCXXMember,
88428844
FormatStringInfo *FSI);
88438845
bool CheckFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,17 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325325
};
326326
} // namespace
327327

328+
static bool GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
329+
unsigned *n) {
330+
llvm::APSInt val;
331+
if (!expr)
332+
return true;
333+
if (!expr->isIntegerConstantExpr(val, astContext))
334+
return false;
335+
*n = val.getLimitedValue();
336+
return true;
337+
}
338+
328339
//------------------------------------------------------------------------------
329340
//
330341
// CGMSHLSLRuntime methods.
@@ -1419,6 +1430,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14191430
}
14201431

14211432
DiagnosticsEngine &Diags = CGM.getDiags();
1433+
ASTContext &astContext = CGM.getTypes().getContext();
14221434

14231435
std::unique_ptr<DxilFunctionProps> funcProps =
14241436
llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1641,12 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16291641

16301642
// Populate numThreads
16311643
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
1632-
1633-
funcProps->numThreads[0] = Attr->getX();
1634-
funcProps->numThreads[1] = Attr->getY();
1635-
funcProps->numThreads[2] = Attr->getZ();
1644+
funcProps->numThreads[0] = 1;
1645+
funcProps->numThreads[1] = 1;
1646+
funcProps->numThreads[2] = 1;
1647+
GetIntConstAttrArg(astContext, Attr->getX(), &funcProps->numThreads[0]);
1648+
GetIntConstAttrArg(astContext, Attr->getY(), &funcProps->numThreads[1]);
1649+
GetIntConstAttrArg(astContext, Attr->getZ(), &funcProps->numThreads[2]);
16361650

16371651
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
16381652
unsigned DiagID = Diags.getCustomDiagID(
@@ -1803,33 +1817,50 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18031817
funcProps->Node.IsProgramEntry = true;
18041818
}
18051819

1820+
funcProps->NodeShaderID.Index = 0;
18061821
if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
18071822
funcProps->NodeShaderID.Name = pAttr->getName().str();
1808-
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
1823+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(),
1824+
&funcProps->NodeShaderID.Index);
18091825
} else {
18101826
funcProps->NodeShaderID.Name = FD->getName().str();
1811-
funcProps->NodeShaderID.Index = 0;
18121827
}
18131828
if (const auto *pAttr =
18141829
FD->getAttr<HLSLNodeLocalRootArgumentsTableIndexAttr>()) {
18151830
funcProps->Node.LocalRootArgumentsTableIndex = pAttr->getIndex();
18161831
}
18171832
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
18181833
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
1819-
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
1834+
funcProps->NodeShaderSharedInput.Index = 0;
1835+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(),
1836+
&funcProps->NodeShaderSharedInput.Index);
18201837
}
18211838
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
1822-
funcProps->Node.DispatchGrid[0] = pAttr->getX();
1823-
funcProps->Node.DispatchGrid[1] = pAttr->getY();
1824-
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
1839+
funcProps->Node.DispatchGrid[0] = 1;
1840+
funcProps->Node.DispatchGrid[1] = 1;
1841+
funcProps->Node.DispatchGrid[2] = 1;
1842+
GetIntConstAttrArg(astContext, pAttr->getX(),
1843+
&funcProps->Node.DispatchGrid[0]);
1844+
GetIntConstAttrArg(astContext, pAttr->getY(),
1845+
&funcProps->Node.DispatchGrid[1]);
1846+
GetIntConstAttrArg(astContext, pAttr->getZ(),
1847+
&funcProps->Node.DispatchGrid[2]);
18251848
}
18261849
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
1827-
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
1828-
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
1829-
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
1850+
funcProps->Node.MaxDispatchGrid[0] = 1;
1851+
funcProps->Node.MaxDispatchGrid[1] = 1;
1852+
funcProps->Node.MaxDispatchGrid[2] = 1;
1853+
GetIntConstAttrArg(astContext, pAttr->getX(),
1854+
&funcProps->Node.MaxDispatchGrid[0]);
1855+
GetIntConstAttrArg(astContext, pAttr->getY(),
1856+
&funcProps->Node.MaxDispatchGrid[1]);
1857+
GetIntConstAttrArg(astContext, pAttr->getZ(),
1858+
&funcProps->Node.MaxDispatchGrid[2]);
18301859
}
18311860
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
1832-
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
1861+
funcProps->Node.MaxRecursionDepth = 0;
1862+
GetIntConstAttrArg(astContext, pAttr->getCount(),
1863+
&funcProps->Node.MaxRecursionDepth);
18331864
}
18341865
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
18351866
// NumThreads wasn't specified.
@@ -2343,8 +2374,11 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23432374
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23442375

23452376
if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
2346-
node.MaxRecords =
2347-
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
2377+
node.MaxRecords = 1;
2378+
GetIntConstAttrArg(
2379+
astContext,
2380+
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(),
2381+
&node.MaxRecords);
23482382
}
23492383
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
23502384
node.Flags.SetGloballyCoherent();
@@ -2373,12 +2407,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23732407
}
23742408

23752409
// OutputID from attribute
2410+
node.OutputID.Index = 0;
23762411
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
23772412
node.OutputID.Name = Attr->getName().str();
2378-
node.OutputID.Index = Attr->getArrayIndex();
2413+
GetIntConstAttrArg(astContext, Attr->getArrayIndex(),
2414+
&node.OutputID.Index);
23792415
} else {
23802416
node.OutputID.Name = parmDecl->getName().str();
2381-
node.OutputID.Index = 0;
23822417
}
23832418

23842419
// Insert output decls for cross referencing once all info is
@@ -2433,8 +2468,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24332468
}
24342469
node.MaxRecordsSharedWith = ix;
24352470
}
2436-
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
2437-
node.MaxRecords = Attr->getMaxCount();
2471+
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>()) {
2472+
node.MaxRecords = 1;
2473+
GetIntConstAttrArg(astContext, Attr->getMaxCount(), &node.MaxRecords);
2474+
}
24382475
}
24392476

24402477
if (inputPatchCount > 1) {

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13245,6 +13245,36 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) {
1324513245
}
1324613246
}
1324713247

13248+
bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) {
13249+
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13250+
if (!numThreadsAttr)
13251+
return false;
13252+
13253+
auto f = [this](Expr *E) {
13254+
APValue A;
13255+
if (DeclRefExpr *D = dyn_cast<DeclRefExpr>(E)) {
13256+
if (auto *V = llvm::dyn_cast<VarDecl>(D->getDecl())) {
13257+
if (const Expr *I = V->getAnyInitializer()) {
13258+
if (I->isCXX11ConstantExpr(astContext, &A) && A.isInt()) {
13259+
return (uint32_t)A.getInt().getSExtValue();
13260+
}
13261+
}
13262+
}
13263+
}
13264+
13265+
llvm::APSInt S(32, 1);
13266+
(void)E->isIntegerConstantExpr(S, astContext);
13267+
return (uint32_t)S.getSExtValue();
13268+
};
13269+
13270+
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13271+
{f(numThreadsAttr->getX()),
13272+
f(numThreadsAttr->getY()),
13273+
f(numThreadsAttr->getZ())},
13274+
decl->getLocation());
13275+
return true;
13276+
}
13277+
1324813278
bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
1324913279
uint32_t *arraySize) {
1325013280
bool success = true;
@@ -13421,15 +13451,9 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
1342113451
}
1342213452

1342313453
void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
13424-
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13425-
assert(numThreadsAttr && "thread group size missing from entry-point");
13426-
13427-
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13428-
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13429-
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
13430-
13431-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13432-
{x, y, z}, decl->getLocation());
13454+
if (!processNumThreadsAttr(decl)) {
13455+
assert(false && "thread group size missing from entry-point");
13456+
}
1343313457

1343413458
auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
1343513459
if (waveSizeAttr) {
@@ -13650,14 +13674,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
1365013674

1365113675
bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
1365213676
const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
13653-
if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
13654-
uint32_t x, y, z;
13655-
x = static_cast<uint32_t>(numThreadsAttr->getX());
13656-
y = static_cast<uint32_t>(numThreadsAttr->getY());
13657-
z = static_cast<uint32_t>(numThreadsAttr->getZ());
13658-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13659-
{x, y, z}, decl->getLocation());
13660-
}
13677+
processNumThreadsAttr(decl);
1366113678

1366213679
// Early return for amplification shaders as they only take the 'numthreads'
1366313680
// attribute.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,8 @@ class SpirvEmitter : public ASTConsumer {
837837
/// \brief Handle inline SPIR-V attributes for the entry function.
838838
void processInlineSpirvAttributes(const FunctionDecl *entryFunction);
839839

840+
bool processNumThreadsAttr(const FunctionDecl *decl);
841+
840842
/// \brief Adds necessary execution modes for the hull/domain shaders based on
841843
/// the HLSL attributes of the entry point function.
842844
/// In the case of hull shaders, also writes the number of output control

0 commit comments

Comments
 (0)