@@ -9,61 +9,59 @@ def matmul(A,B,C,_):
99 tmp += A [i ,k ] * B [k ,j ]
1010 C [i ,j ] = tmp
1111
12- @njit (void (float64 [:,:],float64 [:,:],float64 [:,:],numba .optional (int32 )), cache = True )
12+ @njit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:],numba .optional (int32 )), cache = True )
1313def matmul_numba_serial (A ,B ,C ,_ ):
1414 for i in range (A .shape [0 ]):
15- for j in range (B .shape [1 ]):
16- tmp = 0.
17- for k in range (A .shape [- 1 ]):
18- tmp += A [i ,k ] * B [k ,j ]
19- C [i ,j ] = tmp
15+ for k in range (A .shape [- 1 ]):
16+ for j in range (B .shape [1 ]):
17+ C [i ,j ] += A [i ,k ] * B [k ,j ]
2018
21- @njit (void (float64 [:,:],float64 [:,:],float64 [:,:],numba .optional (int32 )), parallel = True , nogil = True , cache = True )
19+ @njit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:],numba .optional (int32 )), parallel = True , nogil = True , cache = True )
2220def matmul_numba_cpu (A ,B ,C ,_ ):
2321 for i in prange (A .shape [0 ]):
24- for j in range (B .shape [1 ]):
25- tmp = 0.
26- for k in range (A .shape [1 ]):
27- tmp += A [i ,k ] * B [k ,j ]
28- C [i ,j ] = tmp
22+ for k in range (A .shape [1 ]):
23+ for j in range (B .shape [1 ]):
24+ C [i ,j ] += A [i ,k ] * B [k ,j ]
2925
3026
3127
32- @njit (void (float64 [:,:],float64 [:,:],float64 [:,:],int32 ), parallel = True , nogil = True , cache = True )
28+ @njit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:],int32 ), parallel = True , nogil = True , cache = True )
3329def matmul_numba_block_cpu (A ,B ,C , bs = 64 ):
34- niblocks = (A .shape [0 ]// bs ) + ((A .shape [0 ] % bs ) > 0 )
30+ N = A .shape [0 ]
31+ M = B .shape [1 ]
32+ K = A .shape [1 ]
33+ niblocks = (N // bs ) + ((N % bs ) > 0 )
3534 for ii in prange (0 ,niblocks ):
3635 i0 = ii * bs
37- imax = i0 + bs if i0 + bs < A . shape [ 0 ] else A . shape [ 0 ]
38- for jj in range (0 ,B . shape [ 1 ] ,bs ):
39- jmax = jj + bs if jj + bs < B . shape [ 1 ] else B . shape [ 1 ]
40- for kk in range (0 ,A . shape [ - 1 ] ,bs ):
41- kmax = kk + bs if kk + bs < A . shape [ - 1 ] else A . shape [ - 1 ]
36+ imax = min ( i0 + bs , N )
37+ for kk in range (0 ,K ,bs ):
38+ kmax = min ( kk + bs , K )
39+ for jj in range (0 ,M ,bs ):
40+ jmax = min ( jj + bs , M )
4241 for i in range (i0 ,imax ):
43- for j in range (jj ,jmax ):
44- tmp = 0.
45- for k in range (kk ,kmax ):
46- tmp += A [i ,k ] * B [k ,j ]
47- C [i ,j ] += tmp
42+ for k in range (kk ,kmax ):
43+ for j in range (jj ,jmax ):
44+ C [i ,j ] += A [i ,k ] * B [k ,j ]
4845
49- @njit (void (float64 [:,:],float64 [:,:],float64 [:,:],int32 ), parallel = False , nogil = True , cache = True )
46+ @njit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:],int32 ), parallel = False , nogil = True , cache = True )
5047def matmul_numba_block_serial (A ,B ,C , bs = 64 ):
51- niblocks = (A .shape [0 ]// bs ) + ((A .shape [0 ] % bs ) > 0 )
48+ N = A .shape [0 ]
49+ M = B .shape [1 ]
50+ K = A .shape [1 ]
51+ niblocks = (N // bs ) + ((N % bs ) > 0 )
5252 for ii in range (0 ,niblocks ):
5353 i0 = ii * bs
54- imax = i0 + bs if i0 + bs < A . shape [ 0 ] else A . shape [ 0 ]
55- for jj in range (0 ,B . shape [ 1 ] ,bs ):
56- jmax = jj + bs if jj + bs < B . shape [ 1 ] else B . shape [ 1 ]
57- for kk in range (0 ,A . shape [ - 1 ] ,bs ):
58- kmax = kk + bs if kk + bs < A . shape [ - 1 ] else A . shape [ - 1 ]
54+ imax = min ( i0 + bs , N )
55+ for kk in range (0 ,K ,bs ):
56+ kmax = min ( kk + bs , K )
57+ for jj in range (0 ,M ,bs ):
58+ jmax = min ( jj + bs , M )
5959 for i in range (i0 ,imax ):
60- for j in range (jj ,jmax ):
61- tmp = 0.
62- for k in range (kk ,kmax ):
63- tmp += A [i ,k ] * B [k ,j ]
64- C [i ,j ] += tmp
60+ for k in range (kk ,kmax ):
61+ for j in range (jj ,jmax ):
62+ C [i ,j ] += A [i ,k ] * B [k ,j ]
6563
66- @cuda .jit (void (float64 [:,:],float64 [:,:],float64 [:,:]), cache = True )
64+ @cuda .jit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:]), cache = True )
6765def matmul_numba_gpu (A ,B ,C ):
6866 i , j = cuda .grid (ndim = 2 )
6967 if i < C .shape [0 ] and j < C .shape [1 ]:
@@ -73,7 +71,7 @@ def matmul_numba_gpu(A,B,C):
7371 C [i ,j ] = tmp
7472
7573BLOCK_SIZE = 16
76- @cuda .jit (void (float64 [:,:],float64 [:,:],float64 [:,:]), cache = True )
74+ @cuda .jit (void (float64 [:,:: 1 ],float64 [:,:: 1 ],float64 [:,:]), cache = True )
7775def matmul_numba_block_gpu (A ,B ,C ):
7876
7977 bi = cuda .blockIdx .y
0 commit comments