@@ -174,6 +174,7 @@ struct ValidationContext {
174174 const unsigned kLLVMLoopMDKind ;
175175 unsigned m_DxilMajor, m_DxilMinor;
176176 ModuleSlotTracker slotTracker;
177+ std::unique_ptr<CallGraph> pCallGraph;
177178
178179 ValidationContext (Module &llvmModule, Module *DebugModule,
179180 DxilModule &dxilModule)
@@ -397,6 +398,12 @@ struct ValidationContext {
397398
398399 EntryStatus &GetEntryStatus (Function *F) { return *entryStatusMap[F]; }
399400
401+ CallGraph &GetCallGraph () {
402+ if (!pCallGraph)
403+ pCallGraph = llvm::make_unique<CallGraph>(M);
404+ return *pCallGraph.get ();
405+ }
406+
400407 DxilResourceProperties GetResourceFromVal (Value *resVal);
401408
402409 void EmitGlobalVariableFormatError (GlobalVariable *GV, ValidationRule rule,
@@ -5386,6 +5393,216 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx) {
53865393 }
53875394}
53885395
5396+ // CompatibilityChecker is used to identify incompatibilities in an entry
5397+ // function and any functions called by that entry function.
5398+ struct CompatibilityChecker {
5399+ ValidationContext &ValCtx;
5400+ Function *EntryFn;
5401+ const DxilFunctionProps &props;
5402+ DXIL::ShaderKind shaderKind;
5403+
5404+ // These masks identify the potential conflict flags based on the entry
5405+ // function's shader kind and properties when either UsesDerivatives or
5406+ // RequiresGroup flags are set in ShaderCompatInfo.
5407+ uint32_t maskForDeriv = 0 ;
5408+ uint32_t maskForGroup = 0 ;
5409+
5410+ enum class ConflictKind : uint32_t {
5411+ Stage,
5412+ ShaderModel,
5413+ DerivLaunch,
5414+ DerivThreadGroupDim,
5415+ DerivInComputeShaderModel,
5416+ RequiresGroup,
5417+ };
5418+ enum class ConflictFlags : uint32_t {
5419+ Stage = 1 << (uint32_t )ConflictKind::Stage,
5420+ ShaderModel = 1 << (uint32_t )ConflictKind::ShaderModel,
5421+ DerivLaunch = 1 << (uint32_t )ConflictKind::DerivLaunch,
5422+ DerivThreadGroupDim = 1 << (uint32_t )ConflictKind::DerivThreadGroupDim,
5423+ DerivInComputeShaderModel =
5424+ 1 << (uint32_t )ConflictKind::DerivInComputeShaderModel,
5425+ RequiresGroup = 1 << (uint32_t )ConflictKind::RequiresGroup,
5426+ };
5427+
5428+ CompatibilityChecker (ValidationContext &ValCtx, Function *EntryFn)
5429+ : ValCtx(ValCtx), EntryFn(EntryFn),
5430+ props (ValCtx.DxilMod.GetDxilEntryProps(EntryFn).props),
5431+ shaderKind(props.shaderKind) {
5432+
5433+ // Precompute potential incompatibilities based on shader stage, shader kind
5434+ // and entry attributes. These will turn into full conflicts if the entry
5435+ // point's shader flags indicate that they use relevant features.
5436+ if (!ValCtx.DxilMod .GetShaderModel ()->IsSM66Plus () &&
5437+ (shaderKind == DXIL::ShaderKind::Mesh ||
5438+ shaderKind == DXIL::ShaderKind::Amplification ||
5439+ shaderKind == DXIL::ShaderKind::Compute)) {
5440+ maskForDeriv |=
5441+ static_cast <uint32_t >(ConflictFlags::DerivInComputeShaderModel);
5442+ } else if (shaderKind == DXIL::ShaderKind::Node) {
5443+ // Only broadcasting launch supports derivatives.
5444+ if (props.Node .LaunchType != DXIL::NodeLaunchType::Broadcasting)
5445+ maskForDeriv |= static_cast <uint32_t >(ConflictFlags::DerivLaunch);
5446+ // Thread launch node has no group.
5447+ if (props.Node .LaunchType == DXIL::NodeLaunchType::Thread)
5448+ maskForGroup |= static_cast <uint32_t >(ConflictFlags::RequiresGroup);
5449+ }
5450+
5451+ if (shaderKind == DXIL::ShaderKind::Mesh ||
5452+ shaderKind == DXIL::ShaderKind::Amplification ||
5453+ shaderKind == DXIL::ShaderKind::Compute ||
5454+ shaderKind == DXIL::ShaderKind::Node) {
5455+ // All compute-like stages
5456+ // Thread dimensions must be either 1D and X is multiple of 4, or 2D
5457+ // and X and Y must be multiples of 2.
5458+ if (props.numThreads [1 ] == 1 && props.numThreads [2 ] == 1 ) {
5459+ if ((props.numThreads [0 ] & 0x3 ) != 0 )
5460+ maskForDeriv |=
5461+ static_cast <uint32_t >(ConflictFlags::DerivThreadGroupDim);
5462+ } else if ((props.numThreads [0 ] & 0x1 ) || (props.numThreads [1 ] & 0x1 ))
5463+ maskForDeriv |=
5464+ static_cast <uint32_t >(ConflictFlags::DerivThreadGroupDim);
5465+ } else {
5466+ // other stages have no group
5467+ maskForGroup |= static_cast <uint32_t >(ConflictFlags::RequiresGroup);
5468+ }
5469+ }
5470+
5471+ uint32_t
5472+ IdentifyConflict (const DxilModule::ShaderCompatInfo &compatInfo) const {
5473+ uint32_t conflictMask = 0 ;
5474+
5475+ // Compatibility check said this shader kind is not compatible.
5476+ if (0 == ((1 << (uint32_t )shaderKind) & compatInfo.mask ))
5477+ conflictMask |= (uint32_t )ConflictFlags::Stage;
5478+
5479+ // Compatibility check said this shader model is not compatible.
5480+ if (DXIL::CompareVersions (ValCtx.DxilMod .GetShaderModel ()->GetMajor (),
5481+ ValCtx.DxilMod .GetShaderModel ()->GetMinor (),
5482+ compatInfo.minMajor , compatInfo.minMinor ) < 0 )
5483+ conflictMask |= (uint32_t )ConflictFlags::ShaderModel;
5484+
5485+ if (compatInfo.shaderFlags .GetUsesDerivatives ())
5486+ conflictMask |= maskForDeriv;
5487+
5488+ if (compatInfo.shaderFlags .GetRequiresGroup ())
5489+ conflictMask |= maskForGroup;
5490+
5491+ return conflictMask;
5492+ }
5493+
5494+ void Diagnose (Function *F, uint32_t conflictMask, ConflictKind conflict,
5495+ ValidationRule rule, ArrayRef<StringRef> args = {}) {
5496+ if (conflictMask & (1 << (unsigned )conflict))
5497+ ValCtx.EmitFnFormatError (F, rule, args);
5498+ }
5499+
5500+ void DiagnoseConflicts (Function *F, uint32_t conflictMask) {
5501+ // Emit a diagnostic indicating that either the entry function or a function
5502+ // called by the entry function contains a disallowed operation.
5503+ if (F == EntryFn)
5504+ ValCtx.EmitFnError (EntryFn, ValidationRule::SmIncompatibleOperation);
5505+ else
5506+ ValCtx.EmitFnError (EntryFn, ValidationRule::SmIncompatibleCallInEntry);
5507+
5508+ // Emit diagnostics for each conflict found in this function.
5509+ Diagnose (F, conflictMask, ConflictKind::Stage,
5510+ ValidationRule::SmIncompatibleStage,
5511+ {ShaderModel::GetKindName (props.shaderKind )});
5512+ Diagnose (F, conflictMask, ConflictKind::ShaderModel,
5513+ ValidationRule::SmIncompatibleShaderModel);
5514+ Diagnose (F, conflictMask, ConflictKind::DerivLaunch,
5515+ ValidationRule::SmIncompatibleDerivLaunch,
5516+ {GetLaunchTypeStr (props.Node .LaunchType )});
5517+ Diagnose (F, conflictMask, ConflictKind::DerivThreadGroupDim,
5518+ ValidationRule::SmIncompatibleThreadGroupDim,
5519+ {std::to_string (props.numThreads [0 ]),
5520+ std::to_string (props.numThreads [1 ]),
5521+ std::to_string (props.numThreads [2 ])});
5522+ Diagnose (F, conflictMask, ConflictKind::DerivInComputeShaderModel,
5523+ ValidationRule::SmIncompatibleDerivInComputeShaderModel);
5524+ Diagnose (F, conflictMask, ConflictKind::RequiresGroup,
5525+ ValidationRule::SmIncompatibleRequiresGroup);
5526+ }
5527+
5528+ // Visit function and all functions called by it.
5529+ // Emit diagnostics for incompatibilities found in a function when no
5530+ // functions called by that function introduced the conflict.
5531+ // In those cases, the called functions themselves will emit the diagnostic.
5532+ // Return conflict mask for this function.
5533+ uint32_t Visit (Function *F, uint32_t &remainingMask,
5534+ llvm::SmallPtrSet<Function *, 8 > &visited, CallGraph &CG) {
5535+ // Recursive check looks for where a conflict is found and not present
5536+ // in functions called by the current function.
5537+ // - When a source is found, emit diagnostics and clear the conflict
5538+ // flags introduced by this function from the working mask so we don't
5539+ // report this conflict again.
5540+ // - When the remainingMask is 0, we are done.
5541+
5542+ if (remainingMask == 0 )
5543+ return 0 ; // Nothing left to search for.
5544+ if (!visited.insert (F).second )
5545+ return 0 ; // Already visited.
5546+
5547+ const DxilModule::ShaderCompatInfo *compatInfo =
5548+ ValCtx.DxilMod .GetCompatInfoForFunction (F);
5549+ DXASSERT (compatInfo, " otherwise, compat info not computed in module" );
5550+ if (!compatInfo)
5551+ return 0 ;
5552+ uint32_t maskForThisFunction = IdentifyConflict (*compatInfo);
5553+
5554+ uint32_t maskForCalls = 0 ;
5555+ if (CallGraphNode *CGNode = CG[F]) {
5556+ for (auto &Call : *CGNode) {
5557+ Function *called = Call.second ->getFunction ();
5558+ if (called->isDeclaration ())
5559+ continue ;
5560+ maskForCalls |= Visit (called, remainingMask, visited, CG);
5561+ if (remainingMask == 0 )
5562+ return 0 ; // Nothing left to search for.
5563+ }
5564+ }
5565+
5566+ // Mask of incompatibilities introduced by this function.
5567+ uint32_t conflictsIntroduced =
5568+ remainingMask & maskForThisFunction & ~maskForCalls;
5569+ if (conflictsIntroduced) {
5570+ // This function introduces at least one conflict.
5571+ DiagnoseConflicts (F, conflictsIntroduced);
5572+ // Mask off diagnosed incompatibilities.
5573+ remainingMask &= ~conflictsIntroduced;
5574+ }
5575+ return maskForThisFunction;
5576+ }
5577+
5578+ void FindIncompatibleCall (const DxilModule::ShaderCompatInfo &compatInfo) {
5579+ uint32_t conflictMask = IdentifyConflict (compatInfo);
5580+ if (conflictMask == 0 )
5581+ return ;
5582+
5583+ CallGraph &CG = ValCtx.GetCallGraph ();
5584+ llvm::SmallPtrSet<Function *, 8 > visited;
5585+ Visit (EntryFn, conflictMask, visited, CG);
5586+ }
5587+ };
5588+
5589+ static void ValidateEntryCompatibility (ValidationContext &ValCtx) {
5590+ // Make sure functions called from each entry are compatible with that entry.
5591+ DxilModule &DM = ValCtx.DxilMod ;
5592+ for (Function &F : DM.GetModule ()->functions ()) {
5593+ if (DM.HasDxilEntryProps (&F)) {
5594+ const DxilModule::ShaderCompatInfo *compatInfo =
5595+ DM.GetCompatInfoForFunction (&F);
5596+ DXASSERT (compatInfo, " otherwise, compat info not computed in module" );
5597+ if (!compatInfo)
5598+ continue ;
5599+
5600+ CompatibilityChecker checker (ValCtx, &F);
5601+ checker.FindIncompatibleCall (*compatInfo);
5602+ }
5603+ }
5604+ }
5605+
53895606static void CheckPatchConstantSemantic (ValidationContext &ValCtx,
53905607 const DxilEntryProps &EntryProps,
53915608 EntryStatus &Status, Function *F) {
@@ -5900,7 +6117,7 @@ CalculateCallDepth(CallGraphNode *node,
59006117
59016118static void ValidateCallGraph (ValidationContext &ValCtx) {
59026119 // Build CallGraph.
5903- CallGraph CG (* ValCtx.DxilMod . GetModule () );
6120+ CallGraph &CG = ValCtx.GetCallGraph ( );
59046121
59056122 std::unordered_map<CallGraphNode *, unsigned > depthMap;
59066123 std::unordered_set<CallGraphNode *> callStack;
@@ -6161,6 +6378,8 @@ HRESULT ValidateDxilModule(llvm::Module *pModule, llvm::Module *pDebugModule) {
61616378
61626379 ValidateShaderFlags (ValCtx);
61636380
6381+ ValidateEntryCompatibility (ValCtx);
6382+
61646383 ValidateEntrySignatures (ValCtx);
61656384
61666385 ValidateUninitializedOutput (ValCtx);
0 commit comments