Skip to content

Commit 8f92c94

Browse files
authored
[SM6.10] Implement MatrixAccumulate Builtin (#8194)
Fixes #8028
1 parent ac64043 commit 8f92c94

5 files changed

Lines changed: 228 additions & 1 deletion

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7063,7 +7063,23 @@ Value *TranslateLinAlgMatrixAccumulate(CallInst *CI, IntrinsicOp IOP,
70637063
HLOperationLowerHelper &Helper,
70647064
HLObjectOperationLowerHelper *ObjHelper,
70657065
bool &Translated) {
7066-
DXASSERT(false, "Not implemented.");
7066+
hlsl::OP *HlslOp = &Helper.hlslOP;
7067+
IRBuilder<> Builder(CI);
7068+
7069+
Value *MatrixCPtr = CI->getArgOperand(1);
7070+
DXASSERT_NOMSG(isa<PointerType>(MatrixCPtr->getType()));
7071+
Type *MatrixCType = MatrixCPtr->getType()->getPointerElementType();
7072+
7073+
Value *MatrixLHS = CI->getArgOperand(2);
7074+
Value *MatrixRHS = CI->getArgOperand(3);
7075+
7076+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7077+
Function *DxilFunc = HlslOp->GetOpFunc(
7078+
OpCode, {MatrixCType, MatrixLHS->getType(), MatrixRHS->getType()});
7079+
7080+
Value *MatrixC = Builder.CreateCall(DxilFunc, {OpArg, MatrixLHS, MatrixRHS});
7081+
Builder.CreateStore(MatrixC, MatrixCPtr);
7082+
70677083
return nullptr;
70687084
}
70697085

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 4, 0, 0)]] mat1;
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 1, 1, 0, 0)]] mat2;
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(2, 2, 2, 2, 2)]] mat3;
11+
12+
// CHECK: call %dx.types.LinAlgMatrixC2M2N2U2S2 @dx.op.linAlgMatrixAccumulate.mC2M2N2U2S2.mC1M1N1U0S0.mC5M3N4U0S0(i32 -2147483624, %dx.types.LinAlgMatrixC1M1N1U0S0 {{.*}}, %dx.types.LinAlgMatrixC5M3N4U0S0 {{.*}}) ; LinAlgMatrixAccumulate(matrixLHS,matrixRHS)
13+
__builtin_LinAlg_MatrixAccumulate(mat3, mat2, mat1);
14+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixAccumulate 'void (__builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}})' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixC '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixLHS '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixRHS '__builtin_LinAlgMatrix {{.*}}'
8+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 415
9+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
10+
11+
12+
[shader("compute")]
13+
[numthreads(1,1,1)]
14+
void main() {
15+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat2;
17+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat3;
18+
__builtin_LinAlg_MatrixAccumulate(mat, mat2, mat3);
19+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 %s -verify
3+
4+
// expected-no-diagnostics
5+
6+
RWByteAddressBuffer buf;
7+
void CallFunction()
8+
{
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
11+
__builtin_LinAlg_MatrixAccumulate(mat, mat2, mat2);
12+
}
13+
14+
// --- Allowed Stages ---
15+
16+
[shader("compute")]
17+
[numthreads(4,4,4)]
18+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
19+
CallFunction();
20+
}
21+
22+
struct Verts {
23+
float4 position : SV_Position;
24+
};
25+
26+
[shader("mesh")]
27+
[NumThreads(8, 8, 2)]
28+
[OutputTopology("triangle")]
29+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
30+
CallFunction();
31+
SetMeshOutputCounts(32, 16);
32+
Verts v = {0.0, 0.0, 0.0, 0.0};
33+
verts[ix] = v;
34+
}
35+
36+
struct AmpPayload {
37+
float2 dummy;
38+
};
39+
40+
[numthreads(8, 1, 1)]
41+
[shader("amplification")]
42+
void mainAS()
43+
{
44+
CallFunction();
45+
AmpPayload pld;
46+
pld.dummy = float2(1.0,2.0);
47+
DispatchMesh(8, 1, 1, pld);
48+
}
49+
50+
[shader("pixel")]
51+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
52+
CallFunction();
53+
return 1.0;
54+
}
55+
56+
[shader("vertex")]
57+
float4 mainVS(uint ix : SV_VertexID) : OUT {
58+
CallFunction();
59+
return 1.0;
60+
}
61+
62+
[shader("node")]
63+
[nodedispatchgrid(8,1,1)]
64+
[numthreads(64,2,2)]
65+
void mainNS() {
66+
CallFunction();
67+
}
68+
69+
[shader("raygeneration")]
70+
void mainRG() {
71+
CallFunction();
72+
}
73+
74+
[shader("intersection")]
75+
void mainIS() {
76+
CallFunction();
77+
}
78+
79+
struct Attribs { float2 barys; };
80+
81+
[shader("callable")]
82+
void mainCALL(inout Attribs attrs) {
83+
CallFunction();
84+
}
85+
86+
struct [raypayload] RayPayload
87+
{
88+
float elem
89+
: write(caller,closesthit,anyhit,miss)
90+
: read(caller,closesthit,anyhit,miss);
91+
};
92+
93+
[shader("anyhit")]
94+
void mainAH(inout RayPayload pld, in Attribs attrs) {
95+
CallFunction();
96+
}
97+
98+
[shader("closesthit")]
99+
void mainCH(inout RayPayload pld, in Attribs attrs) {
100+
CallFunction();
101+
}
102+
103+
[shader("miss")]
104+
void mainMS(inout RayPayload pld) {
105+
CallFunction();
106+
}
107+
108+
struct PosStruct {
109+
float4 pos : SV_Position;
110+
};
111+
112+
struct PCStruct
113+
{
114+
float Edges[3] : SV_TessFactor;
115+
float Inside : SV_InsideTessFactor;
116+
float4 test : TEST;
117+
};
118+
119+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
120+
OutputPatch<PosStruct, 3> op,
121+
uint ix : SV_PrimitiveID)
122+
{
123+
PCStruct a;
124+
a.Edges[0] = ip[0].pos.w;
125+
a.Edges[1] = ip[0].pos.w;
126+
a.Edges[2] = ip[0].pos.w;
127+
a.Inside = ip[0].pos.w;
128+
return a;
129+
}
130+
131+
[shader("hull")]
132+
[domain("tri")]
133+
[partitioning("fractional_odd")]
134+
[outputtopology("triangle_cw")]
135+
[outputcontrolpoints(3)]
136+
[patchconstantfunc("HSPatch")]
137+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
138+
{
139+
CallFunction();
140+
PosStruct s;
141+
s.pos = p[ix].pos;
142+
return s;
143+
}
144+
145+
[shader("domain")]
146+
[domain("tri")]
147+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
148+
uint ix : SV_PrimitiveID)
149+
{
150+
CallFunction();
151+
PosStruct v;
152+
v.pos = patch[0].pos;
153+
return v;
154+
}
155+
156+
float4 a;
157+
158+
[shader("geometry")]
159+
[maxvertexcount(1)]
160+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
161+
inout PointStream<PosStruct> OutputStream)
162+
{
163+
CallFunction();
164+
PosStruct s;
165+
s.pos = a;
166+
OutputStream.Append(s);
167+
OutputStream.RestartStrip();
168+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 1, 1, 1, 1)]] mat2;
7+
8+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixAccumulate potentially used by ''main'' requires shader model 6.10 or greater}}
9+
__builtin_LinAlg_MatrixAccumulate(mat2, mat, mat);
10+
}

0 commit comments

Comments
 (0)