Skip to content

Commit 05de2c9

Browse files
authored
[SM6.10] Implement LinAlg groupshared Builtins (microsoft#8254)
Implements the LinAlg Load/Store/Accumulate to memory groupshared builtins following the pattern of the previous builtins in alignment with the spec. https://github.com/microsoft/hlsl-specs/blob/main/proposals/0035-linalg-matrix.md Fixes microsoft#7903 Fixes microsoft#7905 Fixes microsoft#7907
1 parent 8e949f8 commit 05de2c9

17 files changed

Lines changed: 344 additions & 96 deletions

File tree

include/dxc/DXIL/DxilInstructions.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10651,14 +10651,14 @@ struct DxilInst_LinAlgMatrixLoadFromMemory {
1065110651
bool requiresUniformInputs() const { return false; }
1065210652
// Operand indexes
1065310653
enum OperandIdx {
10654-
arg_groupsharedArr = 1,
10654+
arg_memory = 1,
1065510655
arg_offset = 2,
1065610656
arg_stride = 3,
1065710657
arg_layout = 4,
1065810658
};
1065910659
// Accessors
10660-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(1); }
10661-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(1, val); }
10660+
llvm::Value *get_memory() const { return Instr->getOperand(1); }
10661+
void set_memory(llvm::Value *val) { Instr->setOperand(1, val); }
1066210662
llvm::Value *get_offset() const { return Instr->getOperand(2); }
1066310663
void set_offset(llvm::Value *val) { Instr->setOperand(2, val); }
1066410664
llvm::Value *get_stride() const { return Instr->getOperand(3); }
@@ -10854,16 +10854,16 @@ struct DxilInst_LinAlgMatrixStoreToMemory {
1085410854
// Operand indexes
1085510855
enum OperandIdx {
1085610856
arg_matrix = 1,
10857-
arg_groupsharedArr = 2,
10857+
arg_memory = 2,
1085810858
arg_offset = 3,
1085910859
arg_stride = 4,
1086010860
arg_layout = 5,
1086110861
};
1086210862
// Accessors
1086310863
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1086410864
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
10865-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
10866-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
10865+
llvm::Value *get_memory() const { return Instr->getOperand(2); }
10866+
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
1086710867
llvm::Value *get_offset() const { return Instr->getOperand(3); }
1086810868
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
1086910869
llvm::Value *get_stride() const { return Instr->getOperand(4); }
@@ -11091,16 +11091,16 @@ struct DxilInst_LinAlgMatrixAccumulateToMemory {
1109111091
// Operand indexes
1109211092
enum OperandIdx {
1109311093
arg_matrix = 1,
11094-
arg_groupsharedArr = 2,
11094+
arg_memory = 2,
1109511095
arg_offset = 3,
1109611096
arg_stride = 4,
1109711097
arg_layout = 5,
1109811098
};
1109911099
// Accessors
1110011100
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1110111101
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
11102-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
11103-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
11102+
llvm::Value *get_memory() const { return Instr->getOperand(2); }
11103+
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
1110411104
llvm::Value *get_offset() const { return Instr->getOperand(3); }
1110511105
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
1110611106
llvm::Value *get_stride() const { return Instr->getOperand(4); }

lib/DXIL/DxilOperations.cpp

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3166,26 +3166,33 @@ const char *OP::GetOverloadTypeName(unsigned TypeSlot) {
31663166
StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
31673167
DXASSERT(!Ty->isVoidTy(), "must not pass void type here");
31683168
unsigned TypeSlot = OP::GetTypeSlot(Ty);
3169+
31693170
if (TypeSlot < TS_BasicCount) {
31703171
return GetOverloadTypeName(TypeSlot);
3171-
} else if (TypeSlot == TS_UDT) {
3172+
}
3173+
3174+
switch (TypeSlot) {
3175+
case TS_UDT: {
31723176
if (Ty->isPointerTy())
31733177
Ty = Ty->getPointerElementType();
31743178
StructType *ST = cast<StructType>(Ty);
31753179
return ST->getStructName();
3176-
} else if (TypeSlot == TS_Object) {
3180+
}
3181+
case TS_Object: {
31773182
StructType *ST = cast<StructType>(Ty);
31783183
if (dxilutil::IsHLSLLinAlgMatrixType(Ty))
31793184
return (Twine("m") + Twine(dxilutil::GetHLSLLinAlgMatrixTypeMangling(ST)))
31803185
.toStringRef(Storage);
31813186
return ST->getStructName();
3182-
} else if (TypeSlot == TS_Vector) {
3187+
}
3188+
case TS_Vector: {
31833189
VectorType *VecTy = cast<VectorType>(Ty);
31843190
return (Twine("v") + Twine(VecTy->getNumElements()) +
31853191
Twine(
31863192
GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType()))))
31873193
.toStringRef(Storage);
3188-
} else if (TypeSlot == TS_Extended) {
3194+
}
3195+
case TS_Extended: {
31893196
DXASSERT(isa<StructType>(Ty),
31903197
"otherwise, extended overload type not wrapped in struct type.");
31913198
StructType *ST = cast<StructType>(Ty);
@@ -3200,11 +3207,14 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
32003207
OS << GetTypeName(ST->getElementType(I), TempStr);
32013208
}
32023209
return OS.str();
3203-
} else {
3204-
raw_svector_ostream OS(Storage);
3205-
Ty->print(OS);
3206-
return OS.str();
32073210
}
3211+
default:
3212+
break;
3213+
}
3214+
3215+
raw_svector_ostream OS(Storage);
3216+
Ty->print(OS);
3217+
return OS.str();
32083218
}
32093219

32103220
StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
@@ -4314,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
43144324
#define VEC2(_y) A(VectorType::get(_y, 2))
43154325
#define VEC4(_y) A(GetStructVectorType(4, _y))
43164326
#define VEC9(_y) A(VectorType::get(_y, 9))
4327+
#define TGSM(_y) A(PointerType::get(_y, DXIL::kTGSMAddrSpace))
43174328

43184329
// Extended Overload types are wrapped in an anonymous struct
4319-
#define EXT(_y) A(cast<StructType>(pOverloadType)->getElementType(_y))
4330+
#define EXT(_y) cast<StructType>(pOverloadType)->getElementType(_y)
43204331

43214332
/* <py::lines('OPCODE-OLOAD-FUNCS')>hctdb_instrhelp.get_oloads_funcs()</py>*/
43224333
switch (opCode) { // return opCode
@@ -6427,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64276438

64286439
// Linear Algebra Operations
64296440
case OpCode::MatVecMul:
6430-
EXT(0);
6441+
A(EXT(0));
64316442
A(pI32);
6432-
EXT(1);
6443+
A(EXT(1));
64336444
A(pI1);
64346445
A(pI32);
64356446
A(pRes);
@@ -6443,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64436454
A(pI1);
64446455
break;
64456456
case OpCode::MatVecMulAdd:
6446-
EXT(0);
6457+
A(EXT(0));
64476458
A(pI32);
6448-
EXT(1);
6459+
A(EXT(1));
64496460
A(pI1);
64506461
A(pI32);
64516462
A(pRes);
@@ -6464,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64646475
case OpCode::OuterProductAccumulate:
64656476
A(pV);
64666477
A(pI32);
6467-
EXT(0);
6468-
EXT(1);
6478+
A(EXT(0));
6479+
A(EXT(1));
64696480
A(pRes);
64706481
A(pI32);
64716482
A(pI32);
@@ -6568,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65686579

65696580
// Linear Algebra Operations
65706581
case OpCode::LinAlgMatrixMultiplyAccumulate:
6571-
EXT(0);
6582+
A(EXT(0));
65726583
A(pI32);
6573-
EXT(1);
6574-
EXT(2);
6575-
EXT(3);
6584+
A(EXT(1));
6585+
A(EXT(2));
6586+
A(EXT(3));
65766587
break;
65776588
case OpCode::LinAlgFillMatrix:
6578-
EXT(0);
6589+
A(EXT(0));
65796590
A(pI32);
6580-
EXT(1);
6591+
A(EXT(1));
65816592
break;
65826593
case OpCode::LinAlgCopyConvertMatrix:
6583-
EXT(0);
6594+
A(EXT(0));
65846595
A(pI32);
6585-
EXT(1);
6596+
A(EXT(1));
65866597
A(pI1);
65876598
break;
65886599
case OpCode::LinAlgMatrixLoadFromDescriptor:
@@ -6594,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65946605
A(pI32);
65956606
break;
65966607
case OpCode::LinAlgMatrixLoadFromMemory:
6597-
EXT(0);
6608+
A(EXT(0));
65986609
A(pI32);
6599-
EXT(1);
6610+
TGSM(EXT(1));
66006611
A(pI32);
66016612
A(pI32);
66026613
A(pI32);
@@ -6613,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66136624
A(pI32);
66146625
break;
66156626
case OpCode::LinAlgMatrixGetElement:
6616-
EXT(0);
6627+
A(EXT(0));
66176628
A(pI32);
6618-
EXT(1);
6629+
A(EXT(1));
66196630
A(pI32);
66206631
break;
66216632
case OpCode::LinAlgMatrixSetElement:
6622-
EXT(0);
6633+
A(EXT(0));
66236634
A(pI32);
6624-
EXT(1);
6635+
A(EXT(1));
66256636
A(pI32);
6626-
EXT(2);
6637+
A(EXT(2));
66276638
break;
66286639
case OpCode::LinAlgMatrixStoreToDescriptor:
66296640
A(pV);
@@ -6637,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66376648
case OpCode::LinAlgMatrixStoreToMemory:
66386649
A(pV);
66396650
A(pI32);
6640-
EXT(0);
6641-
EXT(1);
6651+
A(EXT(0));
6652+
TGSM(EXT(1));
66426653
A(pI32);
66436654
A(pI32);
66446655
A(pI32);
@@ -6648,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66486659
A(pI32);
66496660
break;
66506661
case OpCode::LinAlgMatrixMultiply:
6651-
EXT(0);
6662+
A(EXT(0));
66526663
A(pI32);
6653-
EXT(1);
6654-
EXT(2);
6664+
A(EXT(1));
6665+
A(EXT(2));
66556666
break;
66566667
case OpCode::LinAlgMatrixAccumulate:
6657-
EXT(0);
6668+
A(EXT(0));
66586669
A(pI32);
6659-
EXT(1);
6660-
EXT(2);
6670+
A(EXT(1));
6671+
A(EXT(2));
66616672
break;
66626673
case OpCode::LinAlgMatVecMul:
6663-
EXT(0);
6674+
A(EXT(0));
66646675
A(pI32);
6665-
EXT(1);
6666-
EXT(2);
6676+
A(EXT(1));
6677+
A(EXT(2));
66676678
A(pI32);
66686679
break;
66696680
case OpCode::LinAlgMatVecMulAdd:
6670-
EXT(0);
6681+
A(EXT(0));
66716682
A(pI32);
6672-
EXT(1);
6673-
EXT(2);
6683+
A(EXT(1));
6684+
A(EXT(2));
66746685
A(pI32);
6675-
EXT(3);
6686+
A(EXT(3));
66766687
A(pI32);
66776688
break;
66786689
case OpCode::LinAlgMatrixAccumulateToDescriptor:
@@ -6687,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66876698
case OpCode::LinAlgMatrixAccumulateToMemory:
66886699
A(pV);
66896700
A(pI32);
6690-
EXT(0);
6691-
EXT(1);
6701+
A(EXT(0));
6702+
TGSM(EXT(1));
66926703
A(pI32);
66936704
A(pI32);
66946705
A(pI32);
66956706
break;
66966707
case OpCode::LinAlgMatrixOuterProduct:
6697-
EXT(0);
6708+
A(EXT(0));
66986709
A(pI32);
6699-
EXT(1);
6700-
EXT(2);
6710+
A(EXT(1));
6711+
A(EXT(2));
67016712
break;
67026713

67036714
//
@@ -7064,16 +7075,13 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70647075
case OpCode::MatVecMulAdd:
70657076
case OpCode::LinAlgFillMatrix:
70667077
case OpCode::LinAlgCopyConvertMatrix:
7067-
case OpCode::LinAlgMatrixLoadFromMemory:
70687078
case OpCode::LinAlgMatrixGetElement:
70697079
if (FT->getNumParams() < 2)
70707080
return nullptr;
70717081
return llvm::StructType::get(Ctx,
70727082
{FT->getReturnType(), FT->getParamType(1)});
70737083

70747084
case OpCode::OuterProductAccumulate:
7075-
case OpCode::LinAlgMatrixStoreToMemory:
7076-
case OpCode::LinAlgMatrixAccumulateToMemory:
70777085
if (FT->getNumParams() < 3)
70787086
return nullptr;
70797087
return llvm::StructType::get(Ctx,
@@ -7086,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70867094
{FT->getReturnType(), FT->getParamType(1),
70877095
FT->getParamType(2), FT->getParamType(3)});
70887096

7097+
case OpCode::LinAlgMatrixLoadFromMemory:
7098+
if (FT->getNumParams() < 2)
7099+
return nullptr;
7100+
return llvm::StructType::get(
7101+
Ctx,
7102+
{FT->getReturnType(), FT->getParamType(1)->getPointerElementType()});
7103+
70897104
case OpCode::LinAlgMatrixSetElement:
70907105
if (FT->getNumParams() < 4)
70917106
return nullptr;
70927107
return llvm::StructType::get(
70937108
Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)});
70947109

7110+
case OpCode::LinAlgMatrixStoreToMemory:
7111+
case OpCode::LinAlgMatrixAccumulateToMemory:
7112+
if (FT->getNumParams() < 3)
7113+
return nullptr;
7114+
return llvm::StructType::get(
7115+
Ctx,
7116+
{FT->getParamType(1), FT->getParamType(2)->getPointerElementType()});
7117+
70957118
case OpCode::LinAlgMatrixMultiply:
70967119
case OpCode::LinAlgMatrixAccumulate:
70977120
case OpCode::LinAlgMatVecMul:

0 commit comments

Comments
 (0)