Skip to content

Commit 8da9786

Browse files
Remove unused target types from !dx.targetTypes metadata (#8202)
Make sure `!dx.targetTypes` metadata node in the final module only contains types that are actually used module. Introduces a new pass `DxilTrimTargetTypes` that runs towards the end of the pipeline and re-creates the `!dx.targetTypes` metadata node with only the target types used in the module. This pass runs during single shader compilation and also when multiple lib_6_x modules are linked together for SM 6.10. Fixes #8133 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b8dd93b commit 8da9786

11 files changed

Lines changed: 265 additions & 1 deletion

include/dxc/DXIL/DxilMetadataHelper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ class DxilMDHelper {
648648
public:
649649
// Utility functions.
650650
static bool IsKnownNamedMetaData(const llvm::NamedMDNode &Node);
651+
static bool IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node);
651652
static bool IsKnownMetadataID(llvm::LLVMContext &Ctx, unsigned ID);
652653
static void GetKnownMetadataIDs(llvm::LLVMContext &Ctx,
653654
llvm::SmallVectorImpl<unsigned> *pIDs);

include/dxc/DXIL/DxilOperations.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class OP {
145145
static bool CheckOpCodeTable();
146146
static bool IsDxilOpFuncName(llvm::StringRef name);
147147
static bool IsDxilOpFunc(const llvm::Function *F);
148+
static bool IsDxilOpLinAlgFuncName(llvm::StringRef Name);
148149
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
149150
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
150151
static bool IsDxilOpWave(OpCode C);
@@ -286,6 +287,7 @@ class OP {
286287
static const char *m_NamePrefix;
287288
static const char *m_TypePrefix;
288289
static const char *m_MatrixTypePrefix;
290+
static const char *m_LinAlgNamePrefix;
289291
static unsigned GetTypeSlot(llvm::Type *pType);
290292
static const char *GetOverloadTypeName(unsigned TypeSlot);
291293
static llvm::StringRef GetTypeName(llvm::Type *Ty,

include/dxc/DXIL/DxilUtil.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ llvm::Type *GetHLSLHitObjectType(llvm::Module *M);
166166
bool IsHLSLHitObjectType(llvm::Type *Ty);
167167
bool IsHLSLLinAlgMatrixType(llvm::Type *Ty);
168168
llvm::StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty);
169+
bool IsHLSLKnownTargetType(llvm::Type *Ty);
169170
bool IsHLSLResourceDescType(llvm::Type *Ty);
170171
bool IsResourceSingleComponent(llvm::Type *Ty);
171172
uint8_t GetResourceComponentCount(llvm::Type *Ty);

include/dxc/HLSL/DxilGenerationPass.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,7 @@ void initializeDxilSimpleGVNEliminateRegionPass(llvm::PassRegistry &);
151151
ModulePass *createDxilModuleInitPass();
152152
void initializeDxilModuleInitPass(llvm::PassRegistry &);
153153

154+
ModulePass *createDxilTrimTargetTypesPass();
155+
void initializeDxilTrimTargetTypesPass(llvm::PassRegistry &);
156+
154157
} // namespace llvm

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3338,6 +3338,11 @@ bool DxilMDHelper::IsKnownNamedMetaData(const llvm::NamedMDNode &Node) {
33383338
return false;
33393339
}
33403340

3341+
bool DxilMDHelper::IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node) {
3342+
return IsKnownNamedMetaData(Node) &&
3343+
Node.getName() != DxilMDHelper::kDxilTargetTypesMDName;
3344+
}
3345+
33413346
bool DxilMDHelper::IsKnownMetadataID(LLVMContext &Ctx, unsigned ID) {
33423347
SmallVector<unsigned, 2> IDs;
33433348
GetKnownMetadataIDs(Ctx, &IDs);

lib/DXIL/DxilOperations.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,7 @@ const char *OP::m_OverloadTypeName[TS_BasicCount] = {
30403040
const char *OP::m_NamePrefix = "dx.op.";
30413041
const char *OP::m_TypePrefix = "dx.types.";
30423042
const char *OP::m_MatrixTypePrefix = "class.matrix."; // Allowed in library
3043+
const char *OP::m_LinAlgNamePrefix = "dx.op.linAlg";
30433044

30443045
// Keep sync with DXIL::AtomicBinOpCode
30453046
static const char *AtomicBinOpCodeName[] = {
@@ -3306,6 +3307,10 @@ bool OP::IsDxilOpFuncName(StringRef name) {
33063307
return name.startswith(OP::m_NamePrefix);
33073308
}
33083309

3310+
bool OP::IsDxilOpLinAlgFuncName(StringRef Name) {
3311+
return Name.startswith(OP::m_LinAlgNamePrefix);
3312+
}
3313+
33093314
bool OP::IsDxilOpFunc(const llvm::Function *F) {
33103315
// Test for null to allow IsDxilOpFunc(Call.getCalledFunc()) to be resilient
33113316
// to indirect calls

lib/DXIL/DxilUtil.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,11 @@ StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty) {
631631
return Ty->getStructName().substr(strlen(DXIL::kDxLinAlgMatrixTypePrefix));
632632
}
633633

634+
bool IsHLSLKnownTargetType(llvm::Type *Ty) {
635+
// Currently only LinAlgMatrix types are target types.
636+
return IsHLSLLinAlgMatrixType(Ty);
637+
}
638+
634639
bool IsHLSLResourceDescType(llvm::Type *Ty) {
635640
if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
636641
if (!ST->hasName())

lib/HLSL/DxilLinker.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ void DxilLinkJob::LinkNamedMDNodes(Module *pM, ValueToValueMapTy &vmap) {
575575
if (&NMD == pSrcModFlags)
576576
continue;
577577
// Skip dxil metadata which will be regenerated.
578-
if (DxilMDHelper::IsKnownNamedMetaData(NMD))
578+
if (DxilMDHelper::IsKnownGeneratedMetaData(NMD))
579579
continue;
580580
NamedMDNode *DestNMD = pM->getOrInsertNamedMetadata(NMD.getName());
581581
// Add Src elements into Dest node.
@@ -1293,6 +1293,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
12931293
PM.add(createComputeViewIdStatePass());
12941294
PM.add(createDxilDeadFunctionEliminationPass());
12951295
PM.add(createNoPausePassesPass());
1296+
PM.add(createDxilTrimTargetTypesPass());
12961297
PM.add(createDxilEmitMetadataPass());
12971298
PM.add(createDxilFinalizePreservesPass());
12981299

lib/HLSL/DxilPreparePasses.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,110 @@ INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit",
16411641

16421642
namespace {
16431643

1644+
// DxilTrimTargetTypes pass makes sure the !dx.targetTypes metadata only
1645+
// contains types that are actually used by the shader.
1646+
1647+
class DxilTrimTargetTypes : public ModulePass {
1648+
public:
1649+
static char ID; // Pass identification, replacement for typeid
1650+
explicit DxilTrimTargetTypes() : ModulePass(ID) {}
1651+
1652+
StringRef getPassName() const override {
1653+
return "HLSL DXIL Trim Target Types";
1654+
}
1655+
1656+
// Map of target type to its metadata node and usage flag.
1657+
using TargetTypesUsageMap =
1658+
SmallDenseMap<llvm::Type *, std::pair<MDTuple *, bool>, 16>;
1659+
1660+
void markTargetTypeAsUsed(TargetTypesUsageMap &Map, llvm::Type *Ty) {
1661+
auto It = Map.find(Ty);
1662+
assert(It != Map.end() &&
1663+
"used target type is not in dx.targetTypes metadata list");
1664+
(*It).second.second = true;
1665+
}
1666+
1667+
bool runOnModule(Module &M) override {
1668+
NamedMDNode *TargetTypesMDNode =
1669+
M.getNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName);
1670+
if (!TargetTypesMDNode)
1671+
return false;
1672+
1673+
// Add all target types that from "dx.targetTypes" metadata to the map
1674+
// to track their usage.
1675+
TargetTypesUsageMap TargetTypesMap;
1676+
for (MDNode *Node : TargetTypesMDNode->operands()) {
1677+
MDTuple *TypeMD = dyn_cast<MDTuple>(Node);
1678+
if (!TypeMD || TypeMD->getNumOperands() == 0)
1679+
continue;
1680+
1681+
ConstantAsMetadata *ConstMD =
1682+
dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0).get());
1683+
if (!ConstMD)
1684+
continue;
1685+
1686+
Constant *TypeUndefPtr = ConstMD->getValue();
1687+
llvm::Type *Ty = TypeUndefPtr->getType();
1688+
TargetTypesMap.try_emplace(Ty, std::make_pair(TypeMD, false));
1689+
}
1690+
1691+
// Scan all LinAlgMatrix functions and check the return type and argument
1692+
// types to find all used target types.
1693+
for (const llvm::Function &F : M.functions()) {
1694+
if (!F.isDeclaration())
1695+
continue;
1696+
1697+
// Currently only LinAlgMatrix ops use target types.
1698+
if (!OP::IsDxilOpLinAlgFuncName(F.getName()))
1699+
continue;
1700+
1701+
llvm::Type *RetTy = F.getReturnType();
1702+
if (dxilutil::IsHLSLKnownTargetType(RetTy))
1703+
markTargetTypeAsUsed(TargetTypesMap, RetTy);
1704+
1705+
for (const auto &Arg : F.args()) {
1706+
llvm::Type *Ty = Arg.getType();
1707+
if (dxilutil::IsHLSLKnownTargetType(Ty))
1708+
markTargetTypeAsUsed(TargetTypesMap, Ty);
1709+
}
1710+
}
1711+
1712+
// Remove old metadata node from the module.
1713+
TargetTypesMDNode->eraseFromParent();
1714+
1715+
// Create a new one with the used target types.
1716+
NamedMDNode *NewTargetTypesMDNode =
1717+
M.getOrInsertNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName);
1718+
for (auto &Entry : TargetTypesMap) {
1719+
MDTuple *Node = Entry.second.first;
1720+
bool IsUsed = Entry.second.second;
1721+
if (IsUsed)
1722+
NewTargetTypesMDNode->addOperand(Node);
1723+
}
1724+
1725+
// If no target type is used, remove the new metadata node from module.
1726+
if (NewTargetTypesMDNode->getNumOperands() == 0)
1727+
NewTargetTypesMDNode->eraseFromParent();
1728+
1729+
return true;
1730+
}
1731+
};
1732+
1733+
} // namespace
1734+
1735+
char DxilTrimTargetTypes::ID = 0;
1736+
1737+
ModulePass *llvm::createDxilTrimTargetTypesPass() {
1738+
return new DxilTrimTargetTypes();
1739+
}
1740+
1741+
INITIALIZE_PASS(DxilTrimTargetTypes, "hlsl-trim-target-types",
1742+
"HLSL DXIL Trim Target Types", false, false)
1743+
1744+
///////////////////////////////////////////////////////////////////////////////
1745+
1746+
namespace {
1747+
16441748
const StringRef UniNoWaveSensitiveGradientErrMsg =
16451749
"Gradient operations are not affected by wave-sensitive data or control "
16461750
"flow.";
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 %s | FileCheck %s
3+
4+
// This test is using 2 LinAlgMatrix operations:
5+
// - __builtin_LinAlg_FillMatrix - has the matrix as a return value
6+
// - __builtin_LinAlg_MatrixLength - has the matrix as an argument
7+
// This is done to verify that target types are correctly collected from both
8+
// return values and arguments of LinAlgMatrix operations.
9+
10+
uint useMatrix1() {
11+
// Matrix<ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::Thread> m;
12+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 4, 5, 0, 0)]] mat1;
13+
// mat1 = Matrix::Splat(5);
14+
__builtin_LinAlg_FillMatrix(mat1, 5);
15+
16+
// Matrix<ComponentType::U32, 3, 3, MatrixUse::A, MatrixScope::Thread> m;
17+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 3, 0, 0)]] mat2;
18+
// return mat2.Length();
19+
return __builtin_LinAlg_MatrixLength(mat2);
20+
}
21+
22+
uint useMatrix2() {
23+
// Matrix<ComponentType::F64, 2, 2, MatrixUse::B, MatrixScope::Wave> m;
24+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 2, 2, 1, 1)]] mat3;
25+
// mat3 = Matrix::Splat(5);
26+
__builtin_LinAlg_FillMatrix(mat3, 5);
27+
// return mat3.Length();
28+
return __builtin_LinAlg_MatrixLength(mat3);
29+
}
30+
31+
RWBuffer<uint> Out;
32+
33+
[numthreads(4,1,1)]
34+
void main() {
35+
Out[0] = useMatrix1();
36+
}
37+
38+
// CHECK: !dx.targetTypes = !{!{{[0-9]+}}, !{{[0-9]+}}}
39+
// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0}
40+
// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC5M3N3U0S0 undef, i32 5, i32 3, i32 3, i32 0, i32 0}
41+
// CHECK-NOT: !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1}

0 commit comments

Comments
 (0)