Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/dxc/DXIL/DxilInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,42 @@ struct LlvmInst_VAArg {
bool isAllowed() const { return false; }
};

/// This instruction extracts from vector
struct LlvmInst_ExtractElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ExtractElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ExtractElement;
}
// Validation support
bool isAllowed() const { return true; }
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little suprised to see ExtractElement`InsertElement\ShuffleVector` weren't already defined since I expected at least the first two as required for scalarization which DXIL already supports.

If these are new for vectorization, should isAllowed() be checking IsSM69Plus() instead of always returning true?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalarization took place before DXIL was generated. The output has no native vectors in non-lib shaders and so had no need for these functions. Intermediate steps might have these, but were disallowed for validation of final output. For similar reasons as their inclusion in intermediate steps, library shaders would allow these ops as they preserved native vectors in function interfaces. The check for library shader allowed ops accounts for this in validation.

case Instruction::InsertElement:

The inclusion here adds them to the list of instructions that are approved for general validation by IsLLVMInstructionAllowed, a generated function that merely
returns whether the opcode is in the list of approved LLVM instructions for a non-lib. It is generated into DxilValidationImpl.inc and called here:

if (!IsLLVMInstructionAllowed(I)) {

Similar to this function, this member is meant to return just true or false if it is allowed at all. I should add a 6.9 check to validation though.

};

/// This instruction inserts into vector
struct LlvmInst_InsertElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_InsertElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::InsertElement;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction Shuffle two vectors
struct LlvmInst_ShuffleVector {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ShuffleVector(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ShuffleVector;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction extracts from aggregate
struct LlvmInst_ExtractValue {
llvm::Instruction *Instr;
Expand Down
2 changes: 2 additions & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,8 @@ class DxilMDHelper {
// Dxil version.
void EmitDxilVersion(unsigned Major, unsigned Minor);
void LoadDxilVersion(unsigned &Major, unsigned &Minor);
static bool LoadDxilVersion(const llvm::Module *pModule, unsigned &Major,
unsigned &Minor);

// Validator version.
void EmitValidatorVersion(unsigned Major, unsigned Minor);
Expand Down
4 changes: 4 additions & 0 deletions include/dxc/DXIL/DxilUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ bool DeleteDeadAllocas(llvm::Function &F);
llvm::Value *GEPIdxToOffset(llvm::GetElementPtrInst *GEP,
llvm::IRBuilder<> &Builder, hlsl::OP *OP,
const llvm::DataLayout &DL);

// Passes back Dxil version of the given module on true return.
bool LoadDxilVersion(const llvm::Module *M, unsigned &Major, unsigned &Minor);

} // namespace dxilutil

} // namespace hlsl
23 changes: 17 additions & 6 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,28 @@ void DxilMDHelper::EmitDxilVersion(unsigned Major, unsigned Minor) {
pDxilVersionMD->addOperand(MDNode::get(m_Ctx, MDVals));
}

void DxilMDHelper::LoadDxilVersion(unsigned &Major, unsigned &Minor) {
NamedMDNode *pDxilVersionMD = m_pModule->getNamedMetadata(kDxilVersionMDName);
IFTBOOL(pDxilVersionMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
IFTBOOL(pDxilVersionMD->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
// Load dxil version from metadata contained in pModule.
// Returns true and passes result through
// the dxil major/minor version params if valid.
// Returns false if metadata is missing or invalid.
bool DxilMDHelper::LoadDxilVersion(const Module *pModule, unsigned &Major,
unsigned &Minor) {
NamedMDNode *pDxilVersionMD = pModule->getNamedMetadata(kDxilVersionMDName);
IFRBOOL(pDxilVersionMD != nullptr, false);
IFRBOOL(pDxilVersionMD->getNumOperands() == 1, false);

MDNode *pVersionMD = pDxilVersionMD->getOperand(0);
IFTBOOL(pVersionMD->getNumOperands() == kDxilVersionNumFields,
DXC_E_INCORRECT_DXIL_METADATA);
IFRBOOL(pVersionMD->getNumOperands() == kDxilVersionNumFields, false);

Major = ConstMDToUint32(pVersionMD->getOperand(kDxilVersionMajorIdx));
Minor = ConstMDToUint32(pVersionMD->getOperand(kDxilVersionMinorIdx));

return true;
}

void DxilMDHelper::LoadDxilVersion(unsigned &Major, unsigned &Minor) {
IFTBOOL(LoadDxilVersion(m_pModule, Major, Minor),
DXC_E_INCORRECT_DXIL_METADATA);
}

//
Expand Down
13 changes: 13 additions & 0 deletions lib/DXIL/DxilUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,5 +1394,18 @@ bool DeleteDeadAllocas(llvm::Function &F) {
return Changed;
}

// Retrieve dxil version in the given module.
// Where the module doesn't already have a Dxil module,
// it identifies and returns the version info from the metatdata.
// Returns false where none of that works, but that shouldn't happen much.
bool LoadDxilVersion(const Module *M, unsigned &Major, unsigned &Minor) {
if (M->HasDxilModule()) {
M->GetDxilModule().GetShaderModel()->GetDxilVersion(Major, Minor);
return true;
}
// No module, try metadata.
return DxilMDHelper::LoadDxilVersion(M, Major, Minor);
}

} // namespace dxilutil
} // namespace hlsl
23 changes: 22 additions & 1 deletion lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,9 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx,
return true;

if (Ty->isVectorTy()) {
if (Ty->getVectorNumElements() > 1 &&
ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
return true;
ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector);
return false;
}
Expand Down Expand Up @@ -2669,6 +2672,23 @@ static bool IsLLVMInstructionAllowedForLib(Instruction &I,
}
}

// Shader model specific checks for valid LLVM instructions.
// Currently only checks for pre 6.9 usage of vector operations.
// Returns false if shader model is pre 6.9 and I represents a vector
// operation. Returns true otherwise.
static bool IsLLVMInstructionAllowedForShaderModel(Instruction &I,
ValidationContext &ValCtx) {
if (ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
return true;
unsigned OpCode = I.getOpcode();
if (OpCode == Instruction::InsertElement ||
OpCode == Instruction::ExtractElement ||
OpCode == Instruction::ShuffleVector)
return false;

return true;
}

static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
bool SupportsMinPrecision =
ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
Expand All @@ -2691,7 +2711,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
}

// Instructions must be allowed.
if (!IsLLVMInstructionAllowed(I)) {
if (!IsLLVMInstructionAllowed(I) ||
!IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) {
if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) {
ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed);
continue;
Expand Down
6 changes: 6 additions & 0 deletions lib/HLSL/DxilLinker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,12 @@ void DxilLinkJob::RunPreparePass(Module &M) {
// For static global handle.
PM.add(createLowerStaticGlobalIntoAlloca());

// Change dynamic indexing vector to array where vectors aren't
// supported, but might be there from the initial compile.
if (!pSM->IsSM69Plus())
PM.add(
createDynamicIndexingVectorToArrayPass(false /* ReplaceAllVector */));

// Remove MultiDimArray from function call arg.
PM.add(createMultiDimArrayToOneDimArrayPass());

Expand Down
60 changes: 41 additions & 19 deletions lib/HLSL/HLMatrixBitcastLowerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ Type *TryLowerMatTy(Type *Ty) {
}

class MatrixBitcastLowerPass : public FunctionPass {
bool SupportsVectors = false;

public:
static char ID; // Pass identification, replacement for typeid
explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}

StringRef getPassName() const override { return "Matrix Bitcast lower"; }
bool runOnFunction(Function &F) override {
DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
SupportsVectors = DM.GetShaderModel()->IsSM69Plus();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to avoid full DxilModule dependency, I think it would be reasonable to have something that can look up the DXIL version in the module without loading full metadata, whether or not it's HLM or DM.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea. I know I mentioned it somewhere before. I'd prefer it not block this change though.


bool bUpdated = false;
std::unordered_set<BitCastInst *> matCastSet;
for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
Expand All @@ -100,7 +104,6 @@ class MatrixBitcastLowerPass : public FunctionPass {
}
}

DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
// Remove bitcast which has CallInst user.
if (DM.GetShaderModel()->IsLib()) {
for (auto it = matCastSet.begin(); it != matCastSet.end();) {
Expand Down Expand Up @@ -185,18 +188,19 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
User *U = *(it++);
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
Type *EltTy = GEP->getType()->getPointerElementType();
if (HLMatrixType::isa(EltTy)) {
if (HLMatrixType MatTy = HLMatrixType::dyn_cast(EltTy)) {
// Change gep matrixArray, 0, index
// into
// gep oneDimArray, 0, index * matSize
IRBuilder<> Builder(GEP);
SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
DXASSERT(idxList.size() == 2,
"else not one dim matrix array index to matrix");

HLMatrixType MatTy = HLMatrixType::cast(EltTy);
Value *matSize = Builder.getInt32(MatTy.getNumElements());
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
unsigned NumElts = MatTy.getNumElements();
if (!SupportsVectors || NumElts == 1) {
Value *MatSize = Builder.getInt32(NumElts);
idxList.back() = Builder.CreateMul(idxList.back(), MatSize);
}
Value *NewGEP = Builder.CreateGEP(A, idxList);
lowerMatrix(GEP, NewGEP);
DXASSERT(GEP->user_empty(), "else lower matrix fail");
Expand All @@ -211,13 +215,23 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
} else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
IRBuilder<> Builder(LI);
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
Value *NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
Value *NewVec = nullptr;
unsigned VecSize = Ty->getVectorNumElements();
if (SupportsVectors && VecSize > 1) {
// Create a replacement load using the vector pointer.
Instruction *NewLd = LI->clone();
unsigned VecIdx = NewLd->getNumOperands() - 1;
NewLd->setOperand(VecIdx, A);
Builder.Insert(NewLd);
NewVec = NewLd;
} else {
Value *zeroIdx = Builder.getInt32(0);
NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < VecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that below this point, there's:

    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {

where it still scalarizes the store for the vector.

Did you mean to leave that scalarization in?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I omitted store support in error. Clearly we need some testing for it as it never caused any issues.

}
LI->replaceAllUsesWith(NewVec);
LI->eraseFromParent();
Expand All @@ -228,12 +242,20 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
Value *V = ST->getValueOperand();
if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
IRBuilder<> Builder(LI);
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateExtractElement(V, i);
Builder.CreateStore(Elt, GEP);
if (SupportsVectors && Ty->getVectorNumElements() > 1) {
// Create a replacement store using the vector pointer.
Instruction *NewSt = ST->clone();
unsigned VecIdx = NewSt->getNumOperands() - 1;
NewSt->setOperand(VecIdx, A);
Builder.Insert(NewSt);
} else {
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateExtractElement(V, i);
Builder.CreateStore(Elt, GEP);
}
}
ST->eraseFromParent();
} else {
Expand Down
3 changes: 3 additions & 0 deletions lib/HLSL/HLModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ MDTuple *HLModule::EmitHLResources() {

void HLModule::LoadHLResources(const llvm::MDOperand &MDO) {
const llvm::MDTuple *pSRVs, *pUAVs, *pCBuffers, *pSamplers;
// No resources. Nothing to do.
if (MDO.get() == nullptr)
return;
m_pMDHelper->GetDxilResources(MDO, pSRVs, pUAVs, pCBuffers, pSamplers);

// Load SRV records.
Expand Down
40 changes: 31 additions & 9 deletions lib/Transforms/Scalar/LowerTypePasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/HLSL/HLModule.h"
Expand Down Expand Up @@ -180,10 +181,12 @@ bool LowerTypePass::runOnModule(Module &M) {
namespace {
class DynamicIndexingVectorToArray : public LowerTypePass {
bool ReplaceAllVectors;
bool SupportsVectors;

public:
explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll),
SupportsVectors(false) {}
static char ID; // Pass identification, replacement for typeid
void applyOptions(PassOptions O) override;
void dumpConfig(raw_ostream &OS) override;
Expand All @@ -194,6 +197,7 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
Type *lowerType(Type *Ty) override;
Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
StringRef getGlobalPrefix() override { return ".v"; }
void initialize(Module &M) override;

private:
bool HasVectorDynamicIndexing(Value *V);
Expand All @@ -207,6 +211,18 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
void ReplaceAddrSpaceCast(ConstantExpr *CE, Value *A, IRBuilder<> &Builder);
};

void DynamicIndexingVectorToArray::initialize(Module &M) {
// Set vector support according to available Dxil version.
// Use HLModule or metadata for version info.
// Otherwise retrieve from dxil module or metadata.
unsigned Major = 0, Minor = 0;
if (M.HasHLModule())
M.GetHLModule().GetShaderModel()->GetDxilVersion(Major, Minor);
else
dxilutil::LoadDxilVersion(&M, Major, Minor);
SupportsVectors = (Major == 1 && Minor >= 9);
}

void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
ReplaceAllVectors);
Expand Down Expand Up @@ -306,9 +322,21 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
}

bool DynamicIndexingVectorToArray::needToLower(Value *V) {
bool MustReplaceVector = ReplaceAllVectors;
Type *Ty = V->getType()->getPointerElementType();
if (dyn_cast<VectorType>(Ty)) {
if (isa<GlobalVariable>(V) || ReplaceAllVectors) {

if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
// Array must be replaced even without dynamic indexing to remove vector
// type in dxil.
MustReplaceVector = true;
Ty = dxilutil::GetArrayEltTy(AT);
}

if (isa<VectorType>(Ty)) {
// Only needed for 2+ vectors where native vectors unsupported.
if (SupportsVectors && Ty->getVectorNumElements() > 1)
return false;
if (isa<GlobalVariable>(V) || MustReplaceVector) {
return true;
}
// Don't lower local vector which only static indexing.
Expand All @@ -319,12 +347,6 @@ bool DynamicIndexingVectorToArray::needToLower(Value *V) {
ReplaceStaticIndexingOnVector(V);
return false;
}
} else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
// Array must be replaced even without dynamic indexing to remove vector
// type in dxil.
// TODO: optimize static array index in later pass.
Type *EltTy = dxilutil::GetArrayEltTy(AT);
return isa<VectorType>(EltTy);
}
return false;
}
Expand Down
Loading