Skip to content

Commit ac64043

Browse files
authored
[SM6.10] Implement MatrixOuterProduct Builtin (#8193)
Fixes #7910
1 parent 15c1596 commit ac64043

5 files changed

Lines changed: 227 additions & 1 deletion

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7039,7 +7039,22 @@ Value *TranslateLinAlgMatrixOuterProduct(
70397039
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
70407040
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
70417041
bool &Translated) {
7042-
DXASSERT(false, "Not implemented.");
7042+
hlsl::OP *HlslOp = &Helper.hlslOP;
7043+
IRBuilder<> Builder(CI);
7044+
7045+
Value *MatrixPtr = CI->getArgOperand(1);
7046+
DXASSERT_NOMSG(isa<PointerType>(MatrixPtr->getType()));
7047+
Type *MatrixType = MatrixPtr->getType()->getPointerElementType();
7048+
Value *VecA = CI->getArgOperand(2);
7049+
Value *VecB = CI->getArgOperand(3);
7050+
7051+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7052+
Function *DxilFunc =
7053+
HlslOp->GetOpFunc(OpCode, {MatrixType, VecA->getType(), VecB->getType()});
7054+
7055+
Value *Matrix = Builder.CreateCall(DxilFunc, {OpArg, VecA, VecB});
7056+
Builder.CreateStore(Matrix, MatrixPtr);
7057+
70437058
return nullptr;
70447059
}
70457060

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
float4 lhs = {1,2,3,4};
9+
float4 rhs = {4,3,2,1};
10+
// CHECK: call %dx.types.LinAlgMatrixC2M2N2U2S2 @dx.op.linAlgMatrixOuterProduct.mC2M2N2U2S2.v4f32.v4f32(i32 -2147483619, <4 x float> {{.*}}, <4 x float> {{.*}}) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
11+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(2, 2, 2, 2, 2)]] mat;
12+
__builtin_LinAlg_MatrixOuterProduct(mat, lhs, rhs);
13+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixOuterProduct 'void (__builtin_LinAlgMatrix {{.*}}, vector<int, 4>, vector<int, 4>)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} vecA 'vector<int, 4>':'vector<int, 4>'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} vecB 'vector<int, 4>':'vector<int, 4>'
8+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 421
9+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
10+
11+
12+
[shader("compute")]
13+
[numthreads(1,1,1)]
14+
void main() {
15+
int4 vecA = {1,2,3,4};
16+
int4 vecB = {1,2,3,4};
17+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
18+
__builtin_LinAlg_MatrixOuterProduct(mat, vecA, vecB);
19+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 %s -verify
3+
4+
// expected-no-diagnostics
5+
6+
void CallFunction()
7+
{
8+
int4 vecA = {9,9,9,9};
9+
int4 vecB = {3,3,3,3};
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
11+
__builtin_LinAlg_MatrixOuterProduct(mat, vecA, vecB);
12+
}
13+
14+
// --- Allowed Stages ---
15+
16+
[shader("compute")]
17+
[numthreads(4,4,4)]
18+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
19+
CallFunction();
20+
}
21+
22+
struct Verts {
23+
float4 position : SV_Position;
24+
};
25+
26+
[shader("mesh")]
27+
[NumThreads(8, 8, 2)]
28+
[OutputTopology("triangle")]
29+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
30+
CallFunction();
31+
SetMeshOutputCounts(32, 16);
32+
Verts v = {0.0, 0.0, 0.0, 0.0};
33+
verts[ix] = v;
34+
}
35+
36+
struct AmpPayload {
37+
float2 dummy;
38+
};
39+
40+
[numthreads(8, 1, 1)]
41+
[shader("amplification")]
42+
void mainAS()
43+
{
44+
CallFunction();
45+
AmpPayload pld;
46+
pld.dummy = float2(1.0,2.0);
47+
DispatchMesh(8, 1, 1, pld);
48+
}
49+
50+
[shader("pixel")]
51+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
52+
CallFunction();
53+
return 1.0;
54+
}
55+
56+
[shader("vertex")]
57+
float4 mainVS(uint ix : SV_VertexID) : OUT {
58+
CallFunction();
59+
return 1.0;
60+
}
61+
62+
[shader("node")]
63+
[nodedispatchgrid(8,1,1)]
64+
[numthreads(64,2,2)]
65+
void mainNS() {
66+
CallFunction();
67+
}
68+
69+
[shader("raygeneration")]
70+
void mainRG() {
71+
CallFunction();
72+
}
73+
74+
[shader("intersection")]
75+
void mainIS() {
76+
CallFunction();
77+
}
78+
79+
struct Attribs { float2 barys; };
80+
81+
[shader("callable")]
82+
void mainCALL(inout Attribs attrs) {
83+
CallFunction();
84+
}
85+
86+
struct [raypayload] RayPayload
87+
{
88+
float elem
89+
: write(caller,closesthit,anyhit,miss)
90+
: read(caller,closesthit,anyhit,miss);
91+
};
92+
93+
[shader("anyhit")]
94+
void mainAH(inout RayPayload pld, in Attribs attrs) {
95+
CallFunction();
96+
}
97+
98+
[shader("closesthit")]
99+
void mainCH(inout RayPayload pld, in Attribs attrs) {
100+
CallFunction();
101+
}
102+
103+
[shader("miss")]
104+
void mainMS(inout RayPayload pld) {
105+
CallFunction();
106+
}
107+
108+
struct PosStruct {
109+
float4 pos : SV_Position;
110+
};
111+
112+
struct PCStruct
113+
{
114+
float Edges[3] : SV_TessFactor;
115+
float Inside : SV_InsideTessFactor;
116+
float4 test : TEST;
117+
};
118+
119+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
120+
OutputPatch<PosStruct, 3> op,
121+
uint ix : SV_PrimitiveID)
122+
{
123+
PCStruct a;
124+
a.Edges[0] = ip[0].pos.w;
125+
a.Edges[1] = ip[0].pos.w;
126+
a.Edges[2] = ip[0].pos.w;
127+
a.Inside = ip[0].pos.w;
128+
return a;
129+
}
130+
131+
[shader("hull")]
132+
[domain("tri")]
133+
[partitioning("fractional_odd")]
134+
[outputtopology("triangle_cw")]
135+
[outputcontrolpoints(3)]
136+
[patchconstantfunc("HSPatch")]
137+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
138+
{
139+
CallFunction();
140+
PosStruct s;
141+
s.pos = p[ix].pos;
142+
return s;
143+
}
144+
145+
[shader("domain")]
146+
[domain("tri")]
147+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
148+
uint ix : SV_PrimitiveID)
149+
{
150+
CallFunction();
151+
PosStruct v;
152+
v.pos = patch[0].pos;
153+
return v;
154+
}
155+
156+
float4 a;
157+
158+
[shader("geometry")]
159+
[maxvertexcount(1)]
160+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
161+
inout PointStream<PosStruct> OutputStream)
162+
{
163+
CallFunction();
164+
PosStruct s;
165+
s.pos = a;
166+
OutputStream.Append(s);
167+
OutputStream.RestartStrip();
168+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
float4 lhs = {1,2,3,4};
7+
float4 rhs = {4,3,2,1};
8+
9+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixOuterProduct potentially used by ''main'' requires shader model 6.10 or greater}}
10+
__builtin_LinAlg_MatrixOuterProduct(mat, lhs, rhs);
11+
}

0 commit comments

Comments
 (0)