@@ -3166,26 +3166,33 @@ const char *OP::GetOverloadTypeName(unsigned TypeSlot) {
31663166StringRef OP::GetTypeName (Type *Ty, SmallVectorImpl<char > &Storage) {
31673167 DXASSERT (!Ty->isVoidTy (), " must not pass void type here" );
31683168 unsigned TypeSlot = OP::GetTypeSlot (Ty);
3169+
31693170 if (TypeSlot < TS_BasicCount) {
31703171 return GetOverloadTypeName (TypeSlot);
3171- } else if (TypeSlot == TS_UDT) {
3172+ }
3173+
3174+ switch (TypeSlot) {
3175+ case TS_UDT: {
31723176 if (Ty->isPointerTy ())
31733177 Ty = Ty->getPointerElementType ();
31743178 StructType *ST = cast<StructType>(Ty);
31753179 return ST->getStructName ();
3176- } else if (TypeSlot == TS_Object) {
3180+ }
3181+ case TS_Object: {
31773182 StructType *ST = cast<StructType>(Ty);
31783183 if (dxilutil::IsHLSLLinAlgMatrixType (Ty))
31793184 return (Twine (" m" ) + Twine (dxilutil::GetHLSLLinAlgMatrixTypeMangling (ST)))
31803185 .toStringRef (Storage);
31813186 return ST->getStructName ();
3182- } else if (TypeSlot == TS_Vector) {
3187+ }
3188+ case TS_Vector: {
31833189 VectorType *VecTy = cast<VectorType>(Ty);
31843190 return (Twine (" v" ) + Twine (VecTy->getNumElements ()) +
31853191 Twine (
31863192 GetOverloadTypeName (OP::GetTypeSlot (VecTy->getElementType ()))))
31873193 .toStringRef (Storage);
3188- } else if (TypeSlot == TS_Extended) {
3194+ }
3195+ case TS_Extended: {
31893196 DXASSERT (isa<StructType>(Ty),
31903197 " otherwise, extended overload type not wrapped in struct type." );
31913198 StructType *ST = cast<StructType>(Ty);
@@ -3200,11 +3207,14 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
32003207 OS << GetTypeName (ST->getElementType (I), TempStr);
32013208 }
32023209 return OS.str ();
3203- } else {
3204- raw_svector_ostream OS (Storage);
3205- Ty->print (OS);
3206- return OS.str ();
32073210 }
3211+ default :
3212+ break ;
3213+ }
3214+
3215+ raw_svector_ostream OS (Storage);
3216+ Ty->print (OS);
3217+ return OS.str ();
32083218}
32093219
32103220StringRef OP::ConstructOverloadName (Type *Ty, DXIL::OpCode opCode,
@@ -4314,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
43144324#define VEC2 (_y ) A(VectorType::get(_y, 2 ))
43154325#define VEC4 (_y ) A(GetStructVectorType(4 , _y))
43164326#define VEC9 (_y ) A(VectorType::get(_y, 9 ))
4327+ #define TGSM (_y ) A(PointerType::get(_y, DXIL::kTGSMAddrSpace ))
43174328
43184329// Extended Overload types are wrapped in an anonymous struct
4319- #define EXT (_y ) A( cast<StructType>(pOverloadType)->getElementType (_y) )
4330+ #define EXT (_y ) cast<StructType>(pOverloadType)->getElementType (_y)
43204331
43214332 /* <py::lines('OPCODE-OLOAD-FUNCS')>hctdb_instrhelp.get_oloads_funcs()</py>*/
43224333 switch (opCode) { // return opCode
@@ -6427,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64276438
64286439 // Linear Algebra Operations
64296440 case OpCode::MatVecMul:
6430- EXT (0 );
6441+ A ( EXT (0 ) );
64316442 A (pI32);
6432- EXT (1 );
6443+ A ( EXT (1 ) );
64336444 A (pI1);
64346445 A (pI32);
64356446 A (pRes);
@@ -6443,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64436454 A (pI1);
64446455 break ;
64456456 case OpCode::MatVecMulAdd:
6446- EXT (0 );
6457+ A ( EXT (0 ) );
64476458 A (pI32);
6448- EXT (1 );
6459+ A ( EXT (1 ) );
64496460 A (pI1);
64506461 A (pI32);
64516462 A (pRes);
@@ -6464,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64646475 case OpCode::OuterProductAccumulate:
64656476 A (pV);
64666477 A (pI32);
6467- EXT (0 );
6468- EXT (1 );
6478+ A ( EXT (0 ) );
6479+ A ( EXT (1 ) );
64696480 A (pRes);
64706481 A (pI32);
64716482 A (pI32);
@@ -6568,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65686579
65696580 // Linear Algebra Operations
65706581 case OpCode::LinAlgMatrixMultiplyAccumulate:
6571- EXT (0 );
6582+ A ( EXT (0 ) );
65726583 A (pI32);
6573- EXT (1 );
6574- EXT (2 );
6575- EXT (3 );
6584+ A ( EXT (1 ) );
6585+ A ( EXT (2 ) );
6586+ A ( EXT (3 ) );
65766587 break ;
65776588 case OpCode::LinAlgFillMatrix:
6578- EXT (0 );
6589+ A ( EXT (0 ) );
65796590 A (pI32);
6580- EXT (1 );
6591+ A ( EXT (1 ) );
65816592 break ;
65826593 case OpCode::LinAlgCopyConvertMatrix:
6583- EXT (0 );
6594+ A ( EXT (0 ) );
65846595 A (pI32);
6585- EXT (1 );
6596+ A ( EXT (1 ) );
65866597 A (pI1);
65876598 break ;
65886599 case OpCode::LinAlgMatrixLoadFromDescriptor:
@@ -6594,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65946605 A (pI32);
65956606 break ;
65966607 case OpCode::LinAlgMatrixLoadFromMemory:
6597- EXT (0 );
6608+ A ( EXT (0 ) );
65986609 A (pI32);
6599- EXT (1 );
6610+ TGSM ( EXT (1 ) );
66006611 A (pI32);
66016612 A (pI32);
66026613 A (pI32);
@@ -6613,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66136624 A (pI32);
66146625 break ;
66156626 case OpCode::LinAlgMatrixGetElement:
6616- EXT (0 );
6627+ A ( EXT (0 ) );
66176628 A (pI32);
6618- EXT (1 );
6629+ A ( EXT (1 ) );
66196630 A (pI32);
66206631 break ;
66216632 case OpCode::LinAlgMatrixSetElement:
6622- EXT (0 );
6633+ A ( EXT (0 ) );
66236634 A (pI32);
6624- EXT (1 );
6635+ A ( EXT (1 ) );
66256636 A (pI32);
6626- EXT (2 );
6637+ A ( EXT (2 ) );
66276638 break ;
66286639 case OpCode::LinAlgMatrixStoreToDescriptor:
66296640 A (pV);
@@ -6637,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66376648 case OpCode::LinAlgMatrixStoreToMemory:
66386649 A (pV);
66396650 A (pI32);
6640- EXT (0 );
6641- EXT (1 );
6651+ A ( EXT (0 ) );
6652+ TGSM ( EXT (1 ) );
66426653 A (pI32);
66436654 A (pI32);
66446655 A (pI32);
@@ -6648,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66486659 A (pI32);
66496660 break ;
66506661 case OpCode::LinAlgMatrixMultiply:
6651- EXT (0 );
6662+ A ( EXT (0 ) );
66526663 A (pI32);
6653- EXT (1 );
6654- EXT (2 );
6664+ A ( EXT (1 ) );
6665+ A ( EXT (2 ) );
66556666 break ;
66566667 case OpCode::LinAlgMatrixAccumulate:
6657- EXT (0 );
6668+ A ( EXT (0 ) );
66586669 A (pI32);
6659- EXT (1 );
6660- EXT (2 );
6670+ A ( EXT (1 ) );
6671+ A ( EXT (2 ) );
66616672 break ;
66626673 case OpCode::LinAlgMatVecMul:
6663- EXT (0 );
6674+ A ( EXT (0 ) );
66646675 A (pI32);
6665- EXT (1 );
6666- EXT (2 );
6676+ A ( EXT (1 ) );
6677+ A ( EXT (2 ) );
66676678 A (pI32);
66686679 break ;
66696680 case OpCode::LinAlgMatVecMulAdd:
6670- EXT (0 );
6681+ A ( EXT (0 ) );
66716682 A (pI32);
6672- EXT (1 );
6673- EXT (2 );
6683+ A ( EXT (1 ) );
6684+ A ( EXT (2 ) );
66746685 A (pI32);
6675- EXT (3 );
6686+ A ( EXT (3 ) );
66766687 A (pI32);
66776688 break ;
66786689 case OpCode::LinAlgMatrixAccumulateToDescriptor:
@@ -6687,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66876698 case OpCode::LinAlgMatrixAccumulateToMemory:
66886699 A (pV);
66896700 A (pI32);
6690- EXT (0 );
6691- EXT (1 );
6701+ A ( EXT (0 ) );
6702+ TGSM ( EXT (1 ) );
66926703 A (pI32);
66936704 A (pI32);
66946705 A (pI32);
66956706 break ;
66966707 case OpCode::LinAlgMatrixOuterProduct:
6697- EXT (0 );
6708+ A ( EXT (0 ) );
66986709 A (pI32);
6699- EXT (1 );
6700- EXT (2 );
6710+ A ( EXT (1 ) );
6711+ A ( EXT (2 ) );
67016712 break ;
67026713
67036714 //
@@ -7064,16 +7075,13 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70647075 case OpCode::MatVecMulAdd:
70657076 case OpCode::LinAlgFillMatrix:
70667077 case OpCode::LinAlgCopyConvertMatrix:
7067- case OpCode::LinAlgMatrixLoadFromMemory:
70687078 case OpCode::LinAlgMatrixGetElement:
70697079 if (FT->getNumParams () < 2 )
70707080 return nullptr ;
70717081 return llvm::StructType::get (Ctx,
70727082 {FT->getReturnType (), FT->getParamType (1 )});
70737083
70747084 case OpCode::OuterProductAccumulate:
7075- case OpCode::LinAlgMatrixStoreToMemory:
7076- case OpCode::LinAlgMatrixAccumulateToMemory:
70777085 if (FT->getNumParams () < 3 )
70787086 return nullptr ;
70797087 return llvm::StructType::get (Ctx,
@@ -7086,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70867094 {FT->getReturnType (), FT->getParamType (1 ),
70877095 FT->getParamType (2 ), FT->getParamType (3 )});
70887096
7097+ case OpCode::LinAlgMatrixLoadFromMemory:
7098+ if (FT->getNumParams () < 2 )
7099+ return nullptr ;
7100+ return llvm::StructType::get (
7101+ Ctx,
7102+ {FT->getReturnType (), FT->getParamType (1 )->getPointerElementType ()});
7103+
70897104 case OpCode::LinAlgMatrixSetElement:
70907105 if (FT->getNumParams () < 4 )
70917106 return nullptr ;
70927107 return llvm::StructType::get (
70937108 Ctx, {FT->getReturnType (), FT->getParamType (1 ), FT->getParamType (3 )});
70947109
7110+ case OpCode::LinAlgMatrixStoreToMemory:
7111+ case OpCode::LinAlgMatrixAccumulateToMemory:
7112+ if (FT->getNumParams () < 3 )
7113+ return nullptr ;
7114+ return llvm::StructType::get (
7115+ Ctx,
7116+ {FT->getParamType (1 ), FT->getParamType (2 )->getPointerElementType ()});
7117+
70957118 case OpCode::LinAlgMatrixMultiply:
70967119 case OpCode::LinAlgMatrixAccumulate:
70977120 case OpCode::LinAlgMatVecMul:
0 commit comments