@@ -5,50 +5,50 @@ ByteAddressBuffer bias_buffer;
55RWByteAddressBuffer rw_matrix_buffer;
66
77void UseCoopVec () {
8- vector <float , 4 > output_vector;
9- static const uint is_output_unsigned = 0 ;
10-
11- vector <float , 4 > input_vector;
12- const uint is_input_unsigned = 0 ;
13- const uint input_interpretation = 9 ; /*F32*/
14-
15- const uint matrix_offset = 0 ;
16- const uint matrix_interpretation = 9 ; /*F32*/
17- const uint matrix_dimM = 4 ;
18- const uint matrix_dimK = 4 ;
19- const uint matrix_layout = 0 ; /*RowMajor*/
20- const bool matrix_is_transposed = false ;
21- const uint matrix_stride = 64 ;
22-
23- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
24- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
25- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
26- matrix_is_transposed, matrix_stride);
27-
28- const uint bias_offset = 0 ;
29- const uint bias_interpretation = 9 ; /*F32*/
30-
31- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
32- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
33- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
34- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
35- bias_interpretation);
36-
37- vector <uint , 8 > input_vector1;
38- vector <uint , 8 > input_vector2;
39- const uint opa_matrix_offset = 0 ;
40- const uint opa_matrix_interpretation = 5 ; /*U32*/
41- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
42- const uint opa_matrix_stride = 64 ;
43-
44- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
45- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
46- opa_matrix_layout, opa_matrix_stride);
47-
48- const uint va_matrix_offset = 0 ;
49-
50- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
51- va_matrix_offset);
8+ vector <float , 4 > output_vector;
9+ static const uint is_output_unsigned = 0 ;
10+
11+ vector <float , 4 > input_vector;
12+ const uint is_input_unsigned = 0 ;
13+ const uint input_interpretation = 9 ; /*F32*/
14+
15+ const uint matrix_offset = 0 ;
16+ const uint matrix_interpretation = 9 ; /*F32*/
17+ const uint matrix_dimM = 4 ;
18+ const uint matrix_dimK = 4 ;
19+ const uint matrix_layout = 0 ; /*RowMajor*/
20+ const bool matrix_is_transposed = false ;
21+ const uint matrix_stride = 64 ;
22+
23+ __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
24+ is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
25+ matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
26+ matrix_is_transposed, matrix_stride);
27+
28+ const uint bias_offset = 0 ;
29+ const uint bias_interpretation = 9 ; /*F32*/
30+
31+ __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
32+ is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
33+ matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
34+ matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
35+ bias_interpretation);
36+
37+ vector <uint , 8 > input_vector1;
38+ vector <uint , 8 > input_vector2;
39+ const uint opa_matrix_offset = 0 ;
40+ const uint opa_matrix_interpretation = 5 ; /*U32*/
41+ const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
42+ const uint opa_matrix_stride = 64 ;
43+
44+ __builtin_OuterProductAccumulate (input_vector1, input_vector2,
45+ rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
46+ opa_matrix_layout, opa_matrix_stride);
47+
48+ const uint va_matrix_offset = 0 ;
49+
50+ __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
51+ va_matrix_offset);
5252}
5353
5454// CHECK: define void @ps_main()
@@ -59,7 +59,7 @@ void UseCoopVec() {
5959
6060[Shader ("pixel" )]
6161void ps_main ()
62- {
62+ {
6363 UseCoopVec ();
6464}
6565
@@ -72,8 +72,8 @@ void ps_main()
7272[Shader ("compute" )]
7373[NumThreads (1 ,1 ,1 )]
7474void cs_main ()
75- {
76- UseCoopVec ();
75+ {
76+ UseCoopVec ();
7777}
7878
7979// CHECK: define void @vs_main()
@@ -85,11 +85,11 @@ void cs_main()
8585[Shader ("vertex" )]
8686void vs_main ()
8787{
88- UseCoopVec ();
88+ UseCoopVec ();
8989}
9090
9191struct MyRecord{
92- uint a;
92+ uint a;
9393};
9494
9595// CHECK: define void @ns_main()
@@ -101,8 +101,8 @@ struct MyRecord{
101101[Shader ("node" )]
102102[NodeLaunch ("thread" )]
103103void ns_main (ThreadNodeInputRecord<MyRecord> input)
104- {
105- UseCoopVec ();
104+ {
105+ UseCoopVec ();
106106}
107107
108108// Vertex shader output structure
@@ -125,7 +125,7 @@ struct GS_OUT {
125125[shader ("geometry" )]
126126[maxvertexcount (3 )]
127127void gs_main (point VS_OUT input[1 ],
128- inout TriangleStream <GS_OUT> OutputStream)
128+ inout TriangleStream <GS_OUT> OutputStream)
129129{
130130 UseCoopVec ();
131131}
0 commit comments