Skip to content

Commit 6929512

Browse files
authored
[SM6.10] Implement MatrixMul/MulAccum Builtins (microsoft#8200)
Fixes microsoft#7911
1 parent f94f123 commit 6929512

9 files changed

Lines changed: 554 additions & 2 deletions

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7121,15 +7121,50 @@ Value *TranslateLinAlgMatrixMatrixMultiply(
71217121
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
71227122
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
71237123
bool &Translated) {
7124-
DXASSERT(false, "Not implemented.");
7124+
hlsl::OP *HlslOp = &Helper.hlslOP;
7125+
IRBuilder<> Builder(CI);
7126+
7127+
Value *MatrixCPtr = CI->getArgOperand(1);
7128+
DXASSERT_NOMSG(isa<PointerType>(MatrixCPtr->getType()));
7129+
Type *MatrixCTy = MatrixCPtr->getType()->getPointerElementType();
7130+
7131+
Value *MatrixA = CI->getArgOperand(2);
7132+
Value *MatrixB = CI->getArgOperand(3);
7133+
7134+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7135+
Function *DxilFunc = HlslOp->GetOpFunc(
7136+
OpCode, {MatrixCTy, MatrixA->getType(), MatrixB->getType()});
7137+
7138+
Value *MatrixC = Builder.CreateCall(DxilFunc, {OpArg, MatrixA, MatrixB});
7139+
Builder.CreateStore(MatrixC, MatrixCPtr);
7140+
71257141
return nullptr;
71267142
}
71277143

71287144
Value *TranslateLinAlgMatrixMatrixMultiplyAccumulate(
71297145
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
71307146
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
71317147
bool &Translated) {
7132-
DXASSERT(false, "Not implemented.");
7148+
hlsl::OP *HlslOp = &Helper.hlslOP;
7149+
IRBuilder<> Builder(CI);
7150+
7151+
Value *MatrixRPtr = CI->getArgOperand(1);
7152+
DXASSERT_NOMSG(isa<PointerType>(MatrixRPtr->getType()));
7153+
Type *MatrixRTy = MatrixRPtr->getType()->getPointerElementType();
7154+
7155+
Value *MatrixA = CI->getArgOperand(2);
7156+
Value *MatrixB = CI->getArgOperand(3);
7157+
Value *MatrixC = CI->getArgOperand(4);
7158+
7159+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7160+
Function *DxilFunc =
7161+
HlslOp->GetOpFunc(OpCode, {MatrixRTy, MatrixA->getType(),
7162+
MatrixB->getType(), MatrixC->getType()});
7163+
7164+
Value *MatrixR =
7165+
Builder.CreateCall(DxilFunc, {OpArg, MatrixA, MatrixB, MatrixC});
7166+
Builder.CreateStore(MatrixR, MatrixRPtr);
7167+
71337168
return nullptr;
71347169
}
71357170

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixMultiply.mC4M5N4U1S2.mC4M5N4U1S2.mC4M5N4U1S2(i32 -2147483625, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}) ; LinAlgMatrixMultiply(matrixA,matrixB)
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat1;
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
11+
__builtin_LinAlg_MatrixMatrixMultiply(mat2, mat1, mat1);
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixMultiplyAccumulate.mC4M5N4U1S2.mC4M5N4U1S2.mC4M5N4U1S2.mC4M5N4U1S2(i32 -2147483637, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}) ; LinAlgMatrixMultiplyAccumulate(matrixA,matrixB,matrixC)
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat1;
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
11+
__builtin_LinAlg_MatrixMatrixMultiplyAccumulate(mat2, mat1, mat1, mat1);
12+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixMatrixMultiply 'void (__builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}})' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixC '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixA '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixB '__builtin_LinAlgMatrix {{.*}}'
8+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 416
9+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
10+
11+
[shader("compute")]
12+
[numthreads(1,1,1)]
13+
void main() {
14+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat1;
15+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat2;
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat3;
17+
__builtin_LinAlg_MatrixMatrixMultiply(mat1, mat2, mat3);
18+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: not %dxc -T lib_6_10 %s 2>&1 | FileCheck %s
3+
4+
// CHECK: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(gs).
5+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainGS'.
6+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(ds).
7+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainDS'.
8+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(hs).
9+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainHS'.
10+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(vs).
11+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainVS'.
12+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(ps).
13+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainPS'.
14+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(miss).
15+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainMS@@YAXURayPayload@@@Z'.
16+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(closesthit).
17+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainCH@@YAXURayPayload@@UAttribs@@@Z'.
18+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(anyhit).
19+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainAH@@YAXURayPayload@@UAttribs@@@Z'.
20+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(callable).
21+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainCALL@@YAXUAttribs@@@Z'.
22+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(intersection).
23+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainIS@@YAXXZ'.
24+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(raygeneration).
25+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function '{{.*}}mainRG@@YAXXZ'.
26+
// CHECK-NEXT: Opcode LinAlgMatrixMultiply not valid in shader model lib_6_10(node).
27+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixMultiply{{.*}} of function 'mainNS'.
28+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
29+
// CHECK-NEXT: Function uses features incompatible with the shader stage (node) of the entry function.
30+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
31+
// CHECK-NEXT: Function uses features incompatible with the shader stage (raygeneration) of the entry function.
32+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
33+
// CHECK-NEXT: Function uses features incompatible with the shader stage (intersection) of the entry function.
34+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
35+
// CHECK-NEXT: Function uses features incompatible with the shader stage (callable) of the entry function.
36+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
37+
// CHECK-NEXT: Function uses features incompatible with the shader stage (anyhit) of the entry function.
38+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
39+
// CHECK-NEXT: Function uses features incompatible with the shader stage (closesthit) of the entry function.
40+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
41+
// CHECK-NEXT: Function uses features incompatible with the shader stage (miss) of the entry function.
42+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
43+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ps) of the entry function.
44+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
45+
// CHECK-NEXT: Function uses features incompatible with the shader stage (vs) of the entry function.
46+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
47+
// CHECK-NEXT: Function uses features incompatible with the shader stage (hs) of the entry function.
48+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
49+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ds) of the entry function.
50+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
51+
// CHECK-NEXT: Function uses features incompatible with the shader stage (gs) of the entry function.
52+
// CHECK-NEXT: Validation failed.
53+
54+
void CallFunction()
55+
{
56+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat1;
57+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat2;
58+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat3;
59+
__builtin_LinAlg_MatrixMatrixMultiply(mat1, mat2, mat3);
60+
}
61+
62+
// --- Allowed Stages ---
63+
64+
[shader("compute")]
65+
[numthreads(4,4,4)]
66+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
67+
CallFunction();
68+
}
69+
70+
struct Verts {
71+
float4 position : SV_Position;
72+
};
73+
74+
[shader("mesh")]
75+
[NumThreads(8, 8, 2)]
76+
[OutputTopology("triangle")]
77+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
78+
CallFunction();
79+
SetMeshOutputCounts(32, 16);
80+
Verts v = {0.0, 0.0, 0.0, 0.0};
81+
verts[ix] = v;
82+
}
83+
84+
struct AmpPayload {
85+
float2 dummy;
86+
};
87+
88+
[numthreads(8, 1, 1)]
89+
[shader("amplification")]
90+
void mainAS()
91+
{
92+
CallFunction();
93+
AmpPayload pld;
94+
pld.dummy = float2(1.0,2.0);
95+
DispatchMesh(8, 1, 1, pld);
96+
}
97+
98+
// --- Prohibited Stages ---
99+
100+
[shader("pixel")]
101+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
102+
CallFunction();
103+
return 1.0;
104+
}
105+
106+
[shader("vertex")]
107+
float4 mainVS(uint ix : SV_VertexID) : OUT {
108+
CallFunction();
109+
return 1.0;
110+
}
111+
112+
[shader("node")]
113+
[nodedispatchgrid(8,1,1)]
114+
[numthreads(64,2,2)]
115+
void mainNS() {
116+
CallFunction();
117+
}
118+
119+
[shader("raygeneration")]
120+
void mainRG() {
121+
CallFunction();
122+
}
123+
124+
[shader("intersection")]
125+
void mainIS() {
126+
CallFunction();
127+
}
128+
129+
struct Attribs { float2 barys; };
130+
131+
[shader("callable")]
132+
void mainCALL(inout Attribs attrs) {
133+
CallFunction();
134+
}
135+
136+
struct [raypayload] RayPayload
137+
{
138+
float elem
139+
: write(caller,closesthit,anyhit,miss)
140+
: read(caller,closesthit,anyhit,miss);
141+
};
142+
143+
[shader("anyhit")]
144+
void mainAH(inout RayPayload pld, in Attribs attrs) {
145+
CallFunction();
146+
}
147+
148+
[shader("closesthit")]
149+
void mainCH(inout RayPayload pld, in Attribs attrs) {
150+
CallFunction();
151+
}
152+
153+
[shader("miss")]
154+
void mainMS(inout RayPayload pld) {
155+
CallFunction();
156+
}
157+
158+
struct PosStruct {
159+
float4 pos : SV_Position;
160+
};
161+
162+
struct PCStruct
163+
{
164+
float Edges[3] : SV_TessFactor;
165+
float Inside : SV_InsideTessFactor;
166+
float4 test : TEST;
167+
};
168+
169+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
170+
OutputPatch<PosStruct, 3> op,
171+
uint ix : SV_PrimitiveID)
172+
{
173+
PCStruct a;
174+
a.Edges[0] = ip[0].pos.w;
175+
a.Edges[1] = ip[0].pos.w;
176+
a.Edges[2] = ip[0].pos.w;
177+
a.Inside = ip[0].pos.w;
178+
return a;
179+
}
180+
181+
[shader("hull")]
182+
[domain("tri")]
183+
[partitioning("fractional_odd")]
184+
[outputtopology("triangle_cw")]
185+
[outputcontrolpoints(3)]
186+
[patchconstantfunc("HSPatch")]
187+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
188+
{
189+
CallFunction();
190+
PosStruct s;
191+
s.pos = p[ix].pos;
192+
return s;
193+
}
194+
195+
[shader("domain")]
196+
[domain("tri")]
197+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
198+
uint ix : SV_PrimitiveID)
199+
{
200+
CallFunction();
201+
PosStruct v;
202+
v.pos = patch[0].pos;
203+
return v;
204+
}
205+
206+
float4 a;
207+
208+
[shader("geometry")]
209+
[maxvertexcount(1)]
210+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
211+
inout PointStream<PosStruct> OutputStream)
212+
{
213+
CallFunction();
214+
PosStruct s;
215+
s.pos = a;
216+
OutputStream.Append(s);
217+
OutputStream.RestartStrip();
218+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 1, 1, 1, 1)]] mat2;
7+
8+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixMatrixMultiply potentially used by ''main'' requires shader model 6.10 or greater}}
9+
__builtin_LinAlg_MatrixMatrixMultiply(mat2, mat, mat);
10+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixMatrixMultiplyAccumulate 'void (__builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}})' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixR '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixA '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixB '__builtin_LinAlgMatrix {{.*}}'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} matrixC '__builtin_LinAlgMatrix {{.*}}'
9+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 417
10+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
11+
12+
[shader("compute")]
13+
[numthreads(1,1,1)]
14+
void main() {
15+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat1;
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat2;
17+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat3;
18+
__builtin_LinAlg_MatrixMatrixMultiplyAccumulate(mat1, mat2, mat3, mat1);
19+
}

0 commit comments

Comments
 (0)