Skip to content

Commit ca7587b

Browse files
authored
[SM6.10] Implement CopyConvertMatrix Builtin (#8201)
Fixes #7897
1 parent 37383d4 commit ca7587b

5 files changed

Lines changed: 273 additions & 1 deletion

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7206,7 +7206,23 @@ Value *TranslateLinAlgCopyConvertMatrix(CallInst *CI, IntrinsicOp IOP,
72067206
HLOperationLowerHelper &Helper,
72077207
HLObjectOperationLowerHelper *ObjHelper,
72087208
bool &Translated) {
7209-
DXASSERT(false, "Not implemented.");
7209+
hlsl::OP *HlslOp = &Helper.hlslOP;
7210+
IRBuilder<> Builder(CI);
7211+
7212+
Value *MatrixRPtr = CI->getArgOperand(1);
7213+
DXASSERT_NOMSG(isa<PointerType>(MatrixRPtr->getType()));
7214+
Type *MatrixRTy = MatrixRPtr->getType()->getPointerElementType();
7215+
7216+
Value *MatrixSrc = CI->getArgOperand(2);
7217+
Value *Transpose = CI->getArgOperand(3);
7218+
7219+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7220+
Function *DxilFunc =
7221+
HlslOp->GetOpFunc(OpCode, {MatrixRTy, MatrixSrc->getType()});
7222+
7223+
Value *MatrixR = Builder.CreateCall(DxilFunc, {OpArg, MatrixSrc, Transpose});
7224+
Builder.CreateStore(MatrixR, MatrixRPtr);
7225+
72107226
return nullptr;
72117227
}
72127228

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
// CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgCopyConvertMatrix.mC4M5N4U1S2.mC2M5N4U1S2(i32 -2147483635, %dx.types.LinAlgMatrixC2M5N4U1S2 {{.*}}, i1 false) ; LinAlgCopyConvertMatrix(srcMatrix,transpose)
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(2, 5, 4, 1, 2)]] mat1;
10+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
11+
__builtin_LinAlg_CopyConvertMatrix(mat2, mat1, false);
12+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_CopyConvertMatrix 'void (__builtin_LinAlgMatrix {{.*}}, __builtin_LinAlgMatrix {{.*}}, bool)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} source '__builtin_LinAlgMatrix {{.*}}'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} transpose 'bool'
8+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 405
9+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
10+
11+
[shader("compute")]
12+
[numthreads(1,1,1)]
13+
void main() {
14+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat1;
15+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
16+
__builtin_LinAlg_CopyConvertMatrix(mat2, mat1, true);
17+
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: not %dxc -T lib_6_10 %s 2>&1 | FileCheck %s
3+
4+
// CHECK: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(gs).
5+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainGS'.
6+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(ds).
7+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainDS'.
8+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(hs).
9+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainHS'.
10+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(vs).
11+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainVS'.
12+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(ps).
13+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainPS'.
14+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(miss).
15+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainMS@@YAXURayPayload@@@Z'.
16+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(closesthit).
17+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainCH@@YAXURayPayload@@UAttribs@@@Z'.
18+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(anyhit).
19+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainAH@@YAXURayPayload@@UAttribs@@@Z'.
20+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(callable).
21+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainCALL@@YAXUAttribs@@@Z'.
22+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(intersection).
23+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainIS@@YAXXZ'.
24+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(raygeneration).
25+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function '{{.*}}mainRG@@YAXXZ'.
26+
// CHECK-NEXT: Opcode LinAlgCopyConvertMatrix not valid in shader model lib_6_10(node).
27+
// CHECK-NEXT: note: at {{.*}} @dx.op.linAlgCopyConvertMatrix{{.*}} of function 'mainNS'.
28+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
29+
// CHECK-NEXT: Function uses features incompatible with the shader stage (node) of the entry function.
30+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
31+
// CHECK-NEXT: Function uses features incompatible with the shader stage (raygeneration) of the entry function.
32+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
33+
// CHECK-NEXT: Function uses features incompatible with the shader stage (intersection) of the entry function.
34+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
35+
// CHECK-NEXT: Function uses features incompatible with the shader stage (callable) of the entry function.
36+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
37+
// CHECK-NEXT: Function uses features incompatible with the shader stage (anyhit) of the entry function.
38+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
39+
// CHECK-NEXT: Function uses features incompatible with the shader stage (closesthit) of the entry function.
40+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
41+
// CHECK-NEXT: Function uses features incompatible with the shader stage (miss) of the entry function.
42+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
43+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ps) of the entry function.
44+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
45+
// CHECK-NEXT: Function uses features incompatible with the shader stage (vs) of the entry function.
46+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
47+
// CHECK-NEXT: Function uses features incompatible with the shader stage (hs) of the entry function.
48+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
49+
// CHECK-NEXT: Function uses features incompatible with the shader stage (ds) of the entry function.
50+
// CHECK-NEXT: Entry function performs some operation that is incompatible with the shader stage or other entry properties. See other errors for details.
51+
// CHECK-NEXT: Function uses features incompatible with the shader stage (gs) of the entry function.
52+
// CHECK-NEXT: Validation failed.
53+
54+
void CallFunction()
55+
{
56+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat1;
57+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat2;
58+
__builtin_LinAlg_CopyConvertMatrix(mat2, mat1, true);
59+
}
60+
61+
// --- Allowed Stages ---
62+
63+
[shader("compute")]
64+
[numthreads(4,4,4)]
65+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
66+
CallFunction();
67+
}
68+
69+
struct Verts {
70+
float4 position : SV_Position;
71+
};
72+
73+
[shader("mesh")]
74+
[NumThreads(8, 8, 2)]
75+
[OutputTopology("triangle")]
76+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
77+
CallFunction();
78+
SetMeshOutputCounts(32, 16);
79+
Verts v = {0.0, 0.0, 0.0, 0.0};
80+
verts[ix] = v;
81+
}
82+
83+
struct AmpPayload {
84+
float2 dummy;
85+
};
86+
87+
[numthreads(8, 1, 1)]
88+
[shader("amplification")]
89+
void mainAS()
90+
{
91+
CallFunction();
92+
AmpPayload pld;
93+
pld.dummy = float2(1.0,2.0);
94+
DispatchMesh(8, 1, 1, pld);
95+
}
96+
97+
// --- Prohibited Stages ---
98+
99+
[shader("pixel")]
100+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
101+
CallFunction();
102+
return 1.0;
103+
}
104+
105+
[shader("vertex")]
106+
float4 mainVS(uint ix : SV_VertexID) : OUT {
107+
CallFunction();
108+
return 1.0;
109+
}
110+
111+
[shader("node")]
112+
[nodedispatchgrid(8,1,1)]
113+
[numthreads(64,2,2)]
114+
void mainNS() {
115+
CallFunction();
116+
}
117+
118+
[shader("raygeneration")]
119+
void mainRG() {
120+
CallFunction();
121+
}
122+
123+
[shader("intersection")]
124+
void mainIS() {
125+
CallFunction();
126+
}
127+
128+
struct Attribs { float2 barys; };
129+
130+
[shader("callable")]
131+
void mainCALL(inout Attribs attrs) {
132+
CallFunction();
133+
}
134+
135+
struct [raypayload] RayPayload
136+
{
137+
float elem
138+
: write(caller,closesthit,anyhit,miss)
139+
: read(caller,closesthit,anyhit,miss);
140+
};
141+
142+
[shader("anyhit")]
143+
void mainAH(inout RayPayload pld, in Attribs attrs) {
144+
CallFunction();
145+
}
146+
147+
[shader("closesthit")]
148+
void mainCH(inout RayPayload pld, in Attribs attrs) {
149+
CallFunction();
150+
}
151+
152+
[shader("miss")]
153+
void mainMS(inout RayPayload pld) {
154+
CallFunction();
155+
}
156+
157+
struct PosStruct {
158+
float4 pos : SV_Position;
159+
};
160+
161+
struct PCStruct
162+
{
163+
float Edges[3] : SV_TessFactor;
164+
float Inside : SV_InsideTessFactor;
165+
float4 test : TEST;
166+
};
167+
168+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
169+
OutputPatch<PosStruct, 3> op,
170+
uint ix : SV_PrimitiveID)
171+
{
172+
PCStruct a;
173+
a.Edges[0] = ip[0].pos.w;
174+
a.Edges[1] = ip[0].pos.w;
175+
a.Edges[2] = ip[0].pos.w;
176+
a.Inside = ip[0].pos.w;
177+
return a;
178+
}
179+
180+
[shader("hull")]
181+
[domain("tri")]
182+
[partitioning("fractional_odd")]
183+
[outputtopology("triangle_cw")]
184+
[outputcontrolpoints(3)]
185+
[patchconstantfunc("HSPatch")]
186+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
187+
{
188+
CallFunction();
189+
PosStruct s;
190+
s.pos = p[ix].pos;
191+
return s;
192+
}
193+
194+
[shader("domain")]
195+
[domain("tri")]
196+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
197+
uint ix : SV_PrimitiveID)
198+
{
199+
CallFunction();
200+
PosStruct v;
201+
v.pos = patch[0].pos;
202+
return v;
203+
}
204+
205+
float4 a;
206+
207+
[shader("geometry")]
208+
[maxvertexcount(1)]
209+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
210+
inout PointStream<PosStruct> OutputStream)
211+
{
212+
CallFunction();
213+
PosStruct s;
214+
s.pos = a;
215+
OutputStream.Append(s);
216+
OutputStream.RestartStrip();
217+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
[numthreads(4,1,1)]
4+
void main() {
5+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
6+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 1, 1, 1, 1)]] mat2;
7+
8+
// expected-error@+1{{intrinsic __builtin_LinAlg_CopyConvertMatrix potentially used by ''main'' requires shader model 6.10 or greater}}
9+
__builtin_LinAlg_CopyConvertMatrix(mat, mat2, true);
10+
}

0 commit comments

Comments
 (0)