@@ -2644,80 +2644,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
26442644 // NodePayloadArray types
26452645 else if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type)) {
26462646 const uint32_t elemTypeId = emitType (npaType->getElementType ());
2647+
2648+ // Output the decorations for the type first. This will create other values
2649+ // that are on the decorations, and they must appear before the type.
2650+ emitDecorationsForNodePayloadArrayTypes (npaType, id);
2651+
26472652 initTypeInstruction (spv::Op::OpTypeNodePayloadArrayAMDX);
26482653 curTypeInst.push_back (id);
26492654 curTypeInst.push_back (elemTypeId);
26502655 finalizeTypeInstruction ();
2651-
2652- // Emit decorations
2653- const ParmVarDecl *nodeDecl = npaType->getNodeDecl ();
2654- if (hlsl::IsHLSLNodeOutputType (nodeDecl->getType ())) {
2655- StringRef name = nodeDecl->getName ();
2656- unsigned index = 0 ;
2657- if (auto nodeID = nodeDecl->getAttr <HLSLNodeIdAttr>()) {
2658- name = nodeID->getName ();
2659- index = nodeID->getArrayIndex ();
2660- }
2661-
2662- auto *str = new (context) SpirvConstantString (name);
2663- uint32_t nodeName = getOrCreateConstantString (str);
2664- emitDecoration (id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2665- llvm::None, true );
2666- if (index) {
2667- uint32_t baseIndex = getOrCreateConstantInt (
2668- llvm::APInt (32 , index), context.getUIntType (32 ), false );
2669- emitDecoration (id, spv::Decoration::PayloadNodeBaseIndexAMDX,
2670- {baseIndex}, llvm::None, true );
2671- }
2672- }
2673-
2674- uint32_t maxRecords;
2675- if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsAttr>()) {
2676- maxRecords = getOrCreateConstantInt (llvm::APInt (32 , attr->getMaxCount ()),
2677- context.getUIntType (32 ), false );
2678- } else {
2679- maxRecords = getOrCreateConstantInt (llvm::APInt (32 , 1 ),
2680- context.getUIntType (32 ), false );
2681- }
2682- emitDecoration (id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2683- llvm::None, true );
2684-
2685- if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsSharedWithAttr>()) {
2686- const DeclContext *dc = nodeDecl->getParentFunctionOrMethod ();
2687- if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2688- IdentifierInfo *ii = attr->getName ();
2689- bool alreadyExists = false ;
2690- for (auto *paramDecl : funDecl->params ()) {
2691- if (paramDecl->getIdentifier () == ii) {
2692- assert (paramDecl != nodeDecl);
2693- auto otherType = context.getNodeDeclPayloadType (paramDecl);
2694- const uint32_t otherId =
2695- getResultIdForType (otherType, &alreadyExists);
2696- assert (alreadyExists && " forward references not allowed in "
2697- " MaxRecordsSharedWith attribute" );
2698- emitDecoration (id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2699- {otherId}, llvm::None, true );
2700- break ;
2701- }
2702- }
2703- assert (alreadyExists &&
2704- " invalid reference in MaxRecordsSharedWith attribute" );
2705- }
2706- }
2707- if (const auto *attr = nodeDecl->getAttr <HLSLAllowSparseNodesAttr>()) {
2708- emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2709- llvm::None);
2710- }
2711- if (const auto *attr = nodeDecl->getAttr <HLSLUnboundedSparseNodesAttr>()) {
2712- emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2713- llvm::None);
2714- }
2715- if (const auto *attr = nodeDecl->getAttr <HLSLNodeArraySizeAttr>()) {
2716- uint32_t arraySize = getOrCreateConstantInt (
2717- llvm::APInt (32 , attr->getCount ()), context.getUIntType (32 ), false );
2718- emitDecoration (id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
2719- llvm::None, true );
2720- }
27212656 }
27222657 // Structure types
27232658 else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -2998,5 +2933,78 @@ void EmitTypeHandler::emitNameForType(llvm::StringRef name,
29982933 nameInstr.end ());
29992934}
30002935
2936+ void EmitTypeHandler::emitDecorationsForNodePayloadArrayTypes (
2937+ const NodePayloadArrayType *npaType, uint32_t id) {
2938+ // Emit decorations
2939+ const ParmVarDecl *nodeDecl = npaType->getNodeDecl ();
2940+ if (hlsl::IsHLSLNodeOutputType (nodeDecl->getType ())) {
2941+ StringRef name = nodeDecl->getName ();
2942+ unsigned index = 0 ;
2943+ if (auto nodeID = nodeDecl->getAttr <HLSLNodeIdAttr>()) {
2944+ name = nodeID->getName ();
2945+ index = nodeID->getArrayIndex ();
2946+ }
2947+
2948+ auto *str = new (context) SpirvConstantString (name);
2949+ uint32_t nodeName = getOrCreateConstantString (str);
2950+ emitDecoration (id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2951+ llvm::None, true );
2952+ if (index) {
2953+ uint32_t baseIndex = getOrCreateConstantInt (
2954+ llvm::APInt (32 , index), context.getUIntType (32 ), false );
2955+ emitDecoration (id, spv::Decoration::PayloadNodeBaseIndexAMDX, {baseIndex},
2956+ llvm::None, true );
2957+ }
2958+ }
2959+
2960+ uint32_t maxRecords;
2961+ if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsAttr>()) {
2962+ maxRecords = getOrCreateConstantInt (llvm::APInt (32 , attr->getMaxCount ()),
2963+ context.getUIntType (32 ), false );
2964+ } else {
2965+ maxRecords = getOrCreateConstantInt (llvm::APInt (32 , 1 ),
2966+ context.getUIntType (32 ), false );
2967+ }
2968+ emitDecoration (id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2969+ llvm::None, true );
2970+
2971+ if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsSharedWithAttr>()) {
2972+ const DeclContext *dc = nodeDecl->getParentFunctionOrMethod ();
2973+ if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2974+ IdentifierInfo *ii = attr->getName ();
2975+ bool alreadyExists = false ;
2976+ for (auto *paramDecl : funDecl->params ()) {
2977+ if (paramDecl->getIdentifier () == ii) {
2978+ assert (paramDecl != nodeDecl);
2979+ auto otherType = context.getNodeDeclPayloadType (paramDecl);
2980+ const uint32_t otherId =
2981+ getResultIdForType (otherType, &alreadyExists);
2982+ assert (alreadyExists && " forward references not allowed in "
2983+ " MaxRecordsSharedWith attribute" );
2984+ emitDecoration (id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2985+ {otherId}, llvm::None, true );
2986+ break ;
2987+ }
2988+ }
2989+ assert (alreadyExists &&
2990+ " invalid reference in MaxRecordsSharedWith attribute" );
2991+ }
2992+ }
2993+ if (const auto *attr = nodeDecl->getAttr <HLSLAllowSparseNodesAttr>()) {
2994+ emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2995+ llvm::None);
2996+ }
2997+ if (const auto *attr = nodeDecl->getAttr <HLSLUnboundedSparseNodesAttr>()) {
2998+ emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2999+ llvm::None);
3000+ }
3001+ if (const auto *attr = nodeDecl->getAttr <HLSLNodeArraySizeAttr>()) {
3002+ uint32_t arraySize = getOrCreateConstantInt (
3003+ llvm::APInt (32 , attr->getCount ()), context.getUIntType (32 ), false );
3004+ emitDecoration (id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
3005+ llvm::None, true );
3006+ }
3007+ }
3008+
30013009} // end namespace spirv
30023010} // end namespace clang
0 commit comments