@@ -2863,8 +2863,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
28632863 " linAlgMatrixLoadFromMemory" ,
28642864 Attribute::None,
28652865 2 ,
2866- {{0x200 }, {0x800 }},
2867- {{0x0 }, {0x0 }}}, // Overloads: o,a
2866+ {{0x200 }, {0x63 }},
2867+ {{0x0 }, {0x0 }}}, // Overloads: o,hfwi
28682868 {OC::LinAlgMatrixLength,
28692869 " LinAlgMatrixLength" ,
28702870 OCC::LinAlgMatrixLength,
@@ -2911,8 +2911,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29112911 " linAlgMatrixStoreToMemory" ,
29122912 Attribute::None,
29132913 2 ,
2914- {{0x200 }, {0x800 }},
2915- {{0x0 }, {0x0 }}}, // Overloads: o,a
2914+ {{0x200 }, {0x63 }},
2915+ {{0x0 }, {0x0 }}}, // Overloads: o,hfwi
29162916 {OC::LinAlgMatrixQueryAccumulatorLayout,
29172917 " LinAlgMatrixQueryAccumulatorLayout" ,
29182918 OCC::LinAlgMatrixQueryAccumulatorLayout,
@@ -2967,8 +2967,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29672967 " linAlgMatrixAccumulateToMemory" ,
29682968 Attribute::None,
29692969 2 ,
2970- {{0x200 }, {0x800 }},
2971- {{0x0 }, {0x0 }}}, // Overloads: o,a
2970+ {{0x200 }, {0x63 }},
2971+ {{0x0 }, {0x0 }}}, // Overloads: o,hfwi
29722972 {OC::LinAlgMatrixOuterProduct,
29732973 " LinAlgMatrixOuterProduct" ,
29742974 OCC::LinAlgMatrixOuterProduct,
@@ -3152,8 +3152,6 @@ unsigned OP::GetTypeSlot(Type *pType) {
31523152 return TS_Extended;
31533153 case Type::VectorTyID:
31543154 return TS_Vector;
3155- case Type::ArrayTyID:
3156- return TS_Array;
31573155 default :
31583156 break ;
31593157 }
@@ -3194,12 +3192,6 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
31943192 GetOverloadTypeName (OP::GetTypeSlot (VecTy->getElementType ()))))
31953193 .toStringRef (Storage);
31963194 }
3197- case TS_Array: {
3198- if (Ty->isPointerTy ())
3199- Ty = Ty->getPointerElementType ();
3200- ArrayType *ArrTy = cast<ArrayType>(Ty);
3201- return GetOverloadTypeName (OP::GetTypeSlot (ArrTy->getArrayElementType ()));
3202- }
32033195 case TS_Extended: {
32043196 DXASSERT (isa<StructType>(Ty),
32053197 " otherwise, extended overload type not wrapped in struct type." );
@@ -4332,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
43324324#define VEC2 (_y ) A(VectorType::get(_y, 2 ))
43334325#define VEC4 (_y ) A(GetStructVectorType(4 , _y))
43344326#define VEC9 (_y ) A(VectorType::get(_y, 9 ))
4327+ #define TGSM (_y ) A(PointerType::get(_y, DXIL::kTGSMAddrSpace ))
43354328
43364329// Extended Overload types are wrapped in an anonymous struct
4337- #define EXT (_y ) A( cast<StructType>(pOverloadType)->getElementType (_y) )
4330+ #define EXT (_y ) cast<StructType>(pOverloadType)->getElementType (_y)
43384331
43394332 /* <py::lines('OPCODE-OLOAD-FUNCS')>hctdb_instrhelp.get_oloads_funcs()</py>*/
43404333 switch (opCode) { // return opCode
@@ -6445,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64456438
64466439 // Linear Algebra Operations
64476440 case OpCode::MatVecMul:
6448- EXT (0 );
6441+ A ( EXT (0 ) );
64496442 A (pI32);
6450- EXT (1 );
6443+ A ( EXT (1 ) );
64516444 A (pI1);
64526445 A (pI32);
64536446 A (pRes);
@@ -6461,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64616454 A (pI1);
64626455 break ;
64636456 case OpCode::MatVecMulAdd:
6464- EXT (0 );
6457+ A ( EXT (0 ) );
64656458 A (pI32);
6466- EXT (1 );
6459+ A ( EXT (1 ) );
64676460 A (pI1);
64686461 A (pI32);
64696462 A (pRes);
@@ -6482,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64826475 case OpCode::OuterProductAccumulate:
64836476 A (pV);
64846477 A (pI32);
6485- EXT (0 );
6486- EXT (1 );
6478+ A ( EXT (0 ) );
6479+ A ( EXT (1 ) );
64876480 A (pRes);
64886481 A (pI32);
64896482 A (pI32);
@@ -6586,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65866579
65876580 // Linear Algebra Operations
65886581 case OpCode::LinAlgMatrixMultiplyAccumulate:
6589- EXT (0 );
6582+ A ( EXT (0 ) );
65906583 A (pI32);
6591- EXT (1 );
6592- EXT (2 );
6593- EXT (3 );
6584+ A ( EXT (1 ) );
6585+ A ( EXT (2 ) );
6586+ A ( EXT (3 ) );
65946587 break ;
65956588 case OpCode::LinAlgFillMatrix:
6596- EXT (0 );
6589+ A ( EXT (0 ) );
65976590 A (pI32);
6598- EXT (1 );
6591+ A ( EXT (1 ) );
65996592 break ;
66006593 case OpCode::LinAlgCopyConvertMatrix:
6601- EXT (0 );
6594+ A ( EXT (0 ) );
66026595 A (pI32);
6603- EXT (1 );
6596+ A ( EXT (1 ) );
66046597 A (pI1);
66056598 break ;
66066599 case OpCode::LinAlgMatrixLoadFromDescriptor:
@@ -6612,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66126605 A (pI32);
66136606 break ;
66146607 case OpCode::LinAlgMatrixLoadFromMemory:
6615- EXT (0 );
6608+ A ( EXT (0 ) );
66166609 A (pI32);
6617- EXT (1 );
6610+ TGSM ( EXT (1 ) );
66186611 A (pI32);
66196612 A (pI32);
66206613 A (pI32);
@@ -6631,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66316624 A (pI32);
66326625 break ;
66336626 case OpCode::LinAlgMatrixGetElement:
6634- EXT (0 );
6627+ A ( EXT (0 ) );
66356628 A (pI32);
6636- EXT (1 );
6629+ A ( EXT (1 ) );
66376630 A (pI32);
66386631 break ;
66396632 case OpCode::LinAlgMatrixSetElement:
6640- EXT (0 );
6633+ A ( EXT (0 ) );
66416634 A (pI32);
6642- EXT (1 );
6635+ A ( EXT (1 ) );
66436636 A (pI32);
6644- EXT (2 );
6637+ A ( EXT (2 ) );
66456638 break ;
66466639 case OpCode::LinAlgMatrixStoreToDescriptor:
66476640 A (pV);
@@ -6655,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66556648 case OpCode::LinAlgMatrixStoreToMemory:
66566649 A (pV);
66576650 A (pI32);
6658- EXT (0 );
6659- EXT (1 );
6651+ A ( EXT (0 ) );
6652+ TGSM ( EXT (1 ) );
66606653 A (pI32);
66616654 A (pI32);
66626655 A (pI32);
@@ -6666,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66666659 A (pI32);
66676660 break ;
66686661 case OpCode::LinAlgMatrixMultiply:
6669- EXT (0 );
6662+ A ( EXT (0 ) );
66706663 A (pI32);
6671- EXT (1 );
6672- EXT (2 );
6664+ A ( EXT (1 ) );
6665+ A ( EXT (2 ) );
66736666 break ;
66746667 case OpCode::LinAlgMatrixAccumulate:
6675- EXT (0 );
6668+ A ( EXT (0 ) );
66766669 A (pI32);
6677- EXT (1 );
6678- EXT (2 );
6670+ A ( EXT (1 ) );
6671+ A ( EXT (2 ) );
66796672 break ;
66806673 case OpCode::LinAlgMatVecMul:
6681- EXT (0 );
6674+ A ( EXT (0 ) );
66826675 A (pI32);
6683- EXT (1 );
6684- EXT (2 );
6676+ A ( EXT (1 ) );
6677+ A ( EXT (2 ) );
66856678 A (pI32);
66866679 break ;
66876680 case OpCode::LinAlgMatVecMulAdd:
6688- EXT (0 );
6681+ A ( EXT (0 ) );
66896682 A (pI32);
6690- EXT (1 );
6691- EXT (2 );
6683+ A ( EXT (1 ) );
6684+ A ( EXT (2 ) );
66926685 A (pI32);
6693- EXT (3 );
6686+ A ( EXT (3 ) );
66946687 A (pI32);
66956688 break ;
66966689 case OpCode::LinAlgMatrixAccumulateToDescriptor:
@@ -6705,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
67056698 case OpCode::LinAlgMatrixAccumulateToMemory:
67066699 A (pV);
67076700 A (pI32);
6708- EXT (0 );
6709- EXT (1 );
6701+ A ( EXT (0 ) );
6702+ TGSM ( EXT (1 ) );
67106703 A (pI32);
67116704 A (pI32);
67126705 A (pI32);
67136706 break ;
67146707 case OpCode::LinAlgMatrixOuterProduct:
6715- EXT (0 );
6708+ A ( EXT (0 ) );
67166709 A (pI32);
6717- EXT (1 );
6718- EXT (2 );
6710+ A ( EXT (1 ) );
6711+ A ( EXT (2 ) );
67196712 break ;
67206713
67216714 //
@@ -7082,16 +7075,13 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70827075 case OpCode::MatVecMulAdd:
70837076 case OpCode::LinAlgFillMatrix:
70847077 case OpCode::LinAlgCopyConvertMatrix:
7085- case OpCode::LinAlgMatrixLoadFromMemory:
70867078 case OpCode::LinAlgMatrixGetElement:
70877079 if (FT->getNumParams () < 2 )
70887080 return nullptr ;
70897081 return llvm::StructType::get (Ctx,
70907082 {FT->getReturnType (), FT->getParamType (1 )});
70917083
70927084 case OpCode::OuterProductAccumulate:
7093- case OpCode::LinAlgMatrixStoreToMemory:
7094- case OpCode::LinAlgMatrixAccumulateToMemory:
70957085 if (FT->getNumParams () < 3 )
70967086 return nullptr ;
70977087 return llvm::StructType::get (Ctx,
@@ -7104,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
71047094 {FT->getReturnType (), FT->getParamType (1 ),
71057095 FT->getParamType (2 ), FT->getParamType (3 )});
71067096
7097+ case OpCode::LinAlgMatrixLoadFromMemory:
7098+ if (FT->getNumParams () < 2 )
7099+ return nullptr ;
7100+ return llvm::StructType::get (
7101+ Ctx,
7102+ {FT->getReturnType (), FT->getParamType (1 )->getPointerElementType ()});
7103+
71077104 case OpCode::LinAlgMatrixSetElement:
71087105 if (FT->getNumParams () < 4 )
71097106 return nullptr ;
71107107 return llvm::StructType::get (
71117108 Ctx, {FT->getReturnType (), FT->getParamType (1 ), FT->getParamType (3 )});
71127109
7110+ case OpCode::LinAlgMatrixStoreToMemory:
7111+ case OpCode::LinAlgMatrixAccumulateToMemory:
7112+ if (FT->getNumParams () < 3 )
7113+ return nullptr ;
7114+ return llvm::StructType::get (
7115+ Ctx,
7116+ {FT->getParamType (1 ), FT->getParamType (2 )->getPointerElementType ()});
7117+
71137118 case OpCode::LinAlgMatrixMultiply:
71147119 case OpCode::LinAlgMatrixAccumulate:
71157120 case OpCode::LinAlgMatVecMul:
0 commit comments