Skip to content

Commit 0d151c8

Browse files
authored
[SM6.10] Implement MatrixGetElem Builtin (microsoft#8198)
Fixes microsoft#7900
1 parent 6929512 commit 0d151c8

6 files changed

Lines changed: 288 additions & 2 deletions

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7104,7 +7104,22 @@ Value *TranslateLinAlgMatrixGetElement(CallInst *CI, IntrinsicOp IOP,
71047104
HLOperationLowerHelper &Helper,
71057105
HLObjectOperationLowerHelper *ObjHelper,
71067106
bool &Translated) {
7107-
DXASSERT(false, "Not implemented.");
7107+
hlsl::OP *HlslOp = &Helper.hlslOP;
7108+
IRBuilder<> Builder(CI);
7109+
7110+
Value *RetElemPtr = CI->getArgOperand(1);
7111+
DXASSERT_NOMSG(isa<PointerType>(RetElemPtr->getType()));
7112+
Type *RetTy = RetElemPtr->getType()->getPointerElementType();
7113+
7114+
Value *Matrix = CI->getArgOperand(2);
7115+
Value *Index = CI->getArgOperand(3);
7116+
7117+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7118+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {RetTy, Matrix->getType()});
7119+
7120+
Value *RetElem = Builder.CreateCall(DxilFunc, {OpArg, Matrix, Index});
7121+
Builder.CreateStore(RetElem, RetElemPtr);
7122+
71087123
return nullptr;
71097124
}
71107125

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
9+
10+
// CHECK: call i32 @dx.op.linAlgMatrixGetElement.i32.mC4M5N4U1S2(i32 -2147483630, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, i32 0) ; LinAlgMatrixGetElement(matrix,threadLocalIndex)
11+
uint elem1;
12+
__builtin_LinAlg_MatrixGetElement(elem1, mat, 0);
13+
// CHECK: call float @dx.op.linAlgMatrixGetElement.f32.mC4M5N4U1S2(i32 -2147483630, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, i32 1) ; LinAlgMatrixGetElement(matrix,threadLocalIndex)
14+
float elem2;
15+
__builtin_LinAlg_MatrixGetElement(elem2, mat, 1);
16+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixGetElement 'void (unsigned int &, __builtin_LinAlgMatrix {{.*}}, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret 'unsigned int &&__restrict'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} threadLocalIndex 'unsigned int'
8+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 408
9+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
10+
11+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixGetElement 'void (float &, __builtin_LinAlgMatrix {{.*}}, unsigned int)' extern
12+
// CHECK-NEXT: ParmVarDecl {{.*}} ret 'float &&__restrict'
13+
// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}'
14+
// CHECK-NEXT: ParmVarDecl {{.*}} threadLocalIndex 'unsigned int'
15+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 408
16+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
17+
18+
[shader("compute")]
19+
[numthreads(1,1,1)]
20+
void main() {
21+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
22+
23+
uint elem1;
24+
__builtin_LinAlg_MatrixGetElement(elem1, mat, 3);
25+
26+
float elem2;
27+
__builtin_LinAlg_MatrixGetElement(elem2, mat, 4);
28+
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: not %dxc -T lib_6_10 %s 2>&1 | FileCheck %s
3+
4+
// CHECK: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(gs).
5+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainGS'.
6+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(ds).
7+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainDS'.
8+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(hs).
9+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainHS'.
10+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(vs).
11+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainVS'.
12+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(ps).
13+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainPS'.
14+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(miss).
15+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainMS@@YAXURayPayload@@@Z'.
16+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(closesthit).
17+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainCH@@YAXURayPayload@@UAttribs@@@Z'.
18+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(anyhit).
19+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainAH@@YAXURayPayload@@UAttribs@@@Z'.
20+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(callable).
21+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainCALL@@YAXUAttribs@@@Z'.
22+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(intersection).
23+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainIS@@YAXXZ'.
24+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(raygeneration).
25+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function '{{.*}}mainRG@@YAXXZ'.
26+
// CHECK-NEXT: Opcode LinAlgMatrixGetElement not valid in shader model lib_6_10(node).
27+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetElement{{.*}} of function 'mainNS'.
28+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
29+
// CHECK-NEXT: Function uses features incompatible with the shader stage (node) of the entry function.
30+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
31+
// CHECK-NEXT: Function uses features incompatible with the shader stage (raygeneration) of the entry function.
32+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
33+
// CHECK-NEXT: Function uses features incompatible with the shader stage (intersection) of the entry function.
34+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
35+
// CHECK-NEXT: Function uses features incompatible with the shader stage (callable) of the entry function.
36+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
37+
// CHECK-NEXT: Function uses features incompatible with the shader stage (anyhit) of the entry function.
38+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
39+
// CHECK-NEXT: Function uses features incompatible with the shader stage (closesthit) of the entry function.
40+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
41+
// CHECK-NEXT: Function uses features incompatible with the shader stage (miss) of the entry function.
42+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
43+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ps) of the entry function.
44+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
45+
// CHECK-NEXT: Function uses features incompatible with the shader stage (vs) of the entry function.
46+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
47+
// CHECK-NEXT: Function uses features incompatible with the shader stage (hs) of the entry function.
48+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
49+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ds) of the entry function.
50+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
51+
// CHECK-NEXT: Function uses features incompatible with the shader stage (gs) of the entry function.
52+
// CHECK-NEXT: Validation failed.
53+
54+
void CallFunction()
55+
{
56+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
57+
float elem;
58+
__builtin_LinAlg_MatrixGetElement(elem, mat, 0);
59+
}
60+
61+
// --- Allowed Stages ---
62+
63+
[shader("compute")]
64+
[numthreads(4,4,4)]
65+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
66+
CallFunction();
67+
}
68+
69+
struct Verts {
70+
float4 position : SV_Position;
71+
};
72+
73+
[shader("mesh")]
74+
[NumThreads(8, 8, 2)]
75+
[OutputTopology("triangle")]
76+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
77+
CallFunction();
78+
SetMeshOutputCounts(32, 16);
79+
Verts v = {0.0, 0.0, 0.0, 0.0};
80+
verts[ix] = v;
81+
}
82+
83+
struct AmpPayload {
84+
float2 dummy;
85+
};
86+
87+
[numthreads(8, 1, 1)]
88+
[shader("amplification")]
89+
void mainAS()
90+
{
91+
CallFunction();
92+
AmpPayload pld;
93+
pld.dummy = float2(1.0,2.0);
94+
DispatchMesh(8, 1, 1, pld);
95+
}
96+
97+
// --- Prohibited Stages ---
98+
99+
[shader("pixel")]
100+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
101+
CallFunction();
102+
return 1.0;
103+
}
104+
105+
[shader("vertex")]
106+
float4 mainVS(uint ix : SV_VertexID) : OUT {
107+
CallFunction();
108+
return 1.0;
109+
}
110+
111+
[shader("node")]
112+
[nodedispatchgrid(8,1,1)]
113+
[numthreads(64,2,2)]
114+
void mainNS() {
115+
CallFunction();
116+
}
117+
118+
[shader("raygeneration")]
119+
void mainRG() {
120+
CallFunction();
121+
}
122+
123+
[shader("intersection")]
124+
void mainIS() {
125+
CallFunction();
126+
}
127+
128+
struct Attribs { float2 barys; };
129+
130+
[shader("callable")]
131+
void mainCALL(inout Attribs attrs) {
132+
CallFunction();
133+
}
134+
135+
struct [raypayload] RayPayload
136+
{
137+
float elem
138+
: write(caller,closesthit,anyhit,miss)
139+
: read(caller,closesthit,anyhit,miss);
140+
};
141+
142+
[shader("anyhit")]
143+
void mainAH(inout RayPayload pld, in Attribs attrs) {
144+
CallFunction();
145+
}
146+
147+
[shader("closesthit")]
148+
void mainCH(inout RayPayload pld, in Attribs attrs) {
149+
CallFunction();
150+
}
151+
152+
[shader("miss")]
153+
void mainMS(inout RayPayload pld) {
154+
CallFunction();
155+
}
156+
157+
struct PosStruct {
158+
float4 pos : SV_Position;
159+
};
160+
161+
struct PCStruct
162+
{
163+
float Edges[3] : SV_TessFactor;
164+
float Inside : SV_InsideTessFactor;
165+
float4 test : TEST;
166+
};
167+
168+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
169+
OutputPatch<PosStruct, 3> op,
170+
uint ix : SV_PrimitiveID)
171+
{
172+
PCStruct a;
173+
a.Edges[0] = ip[0].pos.w;
174+
a.Edges[1] = ip[0].pos.w;
175+
a.Edges[2] = ip[0].pos.w;
176+
a.Inside = ip[0].pos.w;
177+
return a;
178+
}
179+
180+
[shader("hull")]
181+
[domain("tri")]
182+
[partitioning("fractional_odd")]
183+
[outputtopology("triangle_cw")]
184+
[outputcontrolpoints(3)]
185+
[patchconstantfunc("HSPatch")]
186+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
187+
{
188+
CallFunction();
189+
PosStruct s;
190+
s.pos = p[ix].pos;
191+
return s;
192+
}
193+
194+
[shader("domain")]
195+
[domain("tri")]
196+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
197+
uint ix : SV_PrimitiveID)
198+
{
199+
CallFunction();
200+
PosStruct v;
201+
v.pos = patch[0].pos;
202+
return v;
203+
}
204+
205+
float4 a;
206+
207+
[shader("geometry")]
208+
[maxvertexcount(1)]
209+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
210+
inout PointStream<PosStruct> OutputStream)
211+
{
212+
CallFunction();
213+
PosStruct s;
214+
s.pos = a;
215+
OutputStream.Append(s);
216+
OutputStream.RestartStrip();
217+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
uint elem;
7+
8+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixGetElement potentially used by ''main'' requires shader model 6.10 or greater}}
9+
__builtin_LinAlg_MatrixGetElement(elem, mat, 0);
10+
}

utils/hct/gen_intrin_main.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix
405405
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
406406
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix);
407407
uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex);
408-
numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(in LinAlgMatrix matrix, in uint threadLocalIndex);
408+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out numeric ret, in LinAlgMatrix matrix, in uint threadLocalIndex);
409409
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value);
410410
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
411411
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);

0 commit comments

Comments
 (0)