Skip to content

Commit 5ba6901

Browse files
authored
[SM6.10] Implement MatrixGetCoord Builtin (microsoft#8197)
Fixes microsoft#7899
1 parent 19ba28b commit 5ba6901

8 files changed

Lines changed: 266 additions & 8 deletions

File tree

lib/DXIL/DxilOperations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6602,7 +6602,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66026602
A(pETy);
66036603
break;
66046604
case OpCode::LinAlgMatrixGetCoordinate:
6605-
VEC4(pETy);
6605+
VEC2(pI32);
66066606
A(pI32);
66076607
A(pETy);
66086608
A(pI32);
@@ -6895,6 +6895,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
68956895
case OpCode::VectorReduceOr:
68966896
case OpCode::FDot:
68976897
case OpCode::LinAlgMatrixLength:
6898+
case OpCode::LinAlgMatrixGetCoordinate:
68986899
case OpCode::LinAlgMatrixStoreToDescriptor:
68996900
case OpCode::LinAlgMatrixAccumulateToDescriptor:
69006901
if (FT->getNumParams() <= 1)
@@ -7044,8 +7045,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70447045
case OpCode::SampleCmpLevel:
70457046
case OpCode::SampleCmpGrad:
70467047
case OpCode::SampleCmpBias:
7047-
case OpCode::RawBufferVectorLoad:
7048-
case OpCode::LinAlgMatrixGetCoordinate: {
7048+
case OpCode::RawBufferVectorLoad: {
70497049
StructType *ST = cast<StructType>(Ty);
70507050
return ST->getElementType(0);
70517051
}

lib/HLSL/HLOperationLower.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6998,8 +6998,16 @@ Value *TranslateLinAlgMatrixGetCoordinate(
69986998
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
69996999
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
70007000
bool &Translated) {
7001-
DXASSERT(false, "Not implemented.");
7002-
return nullptr;
7001+
hlsl::OP *HlslOp = &Helper.hlslOP;
7002+
IRBuilder<> Builder(CI);
7003+
7004+
Value *Matrix = CI->getArgOperand(1);
7005+
Value *Index = CI->getArgOperand(2);
7006+
7007+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7008+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, Matrix->getType());
7009+
7010+
return Builder.CreateCall(DxilFunc, {OpArg, Matrix, Index});
70037011
}
70047012

70057013
Value *TranslateLinAlgMatrixGetElement(CallInst *CI, IntrinsicOp IOP,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
// CHECK: call <2 x i32> @dx.op.linAlgMatrixGetCoordinate.mC4M5N4U1S2(i32 -2147483631, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, i32 1) ; LinAlgMatrixGetCoordinate(matrix,threadLocalIndex)
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
10+
uint2 coord = __builtin_LinAlg_MatrixGetCoordinate(mat, 1);
11+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixGetCoordinate 'vector<uint, 2> (__builtin_LinAlgMatrix {{.*}}, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} threadLocalIndex 'unsigned int'
7+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 407
8+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
9+
10+
[shader("compute")]
11+
[numthreads(1,1,1)]
12+
void main() {
13+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
14+
uint2 coord = __builtin_LinAlg_MatrixGetCoordinate(mat, 0);
15+
}
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: not %dxc -T lib_6_10 %s 2>&1 | FileCheck %s
3+
4+
// CHECK: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(gs).
5+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainGS'.
6+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(ds).
7+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainDS'.
8+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(hs).
9+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainHS'.
10+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(vs).
11+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainVS'.
12+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(ps).
13+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainPS'.
14+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(miss).
15+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainMS@@YAXURayPayload@@@Z'.
16+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(closesthit).
17+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainCH@@YAXURayPayload@@UAttribs@@@Z'.
18+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(anyhit).
19+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainAH@@YAXURayPayload@@UAttribs@@@Z'.
20+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(callable).
21+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainCALL@@YAXUAttribs@@@Z'.
22+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(intersection).
23+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainIS@@YAXXZ'.
24+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(raygeneration).
25+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function '{{.*}}mainRG@@YAXXZ'.
26+
// CHECK-NEXT: Opcode LinAlgMatrixGetCoordinate not valid in shader model lib_6_10(node).
27+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgMatrixGetCoordinate{{.*}} of function 'mainNS'.
28+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
29+
// CHECK-NEXT: Function uses features incompatible with the shader stage (node) of the entry function.
30+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
31+
// CHECK-NEXT: Function uses features incompatible with the shader stage (raygeneration) of the entry function.
32+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
33+
// CHECK-NEXT: Function uses features incompatible with the shader stage (intersection) of the entry function.
34+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
35+
// CHECK-NEXT: Function uses features incompatible with the shader stage (callable) of the entry function.
36+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
37+
// CHECK-NEXT: Function uses features incompatible with the shader stage (anyhit) of the entry function.
38+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
39+
// CHECK-NEXT: Function uses features incompatible with the shader stage (closesthit) of the entry function.
40+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
41+
// CHECK-NEXT: Function uses features incompatible with the shader stage (miss) of the entry function.
42+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
43+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ps) of the entry function.
44+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
45+
// CHECK-NEXT: Function uses features incompatible with the shader stage (vs) of the entry function.
46+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
47+
// CHECK-NEXT: Function uses features incompatible with the shader stage (hs) of the entry function.
48+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
49+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ds) of the entry function.
50+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
51+
// CHECK-NEXT: Function uses features incompatible with the shader stage (gs) of the entry function.
52+
// CHECK-NEXT: Validation failed.
53+
54+
void CallFunction()
55+
{
56+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
57+
uint2 coord = __builtin_LinAlg_MatrixGetCoordinate(mat, 0);
58+
}
59+
60+
// --- Allowed Stages ---
61+
62+
[shader("compute")]
63+
[numthreads(4,4,4)]
64+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
65+
CallFunction();
66+
}
67+
68+
struct Verts {
69+
float4 position : SV_Position;
70+
};
71+
72+
[shader("mesh")]
73+
[NumThreads(8, 8, 2)]
74+
[OutputTopology("triangle")]
75+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
76+
CallFunction();
77+
SetMeshOutputCounts(32, 16);
78+
Verts v = {0.0, 0.0, 0.0, 0.0};
79+
verts[ix] = v;
80+
}
81+
82+
struct AmpPayload {
83+
float2 dummy;
84+
};
85+
86+
[numthreads(8, 1, 1)]
87+
[shader("amplification")]
88+
void mainAS()
89+
{
90+
CallFunction();
91+
AmpPayload pld;
92+
pld.dummy = float2(1.0,2.0);
93+
DispatchMesh(8, 1, 1, pld);
94+
}
95+
96+
// --- Prohibited Stages ---
97+
98+
[shader("pixel")]
99+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
100+
CallFunction();
101+
return 1.0;
102+
}
103+
104+
[shader("vertex")]
105+
float4 mainVS(uint ix : SV_VertexID) : OUT {
106+
CallFunction();
107+
return 1.0;
108+
}
109+
110+
[shader("node")]
111+
[nodedispatchgrid(8,1,1)]
112+
[numthreads(64,2,2)]
113+
void mainNS() {
114+
CallFunction();
115+
}
116+
117+
[shader("raygeneration")]
118+
void mainRG() {
119+
CallFunction();
120+
}
121+
122+
[shader("intersection")]
123+
void mainIS() {
124+
CallFunction();
125+
}
126+
127+
struct Attribs { float2 barys; };
128+
129+
[shader("callable")]
130+
void mainCALL(inout Attribs attrs) {
131+
CallFunction();
132+
}
133+
134+
struct [raypayload] RayPayload
135+
{
136+
float elem
137+
: write(caller,closesthit,anyhit,miss)
138+
: read(caller,closesthit,anyhit,miss);
139+
};
140+
141+
[shader("anyhit")]
142+
void mainAH(inout RayPayload pld, in Attribs attrs) {
143+
CallFunction();
144+
}
145+
146+
[shader("closesthit")]
147+
void mainCH(inout RayPayload pld, in Attribs attrs) {
148+
CallFunction();
149+
}
150+
151+
[shader("miss")]
152+
void mainMS(inout RayPayload pld) {
153+
CallFunction();
154+
}
155+
156+
struct PosStruct {
157+
float4 pos : SV_Position;
158+
};
159+
160+
struct PCStruct
161+
{
162+
float Edges[3] : SV_TessFactor;
163+
float Inside : SV_InsideTessFactor;
164+
float4 test : TEST;
165+
};
166+
167+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
168+
OutputPatch<PosStruct, 3> op,
169+
uint ix : SV_PrimitiveID)
170+
{
171+
PCStruct a;
172+
a.Edges[0] = ip[0].pos.w;
173+
a.Edges[1] = ip[0].pos.w;
174+
a.Edges[2] = ip[0].pos.w;
175+
a.Inside = ip[0].pos.w;
176+
return a;
177+
}
178+
179+
[shader("hull")]
180+
[domain("tri")]
181+
[partitioning("fractional_odd")]
182+
[outputtopology("triangle_cw")]
183+
[outputcontrolpoints(3)]
184+
[patchconstantfunc("HSPatch")]
185+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
186+
{
187+
CallFunction();
188+
PosStruct s;
189+
s.pos = p[ix].pos;
190+
return s;
191+
}
192+
193+
[shader("domain")]
194+
[domain("tri")]
195+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
196+
uint ix : SV_PrimitiveID)
197+
{
198+
CallFunction();
199+
PosStruct v;
200+
v.pos = patch[0].pos;
201+
return v;
202+
}
203+
204+
float4 a;
205+
206+
[shader("geometry")]
207+
[maxvertexcount(1)]
208+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
209+
inout PointStream<PosStruct> OutputStream)
210+
{
211+
CallFunction();
212+
PosStruct s;
213+
s.pos = a;
214+
OutputStream.Append(s);
215+
OutputStream.RestartStrip();
216+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
7+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixGetCoordinate potentially used by ''main'' requires shader model 6.10 or greater}}
8+
uint2 coord = __builtin_LinAlg_MatrixGetCoordinate(mat, 0);
9+
}

utils/hct/hctdb.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6444,9 +6444,7 @@ def populate_ExperimentalOps(self):
64446444
"o",
64456445
"",
64466446
[
6447-
db_dxil_param(
6448-
0, "$vec4", "", "operation result"
6449-
), # TODO: this should be <2 x i32>
6447+
db_dxil_param(0, "int2", "", "operation result"),
64506448
db_dxil_param(2, "$o", "matrix", "matrix to be examined"),
64516449
db_dxil_param(
64526450
3, "i32", "threadLocalIndex", "thread-local index to be examined"

utils/hct/hctdb_instrhelp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def print_opfunc_table(self):
629629
"u64": "A(pI64);",
630630
"u8": "A(pI8);",
631631
"v": "A(pV);",
632+
"int2": "VEC2(pI32);",
632633
"$vec2": "VEC2(pETy);",
633634
"$vec4": "VEC4(pETy);",
634635
"$vec9": "VEC9(pETy);",

0 commit comments

Comments
 (0)