Skip to content

Commit bf5bf68

Browse files
authored
[SM6.10] Implement MatrixLoadFromDescriptor Builtin (#8190)
Fixes #7902
1 parent ee4fd34 commit bf5bf68

5 files changed

Lines changed: 230 additions & 1 deletion

File tree

lib/HLSL/HLOperationLower.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6973,7 +6973,25 @@ Value *TranslateLinAlgMatrixLoadFromDescriptor(
69736973
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
69746974
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
69756975
bool &Translated) {
6976-
DXASSERT(false, "Not implemented.");
6976+
hlsl::OP *HlslOp = &Helper.hlslOP;
6977+
IRBuilder<> Builder(CI);
6978+
6979+
Value *MatrixPtr = CI->getArgOperand(1);
6980+
DXASSERT_NOMSG(isa<PointerType>(MatrixPtr->getType()));
6981+
Type *MatrixType = MatrixPtr->getType()->getPointerElementType();
6982+
6983+
Value *ResHandle = CI->getArgOperand(2);
6984+
Value *Offset = CI->getArgOperand(3);
6985+
Value *Stride = CI->getArgOperand(4);
6986+
Value *Layout = CI->getArgOperand(5);
6987+
6988+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
6989+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, MatrixType);
6990+
6991+
Value *Matrix =
6992+
Builder.CreateCall(DxilFunc, {OpArg, ResHandle, Offset, Stride, Layout});
6993+
Builder.CreateStore(Matrix, MatrixPtr);
6994+
69776995
return nullptr;
69786996
}
69796997

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
ByteAddressBuffer inbuf;
5+
6+
[numthreads(1,1,1)]
7+
void main() {
8+
// CHECK-LABEL: define void @main()
9+
10+
// CHECK: %{{.*}} = call %dx.types.LinAlgMatrixC1M1N1U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC1M1N1U0S0(i32 -2147483634, %dx.types.Handle %{{.*}}, i32 0, i32 0, i32 0) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout)
11+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 1, 1, 0, 0)]] mat;
12+
__builtin_LinAlg_MatrixLoadFromDescriptor(mat, inbuf, 0, 0, 0);
13+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixLoadFromDescriptor 'void (__builtin_LinAlgMatrix & {{.*}}, RWByteAddressBuffer, unsigned int, unsigned int, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix &&__restrict {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} buf 'RWByteAddressBuffer'
7+
// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int'
8+
// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int'
9+
// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int'
10+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 410
11+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
12+
13+
RWByteAddressBuffer inbuf;
14+
15+
[shader("compute")]
16+
[numthreads(1,1,1)]
17+
void main() {
18+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
19+
__builtin_LinAlg_MatrixLoadFromDescriptor(mat, inbuf, 0, 0, 0);
20+
}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 %s -verify
3+
4+
// expected-no-diagnostics
5+
6+
RWByteAddressBuffer buf;
7+
void CallFunction()
8+
{
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat;
10+
__builtin_LinAlg_MatrixLoadFromDescriptor(mat, buf, 5, 5, 5);
11+
}
12+
13+
// --- Allowed Stages ---
14+
15+
[shader("compute")]
16+
[numthreads(4,4,4)]
17+
void mainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
18+
CallFunction();
19+
}
20+
21+
struct Verts {
22+
float4 position : SV_Position;
23+
};
24+
25+
[shader("mesh")]
26+
[NumThreads(8, 8, 2)]
27+
[OutputTopology("triangle")]
28+
void mainMeS(out vertices Verts verts[32], uint ix : SV_GroupIndex) {
29+
CallFunction();
30+
SetMeshOutputCounts(32, 16);
31+
Verts v = {0.0, 0.0, 0.0, 0.0};
32+
verts[ix] = v;
33+
}
34+
35+
struct AmpPayload {
36+
float2 dummy;
37+
};
38+
39+
[numthreads(8, 1, 1)]
40+
[shader("amplification")]
41+
void mainAS()
42+
{
43+
CallFunction();
44+
AmpPayload pld;
45+
pld.dummy = float2(1.0,2.0);
46+
DispatchMesh(8, 1, 1, pld);
47+
}
48+
49+
[shader("pixel")]
50+
float4 mainPS(uint ix : SV_PrimitiveID) : SV_TARGET {
51+
CallFunction();
52+
return 1.0;
53+
}
54+
55+
[shader("vertex")]
56+
float4 mainVS(uint ix : SV_VertexID) : OUT {
57+
CallFunction();
58+
return 1.0;
59+
}
60+
61+
[shader("node")]
62+
[nodedispatchgrid(8,1,1)]
63+
[numthreads(64,2,2)]
64+
void mainNS() {
65+
CallFunction();
66+
}
67+
68+
[shader("raygeneration")]
69+
void mainRG() {
70+
CallFunction();
71+
}
72+
73+
[shader("intersection")]
74+
void mainIS() {
75+
CallFunction();
76+
}
77+
78+
struct Attribs { float2 barys; };
79+
80+
[shader("callable")]
81+
void mainCALL(inout Attribs attrs) {
82+
CallFunction();
83+
}
84+
85+
struct [raypayload] RayPayload
86+
{
87+
float elem
88+
: write(caller,closesthit,anyhit,miss)
89+
: read(caller,closesthit,anyhit,miss);
90+
};
91+
92+
[shader("anyhit")]
93+
void mainAH(inout RayPayload pld, in Attribs attrs) {
94+
CallFunction();
95+
}
96+
97+
[shader("closesthit")]
98+
void mainCH(inout RayPayload pld, in Attribs attrs) {
99+
CallFunction();
100+
}
101+
102+
[shader("miss")]
103+
void mainMS(inout RayPayload pld) {
104+
CallFunction();
105+
}
106+
107+
struct PosStruct {
108+
float4 pos : SV_Position;
109+
};
110+
111+
struct PCStruct
112+
{
113+
float Edges[3] : SV_TessFactor;
114+
float Inside : SV_InsideTessFactor;
115+
float4 test : TEST;
116+
};
117+
118+
PCStruct HSPatch(InputPatch<PosStruct, 3> ip,
119+
OutputPatch<PosStruct, 3> op,
120+
uint ix : SV_PrimitiveID)
121+
{
122+
PCStruct a;
123+
a.Edges[0] = ip[0].pos.w;
124+
a.Edges[1] = ip[0].pos.w;
125+
a.Edges[2] = ip[0].pos.w;
126+
a.Inside = ip[0].pos.w;
127+
return a;
128+
}
129+
130+
[shader("hull")]
131+
[domain("tri")]
132+
[partitioning("fractional_odd")]
133+
[outputtopology("triangle_cw")]
134+
[outputcontrolpoints(3)]
135+
[patchconstantfunc("HSPatch")]
136+
PosStruct mainHS(InputPatch<PosStruct, 3> p, uint ix : SV_OutputControlPointID)
137+
{
138+
CallFunction();
139+
PosStruct s;
140+
s.pos = p[ix].pos;
141+
return s;
142+
}
143+
144+
[shader("domain")]
145+
[domain("tri")]
146+
PosStruct mainDS(const OutputPatch<PosStruct, 3> patch,
147+
uint ix : SV_PrimitiveID)
148+
{
149+
CallFunction();
150+
PosStruct v;
151+
v.pos = patch[0].pos;
152+
return v;
153+
}
154+
155+
float4 a;
156+
157+
[shader("geometry")]
158+
[maxvertexcount(1)]
159+
void mainGS(triangle float4 array[3] : SV_Position, uint ix : SV_GSInstanceID,
160+
inout PointStream<PosStruct> OutputStream)
161+
{
162+
CallFunction();
163+
PosStruct s;
164+
s.pos = a;
165+
OutputStream.Append(s);
166+
OutputStream.RestartStrip();
167+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %dxc -I %hlsl_headers -T cs_6_9 -E main %s -verify
2+
3+
RWByteAddressBuffer inbuf;
4+
5+
[numthreads(4,1,1)]
6+
void main() {
7+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 0, 0)]] mat;
8+
9+
// expected-error@+1{{intrinsic __builtin_LinAlg_MatrixLoadFromDescriptor potentially used by ''main'' requires shader model 6.10 or greater}}
10+
__builtin_LinAlg_MatrixLoadFromDescriptor(mat, inbuf, 1, 1, 1);
11+
}

0 commit comments

Comments
 (0)