Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions include/dxc/DXIL/DxilInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10651,14 +10651,14 @@ struct DxilInst_LinAlgMatrixLoadFromMemory {
bool requiresUniformInputs() const { return false; }
// Operand indexes
enum OperandIdx {
arg_groupsharedArr = 1,
arg_memory = 1,
arg_offset = 2,
arg_stride = 3,
arg_layout = 4,
};
// Accessors
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(1); }
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(1, val); }
llvm::Value *get_memory() const { return Instr->getOperand(1); }
void set_memory(llvm::Value *val) { Instr->setOperand(1, val); }
llvm::Value *get_offset() const { return Instr->getOperand(2); }
void set_offset(llvm::Value *val) { Instr->setOperand(2, val); }
llvm::Value *get_stride() const { return Instr->getOperand(3); }
Expand Down Expand Up @@ -10854,16 +10854,16 @@ struct DxilInst_LinAlgMatrixStoreToMemory {
// Operand indexes
enum OperandIdx {
arg_matrix = 1,
arg_groupsharedArr = 2,
arg_memory = 2,
arg_offset = 3,
arg_stride = 4,
arg_layout = 5,
};
// Accessors
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
llvm::Value *get_memory() const { return Instr->getOperand(2); }
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
llvm::Value *get_offset() const { return Instr->getOperand(3); }
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
llvm::Value *get_stride() const { return Instr->getOperand(4); }
Expand Down Expand Up @@ -11091,16 +11091,16 @@ struct DxilInst_LinAlgMatrixAccumulateToMemory {
// Operand indexes
enum OperandIdx {
arg_matrix = 1,
arg_groupsharedArr = 2,
arg_memory = 2,
arg_offset = 3,
arg_stride = 4,
arg_layout = 5,
};
// Accessors
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
llvm::Value *get_memory() const { return Instr->getOperand(2); }
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
llvm::Value *get_offset() const { return Instr->getOperand(3); }
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
llvm::Value *get_stride() const { return Instr->getOperand(4); }
Expand Down
129 changes: 76 additions & 53 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3166,26 +3166,33 @@ const char *OP::GetOverloadTypeName(unsigned TypeSlot) {
StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
DXASSERT(!Ty->isVoidTy(), "must not pass void type here");
unsigned TypeSlot = OP::GetTypeSlot(Ty);

if (TypeSlot < TS_BasicCount) {
return GetOverloadTypeName(TypeSlot);
} else if (TypeSlot == TS_UDT) {
}

switch (TypeSlot) {
case TS_UDT: {
if (Ty->isPointerTy())
Ty = Ty->getPointerElementType();
StructType *ST = cast<StructType>(Ty);
return ST->getStructName();
} else if (TypeSlot == TS_Object) {
}
case TS_Object: {
StructType *ST = cast<StructType>(Ty);
if (dxilutil::IsHLSLLinAlgMatrixType(Ty))
return (Twine("m") + Twine(dxilutil::GetHLSLLinAlgMatrixTypeMangling(ST)))
.toStringRef(Storage);
return ST->getStructName();
} else if (TypeSlot == TS_Vector) {
}
case TS_Vector: {
VectorType *VecTy = cast<VectorType>(Ty);
return (Twine("v") + Twine(VecTy->getNumElements()) +
Twine(
GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType()))))
.toStringRef(Storage);
} else if (TypeSlot == TS_Extended) {
}
case TS_Extended: {
DXASSERT(isa<StructType>(Ty),
"otherwise, extended overload type not wrapped in struct type.");
StructType *ST = cast<StructType>(Ty);
Expand All @@ -3200,11 +3207,14 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
OS << GetTypeName(ST->getElementType(I), TempStr);
}
return OS.str();
} else {
raw_svector_ostream OS(Storage);
Ty->print(OS);
return OS.str();
}
default:
break;
}

raw_svector_ostream OS(Storage);
Ty->print(OS);
return OS.str();
}

StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
Expand Down Expand Up @@ -4314,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
#define VEC2(_y) A(VectorType::get(_y, 2))
#define VEC4(_y) A(GetStructVectorType(4, _y))
#define VEC9(_y) A(VectorType::get(_y, 9))
#define TGSM(_y) A(PointerType::get(_y, DXIL::kTGSMAddrSpace))

// Extended Overload types are wrapped in an anonymous struct
#define EXT(_y) A(cast<StructType>(pOverloadType)->getElementType(_y))
#define EXT(_y) cast<StructType>(pOverloadType)->getElementType(_y)

/* <py::lines('OPCODE-OLOAD-FUNCS')>hctdb_instrhelp.get_oloads_funcs()</py>*/
switch (opCode) { // return opCode
Expand Down Expand Up @@ -6427,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

// Linear Algebra Operations
case OpCode::MatVecMul:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
A(pI1);
A(pI32);
A(pRes);
Expand All @@ -6443,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
A(pI1);
break;
case OpCode::MatVecMulAdd:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
A(pI1);
A(pI32);
A(pRes);
Expand All @@ -6464,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
case OpCode::OuterProductAccumulate:
A(pV);
A(pI32);
EXT(0);
EXT(1);
A(EXT(0));
A(EXT(1));
A(pRes);
A(pI32);
A(pI32);
Expand Down Expand Up @@ -6568,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

// Linear Algebra Operations
case OpCode::LinAlgMatrixMultiplyAccumulate:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
EXT(3);
A(EXT(1));
A(EXT(2));
A(EXT(3));
break;
case OpCode::LinAlgFillMatrix:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
break;
case OpCode::LinAlgCopyConvertMatrix:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
A(pI1);
break;
case OpCode::LinAlgMatrixLoadFromDescriptor:
Expand All @@ -6594,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
A(pI32);
break;
case OpCode::LinAlgMatrixLoadFromMemory:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
TGSM(EXT(1));
A(pI32);
A(pI32);
A(pI32);
Expand All @@ -6613,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
A(pI32);
break;
case OpCode::LinAlgMatrixGetElement:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
A(pI32);
break;
case OpCode::LinAlgMatrixSetElement:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
A(EXT(1));
A(pI32);
EXT(2);
A(EXT(2));
break;
case OpCode::LinAlgMatrixStoreToDescriptor:
A(pV);
Expand All @@ -6637,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
case OpCode::LinAlgMatrixStoreToMemory:
A(pV);
A(pI32);
EXT(0);
EXT(1);
A(EXT(0));
TGSM(EXT(1));
A(pI32);
A(pI32);
A(pI32);
Expand All @@ -6648,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
A(pI32);
break;
case OpCode::LinAlgMatrixMultiply:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
A(EXT(1));
A(EXT(2));
break;
case OpCode::LinAlgMatrixAccumulate:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
A(EXT(1));
A(EXT(2));
break;
case OpCode::LinAlgMatVecMul:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
A(EXT(1));
A(EXT(2));
A(pI32);
break;
case OpCode::LinAlgMatVecMulAdd:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
A(EXT(1));
A(EXT(2));
A(pI32);
EXT(3);
A(EXT(3));
A(pI32);
break;
case OpCode::LinAlgMatrixAccumulateToDescriptor:
Expand All @@ -6687,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
case OpCode::LinAlgMatrixAccumulateToMemory:
A(pV);
A(pI32);
EXT(0);
EXT(1);
A(EXT(0));
TGSM(EXT(1));
A(pI32);
A(pI32);
A(pI32);
break;
case OpCode::LinAlgMatrixOuterProduct:
EXT(0);
A(EXT(0));
A(pI32);
EXT(1);
EXT(2);
A(EXT(1));
A(EXT(2));
break;

//
Expand Down Expand Up @@ -7064,16 +7075,13 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::MatVecMulAdd:
case OpCode::LinAlgFillMatrix:
case OpCode::LinAlgCopyConvertMatrix:
case OpCode::LinAlgMatrixLoadFromMemory:
case OpCode::LinAlgMatrixGetElement:
if (FT->getNumParams() < 2)
return nullptr;
return llvm::StructType::get(Ctx,
{FT->getReturnType(), FT->getParamType(1)});

case OpCode::OuterProductAccumulate:
case OpCode::LinAlgMatrixStoreToMemory:
case OpCode::LinAlgMatrixAccumulateToMemory:
if (FT->getNumParams() < 3)
return nullptr;
return llvm::StructType::get(Ctx,
Expand All @@ -7086,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
{FT->getReturnType(), FT->getParamType(1),
FT->getParamType(2), FT->getParamType(3)});

case OpCode::LinAlgMatrixLoadFromMemory:
if (FT->getNumParams() < 2)
return nullptr;
return llvm::StructType::get(
Ctx,
{FT->getReturnType(), FT->getParamType(1)->getPointerElementType()});

case OpCode::LinAlgMatrixSetElement:
if (FT->getNumParams() < 4)
return nullptr;
return llvm::StructType::get(
Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)});

case OpCode::LinAlgMatrixStoreToMemory:
case OpCode::LinAlgMatrixAccumulateToMemory:
if (FT->getNumParams() < 3)
return nullptr;
return llvm::StructType::get(
Ctx,
{FT->getParamType(1), FT->getParamType(2)->getPointerElementType()});

case OpCode::LinAlgMatrixMultiply:
case OpCode::LinAlgMatrixAccumulate:
case OpCode::LinAlgMatVecMul:
Expand Down
Loading
Loading