Skip to content

Commit 6ff022c

Browse files
committed
Rework based on feedback
1 parent 28f9487 commit 6ff022c

8 files changed

Lines changed: 117 additions & 91 deletions

File tree

include/dxc/DXIL/DxilOperations.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ class OP {
212212
TS_UDT = 8, // Ex: %"struct.MyStruct" *
213213
TS_Object = 9, // Ex: %"class.StructuredBuffer<Foo>"
214214
TS_Vector = 10, // Ex: <8 x i16>
215-
TS_Array = 11, // Ex: [8 x float]
216215
TS_MaskBitCount, // Types used in Mask end here
217216
// TS_Extended is only used to identify the unnamed struct type used to wrap
218217
// multiple overloads when using GetTypeSlot.

lib/DXIL/DxilOperations.cpp

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,8 +2863,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
28632863
"linAlgMatrixLoadFromMemory",
28642864
Attribute::None,
28652865
2,
2866-
{{0x200}, {0x800}},
2867-
{{0x0}, {0x0}}}, // Overloads: o,a
2866+
{{0x200}, {0x63}},
2867+
{{0x0}, {0x0}}}, // Overloads: o,hfwi
28682868
{OC::LinAlgMatrixLength,
28692869
"LinAlgMatrixLength",
28702870
OCC::LinAlgMatrixLength,
@@ -2911,8 +2911,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29112911
"linAlgMatrixStoreToMemory",
29122912
Attribute::None,
29132913
2,
2914-
{{0x200}, {0x800}},
2915-
{{0x0}, {0x0}}}, // Overloads: o,a
2914+
{{0x200}, {0x63}},
2915+
{{0x0}, {0x0}}}, // Overloads: o,hfwi
29162916
{OC::LinAlgMatrixQueryAccumulatorLayout,
29172917
"LinAlgMatrixQueryAccumulatorLayout",
29182918
OCC::LinAlgMatrixQueryAccumulatorLayout,
@@ -2967,8 +2967,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29672967
"linAlgMatrixAccumulateToMemory",
29682968
Attribute::None,
29692969
2,
2970-
{{0x200}, {0x800}},
2971-
{{0x0}, {0x0}}}, // Overloads: o,a
2970+
{{0x200}, {0x63}},
2971+
{{0x0}, {0x0}}}, // Overloads: o,hfwi
29722972
{OC::LinAlgMatrixOuterProduct,
29732973
"LinAlgMatrixOuterProduct",
29742974
OCC::LinAlgMatrixOuterProduct,
@@ -3152,8 +3152,6 @@ unsigned OP::GetTypeSlot(Type *pType) {
31523152
return TS_Extended;
31533153
case Type::VectorTyID:
31543154
return TS_Vector;
3155-
case Type::ArrayTyID:
3156-
return TS_Array;
31573155
default:
31583156
break;
31593157
}
@@ -3194,12 +3192,6 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
31943192
GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType()))))
31953193
.toStringRef(Storage);
31963194
}
3197-
case TS_Array: {
3198-
if (Ty->isPointerTy())
3199-
Ty = Ty->getPointerElementType();
3200-
ArrayType *ArrTy = cast<ArrayType>(Ty);
3201-
return GetOverloadTypeName(OP::GetTypeSlot(ArrTy->getArrayElementType()));
3202-
}
32033195
case TS_Extended: {
32043196
DXASSERT(isa<StructType>(Ty),
32053197
"otherwise, extended overload type not wrapped in struct type.");
@@ -4332,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
43324324
#define VEC2(_y) A(VectorType::get(_y, 2))
43334325
#define VEC4(_y) A(GetStructVectorType(4, _y))
43344326
#define VEC9(_y) A(VectorType::get(_y, 9))
4327+
#define TGSM(_y) A(PointerType::get(_y, DXIL::kTGSMAddrSpace))
43354328

43364329
// Extended Overload types are wrapped in an anonymous struct
4337-
#define EXT(_y) A(cast<StructType>(pOverloadType)->getElementType(_y))
4330+
#define EXT(_y) cast<StructType>(pOverloadType)->getElementType(_y)
43384331

43394332
/* <py::lines('OPCODE-OLOAD-FUNCS')>hctdb_instrhelp.get_oloads_funcs()</py>*/
43404333
switch (opCode) { // return opCode
@@ -6445,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64456438

64466439
// Linear Algebra Operations
64476440
case OpCode::MatVecMul:
6448-
EXT(0);
6441+
A(EXT(0));
64496442
A(pI32);
6450-
EXT(1);
6443+
A(EXT(1));
64516444
A(pI1);
64526445
A(pI32);
64536446
A(pRes);
@@ -6461,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64616454
A(pI1);
64626455
break;
64636456
case OpCode::MatVecMulAdd:
6464-
EXT(0);
6457+
A(EXT(0));
64656458
A(pI32);
6466-
EXT(1);
6459+
A(EXT(1));
64676460
A(pI1);
64686461
A(pI32);
64696462
A(pRes);
@@ -6482,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
64826475
case OpCode::OuterProductAccumulate:
64836476
A(pV);
64846477
A(pI32);
6485-
EXT(0);
6486-
EXT(1);
6478+
A(EXT(0));
6479+
A(EXT(1));
64876480
A(pRes);
64886481
A(pI32);
64896482
A(pI32);
@@ -6586,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65866579

65876580
// Linear Algebra Operations
65886581
case OpCode::LinAlgMatrixMultiplyAccumulate:
6589-
EXT(0);
6582+
A(EXT(0));
65906583
A(pI32);
6591-
EXT(1);
6592-
EXT(2);
6593-
EXT(3);
6584+
A(EXT(1));
6585+
A(EXT(2));
6586+
A(EXT(3));
65946587
break;
65956588
case OpCode::LinAlgFillMatrix:
6596-
EXT(0);
6589+
A(EXT(0));
65976590
A(pI32);
6598-
EXT(1);
6591+
A(EXT(1));
65996592
break;
66006593
case OpCode::LinAlgCopyConvertMatrix:
6601-
EXT(0);
6594+
A(EXT(0));
66026595
A(pI32);
6603-
EXT(1);
6596+
A(EXT(1));
66046597
A(pI1);
66056598
break;
66066599
case OpCode::LinAlgMatrixLoadFromDescriptor:
@@ -6612,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66126605
A(pI32);
66136606
break;
66146607
case OpCode::LinAlgMatrixLoadFromMemory:
6615-
EXT(0);
6608+
A(EXT(0));
66166609
A(pI32);
6617-
EXT(1);
6610+
TGSM(EXT(1));
66186611
A(pI32);
66196612
A(pI32);
66206613
A(pI32);
@@ -6631,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66316624
A(pI32);
66326625
break;
66336626
case OpCode::LinAlgMatrixGetElement:
6634-
EXT(0);
6627+
A(EXT(0));
66356628
A(pI32);
6636-
EXT(1);
6629+
A(EXT(1));
66376630
A(pI32);
66386631
break;
66396632
case OpCode::LinAlgMatrixSetElement:
6640-
EXT(0);
6633+
A(EXT(0));
66416634
A(pI32);
6642-
EXT(1);
6635+
A(EXT(1));
66436636
A(pI32);
6644-
EXT(2);
6637+
A(EXT(2));
66456638
break;
66466639
case OpCode::LinAlgMatrixStoreToDescriptor:
66476640
A(pV);
@@ -6655,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66556648
case OpCode::LinAlgMatrixStoreToMemory:
66566649
A(pV);
66576650
A(pI32);
6658-
EXT(0);
6659-
EXT(1);
6651+
A(EXT(0));
6652+
TGSM(EXT(1));
66606653
A(pI32);
66616654
A(pI32);
66626655
A(pI32);
@@ -6666,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66666659
A(pI32);
66676660
break;
66686661
case OpCode::LinAlgMatrixMultiply:
6669-
EXT(0);
6662+
A(EXT(0));
66706663
A(pI32);
6671-
EXT(1);
6672-
EXT(2);
6664+
A(EXT(1));
6665+
A(EXT(2));
66736666
break;
66746667
case OpCode::LinAlgMatrixAccumulate:
6675-
EXT(0);
6668+
A(EXT(0));
66766669
A(pI32);
6677-
EXT(1);
6678-
EXT(2);
6670+
A(EXT(1));
6671+
A(EXT(2));
66796672
break;
66806673
case OpCode::LinAlgMatVecMul:
6681-
EXT(0);
6674+
A(EXT(0));
66826675
A(pI32);
6683-
EXT(1);
6684-
EXT(2);
6676+
A(EXT(1));
6677+
A(EXT(2));
66856678
A(pI32);
66866679
break;
66876680
case OpCode::LinAlgMatVecMulAdd:
6688-
EXT(0);
6681+
A(EXT(0));
66896682
A(pI32);
6690-
EXT(1);
6691-
EXT(2);
6683+
A(EXT(1));
6684+
A(EXT(2));
66926685
A(pI32);
6693-
EXT(3);
6686+
A(EXT(3));
66946687
A(pI32);
66956688
break;
66966689
case OpCode::LinAlgMatrixAccumulateToDescriptor:
@@ -6705,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
67056698
case OpCode::LinAlgMatrixAccumulateToMemory:
67066699
A(pV);
67076700
A(pI32);
6708-
EXT(0);
6709-
EXT(1);
6701+
A(EXT(0));
6702+
TGSM(EXT(1));
67106703
A(pI32);
67116704
A(pI32);
67126705
A(pI32);
67136706
break;
67146707
case OpCode::LinAlgMatrixOuterProduct:
6715-
EXT(0);
6708+
A(EXT(0));
67166709
A(pI32);
6717-
EXT(1);
6718-
EXT(2);
6710+
A(EXT(1));
6711+
A(EXT(2));
67196712
break;
67206713

67216714
//
@@ -7082,16 +7075,13 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70827075
case OpCode::MatVecMulAdd:
70837076
case OpCode::LinAlgFillMatrix:
70847077
case OpCode::LinAlgCopyConvertMatrix:
7085-
case OpCode::LinAlgMatrixLoadFromMemory:
70867078
case OpCode::LinAlgMatrixGetElement:
70877079
if (FT->getNumParams() < 2)
70887080
return nullptr;
70897081
return llvm::StructType::get(Ctx,
70907082
{FT->getReturnType(), FT->getParamType(1)});
70917083

70927084
case OpCode::OuterProductAccumulate:
7093-
case OpCode::LinAlgMatrixStoreToMemory:
7094-
case OpCode::LinAlgMatrixAccumulateToMemory:
70957085
if (FT->getNumParams() < 3)
70967086
return nullptr;
70977087
return llvm::StructType::get(Ctx,
@@ -7104,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
71047094
{FT->getReturnType(), FT->getParamType(1),
71057095
FT->getParamType(2), FT->getParamType(3)});
71067096

7097+
case OpCode::LinAlgMatrixLoadFromMemory:
7098+
if (FT->getNumParams() < 2)
7099+
return nullptr;
7100+
return llvm::StructType::get(
7101+
Ctx,
7102+
{FT->getReturnType(), FT->getParamType(1)->getPointerElementType()});
7103+
71077104
case OpCode::LinAlgMatrixSetElement:
71087105
if (FT->getNumParams() < 4)
71097106
return nullptr;
71107107
return llvm::StructType::get(
71117108
Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)});
71127109

7110+
case OpCode::LinAlgMatrixStoreToMemory:
7111+
case OpCode::LinAlgMatrixAccumulateToMemory:
7112+
if (FT->getNumParams() < 3)
7113+
return nullptr;
7114+
return llvm::StructType::get(
7115+
Ctx,
7116+
{FT->getParamType(1), FT->getParamType(2)->getPointerElementType()});
7117+
71137118
case OpCode::LinAlgMatrixMultiply:
71147119
case OpCode::LinAlgMatrixAccumulate:
71157120
case OpCode::LinAlgMatVecMul:

lib/HLSL/HLOperationLower.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7242,11 +7242,15 @@ Value *TranslateLinAlgMatrixLoadFromMemory(
72427242
Value *Stride = CI->getArgOperand(4);
72437243
Value *Layout = CI->getArgOperand(5);
72447244

7245+
Value *Zero = Builder.getInt32(0);
7246+
Value *ArrPtr = Builder.CreateGEP(Arr, {Zero, Zero});
7247+
Type *ArrEltTy = ArrPtr->getType()->getPointerElementType();
7248+
72457249
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7246-
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {MatrixType, Arr->getType()});
7250+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {MatrixType, ArrEltTy});
72477251

72487252
Value *Matrix =
7249-
Builder.CreateCall(DxilFunc, {OpArg, Arr, Offset, Stride, Layout});
7253+
Builder.CreateCall(DxilFunc, {OpArg, ArrPtr, Offset, Stride, Layout});
72507254
Builder.CreateStore(Matrix, MatrixPtr);
72517255

72527256
return nullptr;
@@ -7265,12 +7269,15 @@ Value *TranslateLinAlgMatrixAccumStoreToMemory(
72657269
Value *Stride = CI->getArgOperand(4);
72667270
Value *Layout = CI->getArgOperand(5);
72677271

7272+
Value *Zero = Builder.getInt32(0);
7273+
Value *ArrPtr = Builder.CreateGEP(Arr, {Zero, Zero});
7274+
Type *ArrEltTy = ArrPtr->getType()->getPointerElementType();
7275+
72687276
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7269-
Function *DxilFunc =
7270-
HlslOp->GetOpFunc(OpCode, {Matrix->getType(), Arr->getType()});
7277+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {Matrix->getType(), ArrEltTy});
72717278

72727279
return Builder.CreateCall(DxilFunc,
7273-
{OpArg, Matrix, Arr, Offset, Stride, Layout});
7280+
{OpArg, Matrix, ArrPtr, Offset, Stride, Layout});
72747281
}
72757282

72767283
} // namespace

tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixaccumulatetomemory/nominal.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ groupshared float SharedArr[64];
88
void main() {
99
// CHECK-LABEL: define void @main()
1010

11-
// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC4M5N4U1S2.f32(i32 -2147483620, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout)
11+
// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC4M5N4U1S2.f32(i32 -2147483620, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout)
1212
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
1313
__builtin_LinAlg_MatrixAccumulateToMemory(mat, SharedArr, 1, 2, 3);
1414
}

tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixloadfrommemory/nominal.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ groupshared float SharedArr[64];
88
void main() {
99
// CHECK-LABEL: define void @main()
1010

11-
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixLoadFromMemory.mC4M5N4U1S2.f32(i32 -2147483633, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixLoadFromMemory(memory,offset,stride,layout)
11+
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixLoadFromMemory.mC4M5N4U1S2.f32(i32 -2147483633, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixLoadFromMemory(memory,offset,stride,layout)
1212
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
1313
__builtin_LinAlg_MatrixLoadFromMemory(mat, SharedArr, 1, 2, 3);
1414
}

tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixstoretomemory/nominal.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ groupshared float SharedArr[64];
88
void main() {
99
// CHECK-LABEL: define void @main()
1010

11-
// CHECK: call void @dx.op.linAlgMatrixStoreToMemory.mC4M5N4U1S2.f32(i32 -2147483627, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixStoreToMemory(matrix,memory,offset,stride,layout)
11+
// CHECK: call void @dx.op.linAlgMatrixStoreToMemory.mC4M5N4U1S2.f32(i32 -2147483627, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixStoreToMemory(matrix,memory,offset,stride,layout)
1212
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
1313
__builtin_LinAlg_MatrixStoreToMemory(mat, SharedArr, 1, 2, 3);
1414
}

0 commit comments

Comments
 (0)