-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathless_slow_sm90a.ptx
More file actions
579 lines (500 loc) · 30.3 KB
/
less_slow_sm90a.ptx
File metadata and controls
579 lines (500 loc) · 30.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
/**
* less_slow_sm90a.ptx
*
* Micro-kernels for building a performance-first mindset for CUDA-capable
* GPUs using Parallel Thread eXecution (PTX) Intermediate Representation (IR)
* for for Hopper-generation Nvidia GPUs with Warp-Group-level MMA (WGMMA).
*
* ? You should start at `less_slow.cu` before reading this file.
* ? You should start at `less_slow_sm70.ptx` before reading this file.
* ? Also read intro to PTX: https://docs.nvidia.com/cuda/parallel-thread-execution/
* ? Check the PTX ISA: https://docs.nvidia.com/cuda/pdf/ptx_isa_8.5.pdf
*
* You can validate this file by asking the Nvidia PTX Assembler to compile it
* to `.cubin` for some target architecture:
*
* $ ptxas -o less_slow_sm90a_from_ptx.cubin -arch=sm_90a less_slow_sm90a.ptx
* $ cuobjdump -sass less_slow_sm90a_from_ptx.cubin | grep -i mma
*
* Assuming how aggressively NVCC unrolls loops and the number of kernels in
* this file, you may want to deduplicate them:
*
* $ cuobjdump -sass less_slow_sm90a_from_ptx.cubin | grep -i mma | \
* $ sed -r 's/\/\*[^*]+\*\///g' | \
* $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \
* $ sort -u
*/
.version 8.0 // PTX version 8.0 for Hopper GPUs
.target sm_90a // Target architecture (SM_90a - Hopper GPUs)
.address_size 64 // 64-bit addressing
/**
* Let's define some global memory buffers, visible on both device and host
* side, to output multiplication results.
*/
.visible .global .align 4 .s32 dummy_sink_s32[32];
.visible .global .align 4 .f32 dummy_sink_f32[32];
.visible .entry tops_f16f32_sm90tc_m64n256k16_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of this MMA
.reg .f32 accum<128>;
// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;
// F16 variables will be stored in B16 slots in 2D arrays:
// NVCC prefers to demote the slots down to `.b8` and uses `.align 2`
// to guarantee correct alignment.
.shared .b16 tile_a[64][16];
.shared .b16 tile_b[256][16];
// Define registers to store shared memory addresses
.reg .u64 addr_a, addr_b;
// Load the address of the shared memory tiles
mov.u64 addr_a, tile_a;
cvta.shared.u64 addr_a, addr_a;
mov.u64 addr_b, tile_b;
cvta.shared.u64 addr_b, addr_b;
// Shift address right by 4 bits
and.b64 addr_a, addr_a, 0x3FFFF;
and.b64 addr_b, addr_b, 0x3FFFF;
shr.u64 addr_a, addr_a, 4;
shr.u64 addr_b, addr_b, 4;
// Define the shape of M x K matrix A
mov.u64 desc_a, addr_a;
or.b64 desc_a, desc_a, ((128 >> 4) << 16); // Leading dimension
or.b64 desc_a, desc_a, ((256 >> 4) << 32); // Stride dimension
// Define the shape of K x N matrix B
mov.u64 desc_b, addr_b;
or.b64 desc_b, desc_b, ((4096 >> 4) << 16); // Leading-dimension info
or.b64 desc_b, desc_b, ((128 >> 4) << 32); // Stride info
// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;
// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;
// Set up loop counter and loop limit to fill accumulators
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulator registers:
mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0;
mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0;
mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0;
mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0;
mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0;
mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0;
mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0;
mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0;
mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0;
mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0;
mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0;
mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0;
mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0;
mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0;
mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0;
mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0;
mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0;
mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0;
mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0;
mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0;
mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0;
mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0;
mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0;
mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0;
mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0;
mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0;
mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0;
mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0;
mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0;
mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0;
mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0;
mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0;
// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;
wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
1, 1, 1, 0, 0;
wgmma.commit_group.sync.aligned;
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_exit:
// Zero argument means - wait for all committed WGMMAs to complete.
wgmma.wait_group.sync.aligned 0;
// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.f32 [dummy_sink_f32], accum0;
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
st.global.volatile.f32 [dummy_sink_f32+8], accum126;
st.global.volatile.f32 [dummy_sink_f32+12], accum127;
ret;
}
.visible .entry tops_bf16f32_sm90tc_m64n256k16_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of this MMA
.reg .f32 accum<128>;
// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;
// BF16 variables will be stored in B16 slots in 2D arrays:
// NVCC prefers to demote the slots down to `.b8` and uses `.align 2`
// to guarantee correct alignment.
.shared .b16 tile_a[64][16];
.shared .b16 tile_b[256][16];
// Define registers to store shared memory addresses
.reg .u64 addr_a, addr_b;
// Load the address of the shared memory tiles
mov.u64 addr_a, tile_a;
cvta.shared.u64 addr_a, addr_a;
mov.u64 addr_b, tile_b;
cvta.shared.u64 addr_b, addr_b;
// Shift address right by 4 bits
and.b64 addr_a, addr_a, 0x3FFFF;
and.b64 addr_b, addr_b, 0x3FFFF;
shr.u64 addr_a, addr_a, 4;
shr.u64 addr_b, addr_b, 4;
// Define the shape of M x K matrix A
mov.u64 desc_a, addr_a;
or.b64 desc_a, desc_a, ((128 >> 4) << 16); // Leading dimension
or.b64 desc_a, desc_a, ((256 >> 4) << 32); // Stride dimension
// Define the shape of K x N matrix B
mov.u64 desc_b, addr_b;
or.b64 desc_b, desc_b, ((4096 >> 4) << 16); // Leading-dimension info
or.b64 desc_b, desc_b, ((128 >> 4) << 32); // Stride info
// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;
// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;
// Set up loop counter and loop limit to fill accumulators
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulator registers:
mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0;
mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0;
mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0;
mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0;
mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0;
mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0;
mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0;
mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0;
mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0;
mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0;
mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0;
mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0;
mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0;
mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0;
mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0;
mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0;
mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0;
mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0;
mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0;
mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0;
mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0;
mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0;
mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0;
mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0;
mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0;
mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0;
mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0;
mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0;
mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0;
mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0;
mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0;
mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0;
// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;
wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
1, 1, 1, 0, 0;
wgmma.commit_group.sync.aligned;
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_exit:
// Zero argument means - wait for all committed WGMMAs to complete.
wgmma.wait_group.sync.aligned 0;
// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.f32 [dummy_sink_f32], accum0;
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
st.global.volatile.f32 [dummy_sink_f32+8], accum126;
st.global.volatile.f32 [dummy_sink_f32+12], accum127;
ret;
}
.visible .entry tops_tf32f32_sm90tc_m64n256k8_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of this MMA
.reg .f32 accum<128>;
// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;
// TF32 (19-bits) variables will be stored in B32 slots in 2D arrays,
// and shifted right by 13 bits for arithmetics
.shared .b32 tile_a[64][8];
.shared .b32 tile_b[256][8];
// Define registers to store shared memory addresses
.reg .u64 addr_a, addr_b;
// Load the address of the shared memory tiles
mov.u64 addr_a, tile_a;
cvta.shared.u64 addr_a, addr_a;
mov.u64 addr_b, tile_b;
cvta.shared.u64 addr_b, addr_b;
// Shift address right by 4 bits
and.b64 addr_a, addr_a, 0x3FFFF;
and.b64 addr_b, addr_b, 0x3FFFF;
shr.u64 addr_a, addr_a, 4;
shr.u64 addr_b, addr_b, 4;
// Define the shape of M x K matrix A
mov.u64 desc_a, addr_a;
or.b64 desc_a, desc_a, ((128 >> 4) << 16); // Leading dimension
or.b64 desc_a, desc_a, ((256 >> 4) << 32); // Stride dimension
// Define the shape of K x N matrix B
mov.u64 desc_b, addr_b;
or.b64 desc_b, desc_b, ((4096 >> 4) << 16); // Leading-dimension info
or.b64 desc_b, desc_b, ((128 >> 4) << 32); // Stride info
// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;
// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;
// Set up loop counter and loop limit to fill accumulators
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulator registers:
mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0;
mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0;
mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0;
mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0;
mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0;
mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0;
mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0;
mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0;
mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0;
mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0;
mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0;
mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0;
mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0;
mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0;
mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0;
mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0;
mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0;
mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0;
mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0;
mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0;
mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0;
mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0;
mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0;
mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0;
mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0;
mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0;
mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0;
mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0;
mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0;
mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0;
mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0;
mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0;
// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;
wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
1, 1, 1; //! We can't transpose TF32 inputs, so need to pass fewer arguments.
wgmma.commit_group.sync.aligned;
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_exit:
// Zero argument means - wait for all committed WGMMAs to complete.
wgmma.wait_group.sync.aligned 0;
// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.f32 [dummy_sink_f32], accum0;
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
st.global.volatile.f32 [dummy_sink_f32+8], accum126;
st.global.volatile.f32 [dummy_sink_f32+12], accum127;
ret;
}
/**
* This results in massive performance gains on Hopper:
* - 16x16x8 MMA computed by individual warps: 74 T
* - 64x16x8 WMMA computed by four warps together: 300 T
* - 64x256x8 WGMMA computed by four warps together: 4.7 P ?!
*
* There are also "structured-sparse" variants of those instructions, in case
* half of our entries are zeros! Those, however, simply expand the last
* dimension by 2x, making the instructions no more usable for small matrices.
*/
.visible .entry tops_b1i32and_sm90tc_m64n256k256_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of the MMA operation
.reg .s32 accum<128>;
// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;
// B1 variables will be packed in byte-sized groups
.shared .b8 tile_a[64][32];
.shared .b8 tile_b[256][32];
// Define registers to store shared memory addresses
.reg .u64 addr_a, addr_b;
// Load the address of the shared memory tiles
mov.u64 addr_a, tile_a;
cvta.shared.u64 addr_a, addr_a;
mov.u64 addr_b, tile_b;
cvta.shared.u64 addr_b, addr_b;
// Shift address right by 4 bits
and.b64 addr_a, addr_a, 0x3FFFF;
and.b64 addr_b, addr_b, 0x3FFFF;
shr.u64 addr_a, addr_a, 4;
shr.u64 addr_b, addr_b, 4;
// Define the shape of M x K matrix A
mov.u64 desc_a, addr_a;
or.b64 desc_a, desc_a, ((128 >> 4) << 16); // Leading dimension
or.b64 desc_a, desc_a, ((256 >> 4) << 32); // Stride dimension
// Define the shape of K x N matrix B
mov.u64 desc_b, addr_b;
or.b64 desc_b, desc_b, ((4096 >> 4) << 16); // Leading-dimension info
or.b64 desc_b, desc_b, ((128 >> 4) << 32); // Stride info
// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;
// Predicate registers for conditional branching (loop exit) and scale flag
.reg .pred exit_predicate, scale_d;
// Set up loop counter and loop limit
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulators, as registers may contain noise
mov.s32 accum0, 0; mov.s32 accum1, 0; mov.s32 accum2, 0; mov.s32 accum3, 0;
mov.s32 accum4, 0; mov.s32 accum5, 0; mov.s32 accum6, 0; mov.s32 accum7, 0;
mov.s32 accum8, 0; mov.s32 accum9, 0; mov.s32 accum10, 0; mov.s32 accum11, 0;
mov.s32 accum12, 0; mov.s32 accum13, 0; mov.s32 accum14, 0; mov.s32 accum15, 0;
mov.s32 accum16, 0; mov.s32 accum17, 0; mov.s32 accum18, 0; mov.s32 accum19, 0;
mov.s32 accum20, 0; mov.s32 accum21, 0; mov.s32 accum22, 0; mov.s32 accum23, 0;
mov.s32 accum24, 0; mov.s32 accum25, 0; mov.s32 accum26, 0; mov.s32 accum27, 0;
mov.s32 accum28, 0; mov.s32 accum29, 0; mov.s32 accum30, 0; mov.s32 accum31, 0;
mov.s32 accum32, 0; mov.s32 accum33, 0; mov.s32 accum34, 0; mov.s32 accum35, 0;
mov.s32 accum36, 0; mov.s32 accum37, 0; mov.s32 accum38, 0; mov.s32 accum39, 0;
mov.s32 accum40, 0; mov.s32 accum41, 0; mov.s32 accum42, 0; mov.s32 accum43, 0;
mov.s32 accum44, 0; mov.s32 accum45, 0; mov.s32 accum46, 0; mov.s32 accum47, 0;
mov.s32 accum48, 0; mov.s32 accum49, 0; mov.s32 accum50, 0; mov.s32 accum51, 0;
mov.s32 accum52, 0; mov.s32 accum53, 0; mov.s32 accum54, 0; mov.s32 accum55, 0;
mov.s32 accum56, 0; mov.s32 accum57, 0; mov.s32 accum58, 0; mov.s32 accum59, 0;
mov.s32 accum60, 0; mov.s32 accum61, 0; mov.s32 accum62, 0; mov.s32 accum63, 0;
mov.s32 accum64, 0; mov.s32 accum65, 0; mov.s32 accum66, 0; mov.s32 accum67, 0;
mov.s32 accum68, 0; mov.s32 accum69, 0; mov.s32 accum70, 0; mov.s32 accum71, 0;
mov.s32 accum72, 0; mov.s32 accum73, 0; mov.s32 accum74, 0; mov.s32 accum75, 0;
mov.s32 accum76, 0; mov.s32 accum77, 0; mov.s32 accum78, 0; mov.s32 accum79, 0;
mov.s32 accum80, 0; mov.s32 accum81, 0; mov.s32 accum82, 0; mov.s32 accum83, 0;
mov.s32 accum84, 0; mov.s32 accum85, 0; mov.s32 accum86, 0; mov.s32 accum87, 0;
mov.s32 accum88, 0; mov.s32 accum89, 0; mov.s32 accum90, 0; mov.s32 accum91, 0;
mov.s32 accum92, 0; mov.s32 accum93, 0; mov.s32 accum94, 0; mov.s32 accum95, 0;
mov.s32 accum96, 0; mov.s32 accum97, 0; mov.s32 accum98, 0; mov.s32 accum99, 0;
mov.s32 accum100, 0; mov.s32 accum101, 0; mov.s32 accum102, 0; mov.s32 accum103, 0;
mov.s32 accum104, 0; mov.s32 accum105, 0; mov.s32 accum106, 0; mov.s32 accum107, 0;
mov.s32 accum108, 0; mov.s32 accum109, 0; mov.s32 accum110, 0; mov.s32 accum111, 0;
mov.s32 accum112, 0; mov.s32 accum113, 0; mov.s32 accum114, 0; mov.s32 accum115, 0;
mov.s32 accum116, 0; mov.s32 accum117, 0; mov.s32 accum118, 0; mov.s32 accum119, 0;
mov.s32 accum120, 0; mov.s32 accum121, 0; mov.s32 accum122, 0; mov.s32 accum123, 0;
mov.s32 accum124, 0; mov.s32 accum125, 0; mov.s32 accum126, 0; mov.s32 accum127, 0;
// Initialize scale flag (controls operand scaling or additive bias behavior)
mov.pred scale_d, 1;
// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;
wgmma.mma_async.sync.aligned.m64n256k256.s32.b1.b1.and.popc
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
scale_d;
wgmma.commit_group.sync.aligned;
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_exit:
// Zero argument means - wait for all committed WGMMAs to complete.
wgmma.wait_group.sync.aligned 0;
// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.s32 [dummy_sink_s32], accum0;
st.global.volatile.s32 [dummy_sink_s32+4], accum1;
st.global.volatile.s32 [dummy_sink_s32+8], accum2;
st.global.volatile.s32 [dummy_sink_s32+12], accum3;
ret;
}