-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathwmma_test_kernel.h
More file actions
121 lines (109 loc) · 4.84 KB
/
wmma_test_kernel.h
File metadata and controls
121 lines (109 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
//
// Copyright (c) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
//
// Wave Matrix Multiply Accumulate (WMMA) using HIP compiler intrinsic
// Does a matrix multiplication of two 16x16, fp16 matrices, and stores them into a 16x16 fp16 result matrix
// Use frag_type as an alias of the internal clang vector type of 16 fp16 values
#if __gfx1030__ || __gfx1031__ || __gfx1032__ || __gfx1033__ || __gfx1034__ || __gfx1035__ || __gfx1036__
#define __gfx10__
#endif
#if __gfx1100__ || __gfx1101__ || __gfx1102__ || __gfx1103__ || __gfx1150__ || __gfx1151__
#define __gfx11__
#endif
#if __gfx1200__ || __gfx1201__
#define __gfx12__
#endif
#if defined(__gfx12__)
#define WMMA_DATA_WIDTH 8
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
typedef float frag_type_c __attribute__( ( ext_vector_type( 8 ) ) );
typedef __fp16 half_2 __attribute__( ( ext_vector_type( 2 ) ) );
#else
#define WMMA_DATA_WIDTH 16
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
typedef __fp16 frag_type_c __attribute__( ( ext_vector_type( 16 ) ) );
#endif
__device__ half_2 packFp32s( float a, float b ) { return __builtin_amdgcn_cvt_pkrtz( a, b ); }
extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
{
const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
// a_frag will store one column of the 16x16 matrix tile
// b_frag will store one row of the 16x16 matrix tile
frag_type a_frag;
frag_type b_frag;
// initialize c fragment to 0
frag_type_c c_frag = {};
const int laneWrapped = lIdx % 16;
const int laneGroup = lIdx / 16;
#if defined( __gfx12__ )
#if 1
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
b_frag[ele] = b[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped];
a_frag[ele] = a[16 * laneWrapped + ( ele + laneGroup * WMMA_DATA_WIDTH )];
}
#else
{//with __builtin_amdgcn_cvt_pkrtz
half_2* a_ptr = reinterpret_cast<half_2*>( &a_frag );
half_2* b_ptr = reinterpret_cast<half_2*>( &b_frag );
for( int ele = 0; ele < WMMA_DATA_WIDTH / 2; ++ele )
{
const int e0 = ele * 2 + 0;
const int e1 = ele * 2 + 1;
b_ptr[ele] = packFp32s( b[16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped], b[16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped] );
a_ptr[ele] = packFp32s( a[16 * laneWrapped + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a[16 * laneWrapped + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
}
}
#endif
#else
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
b_frag[ele] = b[16 * ele + laneWrapped];
a_frag[ele] = a[16 * laneWrapped + ele];
}
#endif
// call the WMMA compiler intrinsic
// more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
// more details available in the RDNA4 ISA guide - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
// the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
#if defined( __gfx12__ )
c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
#else
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( a_frag, b_frag, c_frag, false );
#endif
#if defined( __gfx12__ )
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
c[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped] = c_frag[ele];
}
#else
for( int ele = 0; ele < 8; ++ele )
{
const int r = ele * 2 + ( lIdx / 16 );
// store results from unpacked c_frag output
c[16 * r + laneWrapped] = c_frag[ele * 2];
// if OPSEL was set to "true", the line above would instead be
// c[16 * r + laneWrapped] = c_frag[ele*2 + 1];
}
#endif
}