Skip to content

Commit e520d0a

Browse files
committed
[SM6.10] Implement groupshared Builtins
Implements the Load/Store/Accumulate to memory groupshared builtins following the pattern of the previous builtins
1 parent fbc8aed commit e520d0a

16 files changed

Lines changed: 297 additions & 40 deletions

File tree

include/dxc/DXIL/DxilInstructions.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10651,14 +10651,14 @@ struct DxilInst_LinAlgMatrixLoadFromMemory {
1065110651
bool requiresUniformInputs() const { return false; }
1065210652
// Operand indexes
1065310653
enum OperandIdx {
10654-
arg_groupsharedArr = 1,
10654+
arg_memory = 1,
1065510655
arg_offset = 2,
1065610656
arg_stride = 3,
1065710657
arg_layout = 4,
1065810658
};
1065910659
// Accessors
10660-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(1); }
10661-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(1, val); }
10660+
llvm::Value *get_memory() const { return Instr->getOperand(1); }
10661+
void set_memory(llvm::Value *val) { Instr->setOperand(1, val); }
1066210662
llvm::Value *get_offset() const { return Instr->getOperand(2); }
1066310663
void set_offset(llvm::Value *val) { Instr->setOperand(2, val); }
1066410664
llvm::Value *get_stride() const { return Instr->getOperand(3); }
@@ -10854,16 +10854,16 @@ struct DxilInst_LinAlgMatrixStoreToMemory {
1085410854
// Operand indexes
1085510855
enum OperandIdx {
1085610856
arg_matrix = 1,
10857-
arg_groupsharedArr = 2,
10857+
arg_memory = 2,
1085810858
arg_offset = 3,
1085910859
arg_stride = 4,
1086010860
arg_layout = 5,
1086110861
};
1086210862
// Accessors
1086310863
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1086410864
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
10865-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
10866-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
10865+
llvm::Value *get_memory() const { return Instr->getOperand(2); }
10866+
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
1086710867
llvm::Value *get_offset() const { return Instr->getOperand(3); }
1086810868
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
1086910869
llvm::Value *get_stride() const { return Instr->getOperand(4); }
@@ -11091,16 +11091,16 @@ struct DxilInst_LinAlgMatrixAccumulateToMemory {
1109111091
// Operand indexes
1109211092
enum OperandIdx {
1109311093
arg_matrix = 1,
11094-
arg_groupsharedArr = 2,
11094+
arg_memory = 2,
1109511095
arg_offset = 3,
1109611096
arg_stride = 4,
1109711097
arg_layout = 5,
1109811098
};
1109911099
// Accessors
1110011100
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1110111101
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
11102-
llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); }
11103-
void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); }
11102+
llvm::Value *get_memory() const { return Instr->getOperand(2); }
11103+
void set_memory(llvm::Value *val) { Instr->setOperand(2, val); }
1110411104
llvm::Value *get_offset() const { return Instr->getOperand(3); }
1110511105
void set_offset(llvm::Value *val) { Instr->setOperand(3, val); }
1110611106
llvm::Value *get_stride() const { return Instr->getOperand(4); }

include/dxc/DXIL/DxilOperations.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class OP {
212212
TS_UDT = 8, // Ex: %"struct.MyStruct" *
213213
TS_Object = 9, // Ex: %"class.StructuredBuffer<Foo>"
214214
TS_Vector = 10, // Ex: <8 x i16>
215+
TS_Array = 11, // Ex: [8 x float]
215216
TS_MaskBitCount, // Types used in Mask end here
216217
// TS_Extended is only used to identify the unnamed struct type used to wrap
217218
// multiple overloads when using GetTypeSlot.

lib/DXIL/DxilOperations.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,8 +2863,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
28632863
"linAlgMatrixLoadFromMemory",
28642864
Attribute::None,
28652865
2,
2866-
{{0x200}, {0x63}},
2867-
{{0x0}, {0x0}}}, // Overloads: o,hfwi
2866+
{{0x200}, {0x800}},
2867+
{{0x0}, {0x0}}}, // Overloads: o,a
28682868
{OC::LinAlgMatrixLength,
28692869
"LinAlgMatrixLength",
28702870
OCC::LinAlgMatrixLength,
@@ -2911,8 +2911,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29112911
"linAlgMatrixStoreToMemory",
29122912
Attribute::None,
29132913
2,
2914-
{{0x200}, {0x63}},
2915-
{{0x0}, {0x0}}}, // Overloads: o,hfwi
2914+
{{0x200}, {0x800}},
2915+
{{0x0}, {0x0}}}, // Overloads: o,a
29162916
{OC::LinAlgMatrixQueryAccumulatorLayout,
29172917
"LinAlgMatrixQueryAccumulatorLayout",
29182918
OCC::LinAlgMatrixQueryAccumulatorLayout,
@@ -2967,8 +2967,8 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29672967
"linAlgMatrixAccumulateToMemory",
29682968
Attribute::None,
29692969
2,
2970-
{{0x200}, {0x63}},
2971-
{{0x0}, {0x0}}}, // Overloads: o,hfwi
2970+
{{0x200}, {0x800}},
2971+
{{0x0}, {0x0}}}, // Overloads: o,a
29722972
{OC::LinAlgMatrixOuterProduct,
29732973
"LinAlgMatrixOuterProduct",
29742974
OCC::LinAlgMatrixOuterProduct,
@@ -3152,6 +3152,8 @@ unsigned OP::GetTypeSlot(Type *pType) {
31523152
return TS_Extended;
31533153
case Type::VectorTyID:
31543154
return TS_Vector;
3155+
case Type::ArrayTyID:
3156+
return TS_Array;
31553157
default:
31563158
break;
31573159
}
@@ -3166,26 +3168,39 @@ const char *OP::GetOverloadTypeName(unsigned TypeSlot) {
31663168
StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
31673169
DXASSERT(!Ty->isVoidTy(), "must not pass void type here");
31683170
unsigned TypeSlot = OP::GetTypeSlot(Ty);
3171+
31693172
if (TypeSlot < TS_BasicCount) {
31703173
return GetOverloadTypeName(TypeSlot);
3171-
} else if (TypeSlot == TS_UDT) {
3174+
}
3175+
3176+
switch (TypeSlot) {
3177+
case TS_UDT: {
31723178
if (Ty->isPointerTy())
31733179
Ty = Ty->getPointerElementType();
31743180
StructType *ST = cast<StructType>(Ty);
31753181
return ST->getStructName();
3176-
} else if (TypeSlot == TS_Object) {
3182+
}
3183+
case TS_Object: {
31773184
StructType *ST = cast<StructType>(Ty);
31783185
if (dxilutil::IsHLSLLinAlgMatrixType(Ty))
31793186
return (Twine("m") + Twine(dxilutil::GetHLSLLinAlgMatrixTypeMangling(ST)))
31803187
.toStringRef(Storage);
31813188
return ST->getStructName();
3182-
} else if (TypeSlot == TS_Vector) {
3189+
}
3190+
case TS_Vector: {
31833191
VectorType *VecTy = cast<VectorType>(Ty);
31843192
return (Twine("v") + Twine(VecTy->getNumElements()) +
31853193
Twine(
31863194
GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType()))))
31873195
.toStringRef(Storage);
3188-
} else if (TypeSlot == TS_Extended) {
3196+
}
3197+
case TS_Array: {
3198+
if (Ty->isPointerTy())
3199+
Ty = Ty->getPointerElementType();
3200+
ArrayType *ArrTy = cast<ArrayType>(Ty);
3201+
return GetOverloadTypeName(OP::GetTypeSlot(ArrTy->getArrayElementType()));
3202+
}
3203+
case TS_Extended: {
31893204
DXASSERT(isa<StructType>(Ty),
31903205
"otherwise, extended overload type not wrapped in struct type.");
31913206
StructType *ST = cast<StructType>(Ty);
@@ -3200,11 +3215,14 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
32003215
OS << GetTypeName(ST->getElementType(I), TempStr);
32013216
}
32023217
return OS.str();
3203-
} else {
3204-
raw_svector_ostream OS(Storage);
3205-
Ty->print(OS);
3206-
return OS.str();
32073218
}
3219+
default:
3220+
break;
3221+
}
3222+
3223+
raw_svector_ostream OS(Storage);
3224+
Ty->print(OS);
3225+
return OS.str();
32083226
}
32093227

32103228
StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,

lib/HLSL/HLOperationLower.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7226,6 +7226,53 @@ Value *TranslateLinAlgCopyConvertMatrix(CallInst *CI, IntrinsicOp IOP,
72267226
return nullptr;
72277227
}
72287228

7229+
Value *TranslateLinAlgMatrixLoadFromMemory(
7230+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
7231+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
7232+
bool &Translated) {
7233+
hlsl::OP *HlslOp = &Helper.hlslOP;
7234+
IRBuilder<> Builder(CI);
7235+
7236+
Value *MatrixPtr = CI->getArgOperand(1);
7237+
DXASSERT_NOMSG(isa<PointerType>(MatrixPtr->getType()));
7238+
Type *MatrixType = MatrixPtr->getType()->getPointerElementType();
7239+
7240+
Value *Arr = CI->getArgOperand(2);
7241+
Value *Offset = CI->getArgOperand(3);
7242+
Value *Stride = CI->getArgOperand(4);
7243+
Value *Layout = CI->getArgOperand(5);
7244+
7245+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7246+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {MatrixType, Arr->getType()});
7247+
7248+
Value *Matrix =
7249+
Builder.CreateCall(DxilFunc, {OpArg, Arr, Offset, Stride, Layout});
7250+
Builder.CreateStore(Matrix, MatrixPtr);
7251+
7252+
return nullptr;
7253+
}
7254+
7255+
Value *TranslateLinAlgMatrixAccumStoreToMemory(
7256+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
7257+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
7258+
bool &Translated) {
7259+
hlsl::OP *HlslOp = &Helper.hlslOP;
7260+
IRBuilder<> Builder(CI);
7261+
7262+
Value *Matrix = CI->getArgOperand(1);
7263+
Value *Arr = CI->getArgOperand(2);
7264+
Value *Offset = CI->getArgOperand(3);
7265+
Value *Stride = CI->getArgOperand(4);
7266+
Value *Layout = CI->getArgOperand(5);
7267+
7268+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7269+
Function *DxilFunc =
7270+
HlslOp->GetOpFunc(OpCode, {Matrix->getType(), Arr->getType()});
7271+
7272+
return Builder.CreateCall(DxilFunc,
7273+
{OpArg, Matrix, Arr, Offset, Stride, Layout});
7274+
}
7275+
72297276
} // namespace
72307277

72317278
// Lower table.
@@ -7989,14 +8036,16 @@ constexpr IntrinsicLower gLowerTable[] = {
79898036
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromDescriptor,
79908037
TranslateLinAlgMatrixLoadFromDescriptor,
79918038
DXIL::OpCode::LinAlgMatrixLoadFromDescriptor},
7992-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromMemory, EmptyLower,
8039+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromMemory,
8040+
TranslateLinAlgMatrixLoadFromMemory,
79938041
DXIL::OpCode::LinAlgMatrixLoadFromMemory},
79948042
{IntrinsicOp::IOP___builtin_LinAlg_MatrixSetElement,
79958043
TranslateLinAlgMatrixSetElement, DXIL::OpCode::LinAlgMatrixSetElement},
79968044
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToDescriptor,
79978045
TranslateLinAlgMatrixAccumStoreToDescriptor,
79988046
DXIL::OpCode::LinAlgMatrixStoreToDescriptor},
7999-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToMemory, EmptyLower,
8047+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToMemory,
8048+
TranslateLinAlgMatrixAccumStoreToMemory,
80008049
DXIL::OpCode::LinAlgMatrixStoreToMemory},
80018050
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate,
80028051
TranslateLinAlgMatrixAccumulate, DXIL::OpCode::LinAlgMatrixAccumulate},
@@ -8010,7 +8059,8 @@ constexpr IntrinsicLower gLowerTable[] = {
80108059
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor,
80118060
TranslateLinAlgMatrixAccumStoreToDescriptor,
80128061
DXIL::OpCode::LinAlgMatrixAccumulateToDescriptor},
8013-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToMemory, EmptyLower,
8062+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToMemory,
8063+
TranslateLinAlgMatrixAccumStoreToMemory,
80148064
DXIL::OpCode::LinAlgMatrixAccumulateToMemory},
80158065
{IntrinsicOp::IOP___builtin_LinAlg_MatrixOuterProduct,
80168066
TranslateLinAlgMatrixOuterProduct, DXIL::OpCode::LinAlgMatrixOuterProduct},
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s
3+
4+
groupshared float SharedArr[64];
5+
6+
void fn(groupshared float Arr[64]) {
7+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
8+
__builtin_LinAlg_MatrixAccumulateToMemory(mat, Arr, 0, 0, 0);
9+
}
10+
11+
// CHECK: @{{.*}} = external addrspace(3) global [64 x float]
12+
13+
[numthreads(4,1,1)]
14+
void main() {
15+
// CHECK-LABEL: define void @main()
16+
17+
// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC4M5N4U1S2.f32(i32 -2147483620, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 0, i32 0, i32 0) ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout)
18+
fn(SharedArr);
19+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s
3+
4+
groupshared float SharedArr[64];
5+
6+
void fn(groupshared float Arr[64]) {
7+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
8+
__builtin_LinAlg_MatrixLoadFromMemory(mat, Arr, 0, 0, 0);
9+
}
10+
11+
// CHECK: @{{.*}} = external addrspace(3) global [64 x float]
12+
13+
[numthreads(4,1,1)]
14+
void main() {
15+
// CHECK-LABEL: define void @main()
16+
17+
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixLoadFromMemory.mC4M5N4U1S2.f32(i32 -2147483633, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 0, i32 0, i32 0) ; LinAlgMatrixLoadFromMemory(memory,offset,stride,layout)
18+
fn(SharedArr);
19+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s
3+
4+
groupshared float SharedArr[64];
5+
6+
void fn(groupshared float Arr[64]) {
7+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
8+
__builtin_LinAlg_MatrixStoreToMemory(mat, Arr, 0, 0, 0);
9+
}
10+
11+
// CHECK: @{{.*}} = external addrspace(3) global [64 x float]
12+
13+
[numthreads(4,1,1)]
14+
void main() {
15+
// CHECK-LABEL: define void @main()
16+
17+
// CHECK: call void @dx.op.linAlgMatrixStoreToMemory.mC4M5N4U1S2.f32(i32 -2147483627, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, [64 x float] addrspace(3)* nonnull @{{.*}}, i32 0, i32 0, i32 0) ; LinAlgMatrixStoreToMemory(matrix,memory,offset,stride,layout)
18+
fn(SharedArr);
19+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixAccumulateToMemory 'void (__builtin_LinAlgMatrix {{.*}}, float const __attribute__((address_space(3))) (&)[64], unsigned int, unsigned int, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} memory 'float const __attribute__((address_space(3))) (&)[64]'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int'
9+
// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int'
10+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 420
11+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
12+
13+
groupshared float SharedArr[64];
14+
15+
void fn(groupshared float Arr[64]) {
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
17+
__builtin_LinAlg_MatrixAccumulateToMemory(mat, Arr, 0, 0, 0);
18+
}
19+
20+
[shader("compute")]
21+
[numthreads(1,1,1)]
22+
void main() {
23+
fn(SharedArr);
24+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %dxc -T cs_6_9 -HV 202x -E main %s -verify
2+
3+
groupshared float SharedArr[64];
4+
5+
void fn(groupshared float Arr[64], float F) {
6+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
7+
8+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixAccumulateToMemory potentially used by ''main'' requires shader model 6.10 or greater}}
9+
__builtin_LinAlg_MatrixAccumulateToMemory(mat, Arr, 0, 0, 0);
10+
}
11+
12+
[numthreads(4,1,1)]
13+
void main() {
14+
fn(SharedArr, 6.0);
15+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixLoadFromMemory 'void (__builtin_LinAlgMatrix {{.*}}, float const __attribute__((address_space(3))) (&)[64], unsigned int, unsigned int, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} memory 'float const __attribute__((address_space(3))) (&)[64]'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int'
9+
// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int'
10+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 411
11+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
12+
13+
groupshared float SharedArr[64];
14+
15+
void fn(groupshared float Arr[64]) {
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
17+
__builtin_LinAlg_MatrixLoadFromMemory(mat, Arr, 0, 0, 0);
18+
}
19+
20+
[shader("compute")]
21+
[numthreads(1,1,1)]
22+
void main() {
23+
fn(SharedArr);
24+
}

0 commit comments

Comments
 (0)