@@ -4,15 +4,7 @@ ByteAddressBuffer matrix_buffer;
44ByteAddressBuffer bias_buffer;
55RWByteAddressBuffer rw_matrix_buffer;
66
7- // CHECK: define void @ps_main()
8- // CHECK: call <4 x float> @dx.op.matVecMul
9- // CHECK: call <4 x float> @dx.op.matVecMulAdd
10- // CHECK: call void @dx.op.outerProductAccumulate
11- // CHECK: call void @dx.op.vectorAccumulate
12-
13- [Shader ("pixel" )]
14- void ps_main ()
15- {
7+ void UseCoopVec () {
168 vector <float , 4 > output_vector;
179 static const uint is_output_unsigned = 0 ;
1810
@@ -59,6 +51,18 @@ void ps_main()
5951 va_matrix_offset);
6052}
6153
54+ // CHECK: define void @ps_main()
55+ // CHECK: call <4 x float> @dx.op.matVecMul
56+ // CHECK: call <4 x float> @dx.op.matVecMulAdd
57+ // CHECK: call void @dx.op.outerProductAccumulate
58+ // CHECK: call void @dx.op.vectorAccumulate
59+
60+ [Shader ("pixel" )]
61+ void ps_main ()
62+ {
63+ UseCoopVec ();
64+ }
65+
6266// CHECK: define void @cs_main()
6367// CHECK: call <4 x float> @dx.op.matVecMul
6468// CHECK: call <4 x float> @dx.op.matVecMulAdd
@@ -69,50 +73,7 @@ void ps_main()
6973[NumThreads (1 ,1 ,1 )]
7074void cs_main ()
7175{
72- vector <float , 4 > output_vector;
73- static const uint is_output_unsigned = 0 ;
74-
75- vector <float , 4 > input_vector;
76- const uint is_input_unsigned = 0 ;
77- const uint input_interpretation = 9 ; /*F32*/
78-
79- const uint matrix_offset = 0 ;
80- const uint matrix_interpretation = 9 ; /*F32*/
81- const uint matrix_dimM = 4 ;
82- const uint matrix_dimK = 4 ;
83- const uint matrix_layout = 0 ; /*RowMajor*/
84- const bool matrix_is_transposed = false ;
85- const uint matrix_stride = 64 ;
86-
87- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
88- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
89- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
90- matrix_is_transposed, matrix_stride);
91-
92- const uint bias_offset = 0 ;
93- const uint bias_interpretation = 9 ; /*F32*/
94-
95- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
96- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
97- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
98- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
99- bias_interpretation);
100-
101- vector <uint , 8 > input_vector1;
102- vector <uint , 8 > input_vector2;
103- const uint opa_matrix_offset = 0 ;
104- const uint opa_matrix_interpretation = 5 ; /*U32*/
105- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
106- const uint opa_matrix_stride = 64 ;
107-
108- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
109- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
110- opa_matrix_layout, opa_matrix_stride);
111-
112- const uint va_matrix_offset = 0 ;
113-
114- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
115- va_matrix_offset);
76+ UseCoopVec ();
11677}
11778
11879// CHECK: define void @vs_main()
@@ -123,51 +84,9 @@ void cs_main()
12384
12485[Shader ("vertex" )]
12586void vs_main ()
126- {
127- vector <float , 4 > output_vector;
128- static const uint is_output_unsigned = 0 ;
129-
130- vector <float , 4 > input_vector;
131- const uint is_input_unsigned = 0 ;
132- const uint input_interpretation = 9 ; /*F32*/
133-
134- const uint matrix_offset = 0 ;
135- const uint matrix_interpretation = 9 ; /*F32*/
136- const uint matrix_dimM = 4 ;
137- const uint matrix_dimK = 4 ;
138- const uint matrix_layout = 0 ; /*RowMajor*/
139- const bool matrix_is_transposed = false ;
140- const uint matrix_stride = 64 ;
141-
142- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
143- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
144- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
145- matrix_is_transposed, matrix_stride);
146-
147- const uint bias_offset = 0 ;
148- const uint bias_interpretation = 9 ; /*F32*/
149-
150- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
151- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
152- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
153- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
154- bias_interpretation);
155-
156- vector <uint , 8 > input_vector1;
157- vector <uint , 8 > input_vector2;
158- const uint opa_matrix_offset = 0 ;
159- const uint opa_matrix_interpretation = 5 ; /*U32*/
160- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
161- const uint opa_matrix_stride = 64 ;
162-
163- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
164- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
165- opa_matrix_layout, opa_matrix_stride);
166-
167- const uint va_matrix_offset = 0 ;
168-
169- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
170- va_matrix_offset); }
87+ {
88+ UseCoopVec ();
89+ }
17190
17291struct MyRecord{
17392 uint a;
@@ -183,50 +102,7 @@ struct MyRecord{
183102[NodeLaunch ("thread" )]
184103void ns_main (ThreadNodeInputRecord<MyRecord> input)
185104{
186- vector <float , 4 > output_vector;
187- static const uint is_output_unsigned = 0 ;
188-
189- vector <float , 4 > input_vector;
190- const uint is_input_unsigned = 0 ;
191- const uint input_interpretation = 9 ; /*F32*/
192-
193- const uint matrix_offset = 0 ;
194- const uint matrix_interpretation = 9 ; /*F32*/
195- const uint matrix_dimM = 4 ;
196- const uint matrix_dimK = 4 ;
197- const uint matrix_layout = 0 ; /*RowMajor*/
198- const bool matrix_is_transposed = false ;
199- const uint matrix_stride = 64 ;
200-
201- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
202- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
203- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
204- matrix_is_transposed, matrix_stride);
205-
206- const uint bias_offset = 0 ;
207- const uint bias_interpretation = 9 ; /*F32*/
208-
209- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
210- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
211- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
212- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
213- bias_interpretation);
214-
215- vector <uint , 8 > input_vector1;
216- vector <uint , 8 > input_vector2;
217- const uint opa_matrix_offset = 0 ;
218- const uint opa_matrix_interpretation = 5 ; /*U32*/
219- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
220- const uint opa_matrix_stride = 64 ;
221-
222- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
223- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
224- opa_matrix_layout, opa_matrix_stride);
225-
226- const uint va_matrix_offset = 0 ;
227-
228- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
229- va_matrix_offset);
105+ UseCoopVec ();
230106}
231107
232108// Vertex shader output structure
@@ -250,49 +126,6 @@ struct GS_OUT {
250126[maxvertexcount (3 )]
251127void gs_main (point VS_OUT input[1 ],
252128 inout TriangleStream <GS_OUT> OutputStream)
253- {
254- vector <float , 4 > output_vector;
255- static const uint is_output_unsigned = 0 ;
256-
257- vector <float , 4 > input_vector;
258- const uint is_input_unsigned = 0 ;
259- const uint input_interpretation = 9 ; /*F32*/
260-
261- const uint matrix_offset = 0 ;
262- const uint matrix_interpretation = 9 ; /*F32*/
263- const uint matrix_dimM = 4 ;
264- const uint matrix_dimK = 4 ;
265- const uint matrix_layout = 0 ; /*RowMajor*/
266- const bool matrix_is_transposed = false ;
267- const uint matrix_stride = 64 ;
268-
269- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
270- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
271- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
272- matrix_is_transposed, matrix_stride);
273-
274- const uint bias_offset = 0 ;
275- const uint bias_interpretation = 9 ; /*F32*/
276-
277- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
278- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
279- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
280- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
281- bias_interpretation);
282-
283- vector <uint , 8 > input_vector1;
284- vector <uint , 8 > input_vector2;
285- const uint opa_matrix_offset = 0 ;
286- const uint opa_matrix_interpretation = 5 ; /*U32*/
287- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
288- const uint opa_matrix_stride = 64 ;
289-
290- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
291- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
292- opa_matrix_layout, opa_matrix_stride);
293-
294- const uint va_matrix_offset = 0 ;
295-
296- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
297- va_matrix_offset);
129+ {
130+ UseCoopVec ();
298131}
0 commit comments