@@ -136,6 +136,8 @@ uint32_t getHeaderVersion(spv_target_env env) {
136136std::string
137137ReadSourceCode (llvm::StringRef filePath,
138138 const clang::spirv::SpirvCodeGenOptions &spvOptions) {
139+
140+ std::string localFilePath (filePath.begin (), filePath.end ());
139141 try {
140142 dxc::DxcDllSupport dllSupport;
141143 IFT (dllSupport.Initialize ());
@@ -154,7 +156,10 @@ ReadSourceCode(llvm::StringRef filePath,
154156 } catch (...) {
155157 // An exception has occurred while reading the file
156158 // return the original source (which may have been supplied directly)
157- if (!spvOptions.origSource .empty ()) {
159+ // only for the main input file
160+ if ((!strcmp (localFilePath.c_str (), " hlsl.hlsl" ) &&
161+ spvOptions.inputFile .empty ()) ||
162+ !strcmp (localFilePath.c_str (), spvOptions.inputFile .c_str ())) {
158163 return spvOptions.origSource .c_str ();
159164 }
160165 return " " ;
@@ -2639,80 +2644,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
26392644 // NodePayloadArray types
26402645 else if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type)) {
26412646 const uint32_t elemTypeId = emitType (npaType->getElementType ());
2647+
2648+ // Output the decorations for the type first. This will create other values
2649+ // that are on the decorations, and they must appear before the type.
2650+ emitDecorationsForNodePayloadArrayTypes (npaType, id);
2651+
26422652 initTypeInstruction (spv::Op::OpTypeNodePayloadArrayAMDX);
26432653 curTypeInst.push_back (id);
26442654 curTypeInst.push_back (elemTypeId);
26452655 finalizeTypeInstruction ();
2646-
2647- // Emit decorations
2648- const ParmVarDecl *nodeDecl = npaType->getNodeDecl ();
2649- if (hlsl::IsHLSLNodeOutputType (nodeDecl->getType ())) {
2650- StringRef name = nodeDecl->getName ();
2651- unsigned index = 0 ;
2652- if (auto nodeID = nodeDecl->getAttr <HLSLNodeIdAttr>()) {
2653- name = nodeID->getName ();
2654- index = nodeID->getArrayIndex ();
2655- }
2656-
2657- auto *str = new (context) SpirvConstantString (name);
2658- uint32_t nodeName = getOrCreateConstantString (str);
2659- emitDecoration (id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2660- llvm::None, true );
2661- if (index) {
2662- uint32_t baseIndex = getOrCreateConstantInt (
2663- llvm::APInt (32 , index), context.getUIntType (32 ), false );
2664- emitDecoration (id, spv::Decoration::PayloadNodeBaseIndexAMDX,
2665- {baseIndex}, llvm::None, true );
2666- }
2667- }
2668-
2669- uint32_t maxRecords;
2670- if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsAttr>()) {
2671- maxRecords = getOrCreateConstantInt (llvm::APInt (32 , attr->getMaxCount ()),
2672- context.getUIntType (32 ), false );
2673- } else {
2674- maxRecords = getOrCreateConstantInt (llvm::APInt (32 , 1 ),
2675- context.getUIntType (32 ), false );
2676- }
2677- emitDecoration (id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2678- llvm::None, true );
2679-
2680- if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsSharedWithAttr>()) {
2681- const DeclContext *dc = nodeDecl->getParentFunctionOrMethod ();
2682- if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2683- IdentifierInfo *ii = attr->getName ();
2684- bool alreadyExists = false ;
2685- for (auto *paramDecl : funDecl->params ()) {
2686- if (paramDecl->getIdentifier () == ii) {
2687- assert (paramDecl != nodeDecl);
2688- auto otherType = context.getNodeDeclPayloadType (paramDecl);
2689- const uint32_t otherId =
2690- getResultIdForType (otherType, &alreadyExists);
2691- assert (alreadyExists && " forward references not allowed in "
2692- " MaxRecordsSharedWith attribute" );
2693- emitDecoration (id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2694- {otherId}, llvm::None, true );
2695- break ;
2696- }
2697- }
2698- assert (alreadyExists &&
2699- " invalid reference in MaxRecordsSharedWith attribute" );
2700- }
2701- }
2702- if (const auto *attr = nodeDecl->getAttr <HLSLAllowSparseNodesAttr>()) {
2703- emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2704- llvm::None);
2705- }
2706- if (const auto *attr = nodeDecl->getAttr <HLSLUnboundedSparseNodesAttr>()) {
2707- emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2708- llvm::None);
2709- }
2710- if (const auto *attr = nodeDecl->getAttr <HLSLNodeArraySizeAttr>()) {
2711- uint32_t arraySize = getOrCreateConstantInt (
2712- llvm::APInt (32 , attr->getCount ()), context.getUIntType (32 ), false );
2713- emitDecoration (id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
2714- llvm::None, true );
2715- }
27162656 }
27172657 // Structure types
27182658 else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -2993,5 +2933,78 @@ void EmitTypeHandler::emitNameForType(llvm::StringRef name,
29932933 nameInstr.end ());
29942934}
29952935
2936+ void EmitTypeHandler::emitDecorationsForNodePayloadArrayTypes (
2937+ const NodePayloadArrayType *npaType, uint32_t id) {
2938+ // Emit decorations
2939+ const ParmVarDecl *nodeDecl = npaType->getNodeDecl ();
2940+ if (hlsl::IsHLSLNodeOutputType (nodeDecl->getType ())) {
2941+ StringRef name = nodeDecl->getName ();
2942+ unsigned index = 0 ;
2943+ if (auto nodeID = nodeDecl->getAttr <HLSLNodeIdAttr>()) {
2944+ name = nodeID->getName ();
2945+ index = nodeID->getArrayIndex ();
2946+ }
2947+
2948+ auto *str = new (context) SpirvConstantString (name);
2949+ uint32_t nodeName = getOrCreateConstantString (str);
2950+ emitDecoration (id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
2951+ llvm::None, true );
2952+ if (index) {
2953+ uint32_t baseIndex = getOrCreateConstantInt (
2954+ llvm::APInt (32 , index), context.getUIntType (32 ), false );
2955+ emitDecoration (id, spv::Decoration::PayloadNodeBaseIndexAMDX, {baseIndex},
2956+ llvm::None, true );
2957+ }
2958+ }
2959+
2960+ uint32_t maxRecords;
2961+ if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsAttr>()) {
2962+ maxRecords = getOrCreateConstantInt (llvm::APInt (32 , attr->getMaxCount ()),
2963+ context.getUIntType (32 ), false );
2964+ } else {
2965+ maxRecords = getOrCreateConstantInt (llvm::APInt (32 , 1 ),
2966+ context.getUIntType (32 ), false );
2967+ }
2968+ emitDecoration (id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
2969+ llvm::None, true );
2970+
2971+ if (const auto *attr = nodeDecl->getAttr <HLSLMaxRecordsSharedWithAttr>()) {
2972+ const DeclContext *dc = nodeDecl->getParentFunctionOrMethod ();
2973+ if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
2974+ IdentifierInfo *ii = attr->getName ();
2975+ bool alreadyExists = false ;
2976+ for (auto *paramDecl : funDecl->params ()) {
2977+ if (paramDecl->getIdentifier () == ii) {
2978+ assert (paramDecl != nodeDecl);
2979+ auto otherType = context.getNodeDeclPayloadType (paramDecl);
2980+ const uint32_t otherId =
2981+ getResultIdForType (otherType, &alreadyExists);
2982+ assert (alreadyExists && " forward references not allowed in "
2983+ " MaxRecordsSharedWith attribute" );
2984+ emitDecoration (id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
2985+ {otherId}, llvm::None, true );
2986+ break ;
2987+ }
2988+ }
2989+ assert (alreadyExists &&
2990+ " invalid reference in MaxRecordsSharedWith attribute" );
2991+ }
2992+ }
2993+ if (const auto *attr = nodeDecl->getAttr <HLSLAllowSparseNodesAttr>()) {
2994+ emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2995+ llvm::None);
2996+ }
2997+ if (const auto *attr = nodeDecl->getAttr <HLSLUnboundedSparseNodesAttr>()) {
2998+ emitDecoration (id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
2999+ llvm::None);
3000+ }
3001+ if (const auto *attr = nodeDecl->getAttr <HLSLNodeArraySizeAttr>()) {
3002+ uint32_t arraySize = getOrCreateConstantInt (
3003+ llvm::APInt (32 , attr->getCount ()), context.getUIntType (32 ), false );
3004+ emitDecoration (id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
3005+ llvm::None, true );
3006+ }
3007+ }
3008+
29963009} // end namespace spirv
29973010} // end namespace clang
0 commit comments