Skip to content

Commit 15c1596

Browse files
authored
[SM6.10] Implement MatVecMul/MatVecMullAdd Builtins (#8192)
Fixes #7908 Fixes #7909
1 parent bf5bf68 commit 15c1596

9 files changed

Lines changed: 476 additions & 2 deletions

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6956,7 +6956,25 @@ Value *TranslateLinAlgMatVecMul(CallInst *CI, IntrinsicOp IOP,
69566956
HLOperationLowerHelper &Helper,
69576957
HLObjectOperationLowerHelper *ObjHelper,
69586958
bool &Translated) {
6959-
DXASSERT(false, "Not implemented.");
6959+
hlsl::OP *HlslOp = &Helper.hlslOP;
6960+
IRBuilder<> Builder(CI);
6961+
6962+
Value *ReturnVecPtr = CI->getArgOperand(1);
6963+
DXASSERT_NOMSG(isa<PointerType>(ReturnVecPtr->getType()));
6964+
Type *ReturnVecType = ReturnVecPtr->getType()->getPointerElementType();
6965+
6966+
Value *Matrix = CI->getArgOperand(2);
6967+
Value *InputVector = CI->getArgOperand(3);
6968+
Value *InputVectorInterp = CI->getArgOperand(4);
6969+
6970+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
6971+
Function *DxilFunc = HlslOp->GetOpFunc(
6972+
OpCode, {ReturnVecType, Matrix->getType(), InputVector->getType()});
6973+
6974+
Value *ReturnVec = Builder.CreateCall(
6975+
DxilFunc, {OpArg, Matrix, InputVector, InputVectorInterp});
6976+
Builder.CreateStore(ReturnVec, ReturnVecPtr);
6977+
69606978
return nullptr;
69616979
}
69626980

@@ -6965,7 +6983,29 @@ Value *TranslateLinAlgMatVecMulAdd(CallInst *CI, IntrinsicOp IOP,
69656983
HLOperationLowerHelper &Helper,
69666984
HLObjectOperationLowerHelper *ObjHelper,
69676985
bool &Translated) {
6968-
DXASSERT(false, "Not implemented.");
6986+
hlsl::OP *HlslOp = &Helper.hlslOP;
6987+
IRBuilder<> Builder(CI);
6988+
6989+
Value *ReturnVecPtr = CI->getArgOperand(1);
6990+
DXASSERT_NOMSG(isa<PointerType>(ReturnVecPtr->getType()));
6991+
Type *ReturnVecType = ReturnVecPtr->getType()->getPointerElementType();
6992+
6993+
Value *Matrix = CI->getArgOperand(2);
6994+
Value *InputVector = CI->getArgOperand(3);
6995+
Value *InputVectorInterp = CI->getArgOperand(4);
6996+
Value *BiasVector = CI->getArgOperand(5);
6997+
Value *BiasVectorInterp = CI->getArgOperand(6);
6998+
6999+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7000+
Function *DxilFunc = HlslOp->GetOpFunc(
7001+
OpCode, {ReturnVecType, Matrix->getType(), InputVector->getType(),
7002+
BiasVector->getType()});
7003+
7004+
Value *ReturnVec = Builder.CreateCall(
7005+
DxilFunc, {OpArg, Matrix, InputVector, InputVectorInterp, BiasVector,
7006+
BiasVectorInterp});
7007+
Builder.CreateStore(ReturnVec, ReturnVecPtr);
7008+
69697009
return nullptr;
69707010
}
69717011

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
9+
float4 vec = {1,2,3,4};
10+
float4 result;
11+
12+
// CHECK: call <4 x float> @dx.op.linAlgMatVecMul.v4f32.mC4M5N4U1S2.v4f32(i32 -2147483623, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, <4 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>, i32 1) ; LinAlgMatVecMul(matrix,inputVector,interpretation)
13+
__builtin_LinAlg_MatrixVectorMultiply(result, mat, vec, 1);
14+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 4, 0, 0)]] mat;
8+
float4 vec = {1,2,3,4};
9+
float4 result;
10+
11+
// CHECK: call <4 x float> @dx.op.linAlgMatVecMulAdd.v4f32.mC5M3N4U0S0.v4f32.v4f32(i32 -2147483622, %dx.types.LinAlgMatrixC5M3N4U0S0 {{.*}}, <4 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>, i32 1, <4 x float> {{.*}}, i32 0) ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
12+
__builtin_LinAlg_MatrixVectorMultiplyAdd(result, mat, vec, 1, result, 0);
13+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixVectorMultiply 'void (vector<float, 4> &, __builtin_LinAlgMatrix {{.*}}, vector<float, 4>, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret 'vector<float, 4> &&__restrict'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} mat '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} input 'vector<float, 4>':'vector<float, 4>'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} input_interp 'unsigned int'
9+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 422
10+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
11+
12+
[shader("compute")]
13+
[numthreads(1,1,1)]
14+
void main() {
15+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
16+
__builtin_LinAlg_FillMatrix(mat, 15);
17+
18+
float4 vec = {1,2,3,4};
19+
float4 result;
20+
__builtin_LinAlg_MatrixVectorMultiply(result, mat, vec, 1);
21+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 %s -verify
3+
4+
// expected-no-diagnostics
5+
6+
RWByteAddressBuffer buf;
7+
void CallFunction()
8+
{
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
10+
float4 vec = {9,9,9,9};
11+
float4 result;
12+
__builtin_LinAlg_MatrixVectorMultiply(result, mat, vec, 1);
13+
}
14+
15+
// --- Allowed Stages ---
16+
17+
[shader("compute")]
18+
[numthreads(4,4,4)]
19+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
20+
CallFunction();
21+
}
22+
23+
struct Verts {
24+
float4 position : SV_Position;
25+
};
26+
27+
[shader("mesh")]
28+
[NumThreads(8, 8, 2)]
29+
[OutputTopology("triangle")]
30+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
31+
CallFunction();
32+
SetMeshOutputCounts(32, 16);
33+
Verts v = {0.0, 0.0, 0.0, 0.0};
34+
verts[ix] = v;
35+
}
36+
37+
struct AmpPayload {
38+
float2 dummy;
39+
};
40+
41+
[numthreads(8, 1, 1)]
42+
[shader("amplification")]
43+
void mainAS()
44+
{
45+
CallFunction();
46+
AmpPayload pld;
47+
pld.dummy = float2(1.0,2.0);
48+
DispatchMesh(8, 1, 1, pld);
49+
}
50+
51+
[shader("pixel")]
52+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
53+
CallFunction();
54+
return 1.0;
55+
}
56+
57+
[shader("vertex")]
58+
float4 mainVS(uint ix : SV_VertexID) : OUT {
59+
CallFunction();
60+
return 1.0;
61+
}
62+
63+
[shader("node")]
64+
[nodedispatchgrid(8,1,1)]
65+
[numthreads(64,2,2)]
66+
void mainNS() {
67+
CallFunction();
68+
}
69+
70+
[shader("raygeneration")]
71+
void mainRG() {
72+
CallFunction();
73+
}
74+
75+
[shader("intersection")]
76+
void mainIS() {
77+
CallFunction();
78+
}
79+
80+
struct Attribs { float2 barys; };
81+
82+
[shader("callable")]
83+
void mainCALL(inout Attribs attrs) {
84+
CallFunction();
85+
}
86+
87+
struct [raypayload] RayPayload
88+
{
89+
float elem
90+
: write(caller,closesthit,anyhit,miss)
91+
: read(caller,closesthit,anyhit,miss);
92+
};
93+
94+
[shader("anyhit")]
95+
void mainAH(inout RayPayload pld, in Attribs attrs) {
96+
CallFunction();
97+
}
98+
99+
[shader("closesthit")]
100+
void mainCH(inout RayPayload pld, in Attribs attrs) {
101+
CallFunction();
102+
}
103+
104+
[shader("miss")]
105+
void mainMS(inout RayPayload pld) {
106+
CallFunction();
107+
}
108+
109+
struct PosStruct {
110+
float4 pos : SV_Position;
111+
};
112+
113+
struct PCStruct
114+
{
115+
float Edges[3] : SV_TessFactor;
116+
float Inside : SV_InsideTessFactor;
117+
float4 test : TEST;
118+
};
119+
120+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
121+
OutputPatch<PosStruct, 3> op,
122+
uint ix : SV_PrimitiveID)
123+
{
124+
PCStruct a;
125+
a.Edges[0] = ip[0].pos.w;
126+
a.Edges[1] = ip[0].pos.w;
127+
a.Edges[2] = ip[0].pos.w;
128+
a.Inside = ip[0].pos.w;
129+
return a;
130+
}
131+
132+
[shader("hull")]
133+
[domain("tri")]
134+
[partitioning("fractional_odd")]
135+
[outputtopology("triangle_cw")]
136+
[outputcontrolpoints(3)]
137+
[patchconstantfunc("HSPatch")]
138+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
139+
{
140+
CallFunction();
141+
PosStruct s;
142+
s.pos = p[ix].pos;
143+
return s;
144+
}
145+
146+
[shader("domain")]
147+
[domain("tri")]
148+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
149+
uint ix : SV_PrimitiveID)
150+
{
151+
CallFunction();
152+
PosStruct v;
153+
v.pos = patch[0].pos;
154+
return v;
155+
}
156+
157+
float4 a;
158+
159+
[shader("geometry")]
160+
[maxvertexcount(1)]
161+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
162+
inout PointStream<PosStruct> OutputStream)
163+
{
164+
CallFunction();
165+
PosStruct s;
166+
s.pos = a;
167+
OutputStream.Append(s);
168+
OutputStream.RestartStrip();
169+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
float4 input = {1,2,3,4};
7+
float4 result;
8+
9+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixVectorMultiply potentially used by ''main'' requires shader model 6.10 or greater}}
10+
__builtin_LinAlg_MatrixVectorMultiply(result, mat, input, 1);
11+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixVectorMultiplyAdd 'void (vector<float, 4> &, __builtin_LinAlgMatrix {{.*}}, vector<float, 4>, unsigned int, vector<float, 4>, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret 'vector<float, 4> &&__restrict'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} mat '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} input 'vector<float, 4>':'vector<float, 4>'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} input_interp 'unsigned int'
9+
// CHECK-NEXT: ParmVarDecl {{.*}} bias 'vector<float, 4>':'vector<float, 4>'
10+
// CHECK-NEXT: ParmVarDecl {{.*}} bias_interp 'unsigned int'
11+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 423
12+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
13+
14+
[shader("compute")]
15+
[numthreads(1,1,1)]
16+
void main() {
17+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
18+
__builtin_LinAlg_FillMatrix(mat, 15);
19+
20+
float4 input = {1,2,3,4};
21+
float4 bias = {5,6,7,8};
22+
float4 result;
23+
__builtin_LinAlg_MatrixVectorMultiplyAdd(result, mat, input, 1, bias, 2);
24+
}

0 commit comments

Comments
 (0)