A lightweight, from-scratch Flash Attention CUDA implementation (forward pass only).
- Data Types: FP16, BF16, FP32
- Head Dimensions: 32, 64, 96, 128, 192, 256
- Causal Attention
- Grouped-Query Attention (
numHeadsQ % numHeadsKV == 0) - Fixed Length / Variable Length (
cu_seqlenspacked format) - Paged Attention (packed Q + paged KV cache)
- NVIDIA GPU (Turing or newer, compute capability >= 7.5)
- CUDA Toolkit >= 11.8
- PyTorch (with CUDA support)
- Python >= 3.8
- C++17 compatible compiler
git clone --recursive https://github.com/keith2018/TinyFA.git
cd TinyFA/python
pip install --no-build-isolation .By default, only HeadDim 64 and 128 are compiled. Use environment variables to enable additional head dimensions or target specific configurations:
# Enable additional head dimensions
TFA_TARGET_HEADDIM_32=1 TFA_TARGET_HEADDIM_256=1 pip install --no-build-isolation .
# Target a specific GPU architecture
TFA_TARGET_SM=sm80 pip install --no-build-isolation .
# Target a specific data type
TFA_TARGET_DTYPE=fp16 pip install --no-build-isolation .
# Combine for fastest compilation (only HeadDim=128, fp16, sm80)
TFA_TARGET_SM=sm80 TFA_TARGET_DTYPE=fp16 TFA_TARGET_HEADDIM_64=0 \
pip install --no-build-isolation .Achieves 94–96% of Dao-AILab/flash-attention throughput (forward pass).
- Device: NVIDIA A100-SXM4-40GB
- Configuration: batch=2, numHeads=32, headDim=128, SeqLen=4096
| Dtype | Causal | TinyFA (ms) | TinyFA (TFLOPS) | flash_attn (ms) | flash_attn (TFLOPS) | Relative |
|---|---|---|---|---|---|---|
| fp16 | False | 2.832 | 194.09 | 2.654 | 207.14 | 0.94x |
| fp16 | True | 1.636 | 168.01 | 1.557 | 176.54 | 0.95x |
| bf16 | False | 2.743 | 200.40 | 2.623 | 209.58 | 0.96x |
| bf16 | True | 1.624 | 169.28 | 1.537 | 178.82 | 0.95x |
cd benchmarks
# Default: fp16, head_dim=128, SeqLen=[512, 1024, 2048, 4096]
python benchmark.py --dtype fp16 --head-dim 128
# With causal mask
python benchmark.py --dtype fp16 --head-dim 128 --causal
# Sweep all combinations (fp16/bf16 x causal x head_dim)
python benchmark.py --sweepimport torch
from tiny_flash_attn import flash_attn_forward
# Q: [batch, seqQ, numHeadsQ, headDim]
# K: [batch, seqKV, numHeadsKV, headDim]
# V: [batch, seqKV, numHeadsKV, headDim]
Q = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")
K = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")
V = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")
O = flash_attn_forward(Q, K, V, is_causal=False)
# O: [batch, seqQ, numHeadsQ, headDim]from tiny_flash_attn import flash_attn_varlen_forward
# Q: [totalQ, numHeadsQ, headDim] (packed sequences)
# K: [totalKV, numHeadsKV, headDim]
# V: [totalKV, numHeadsKV, headDim]
# cu_seqlens_q: [batch + 1], int32, cumulative sequence lengths
# cu_seqlens_kv: [batch + 1], int32
O = flash_attn_varlen_forward(
Q, K, V,
cu_seqlens_q, cu_seqlens_kv,
max_seqlen_q, max_seqlen_kv,
is_causal=False
)# GQA: numHeadsQ must be divisible by numHeadsKV
Q = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda") # 32 query heads
K = torch.randn(2, 1024, 8, 128, dtype=torch.float16, device="cuda") # 8 KV heads (4 groups)
V = torch.randn(2, 1024, 8, 128, dtype=torch.float16, device="cuda")
O = flash_attn_forward(Q, K, V, is_causal=True)from tiny_flash_attn import flash_attn_paged_varlen_forward
# Q: [totalTokens, numHeadsQ, headDim] (packed sequences)
# k_cache_pool: [numBlocks, numKvHeads, pageSize, headDim]
# v_cache_pool: [numBlocks, numKvHeads, pageSize, headDim]
# block_table: [batchSize, maxBlocksPerSeq] int32
batch_size = 3
page_size = 16
num_kv_heads = 8
head_dim = 128
num_blocks = 256
# mixed prefill + decode batch
cu_seqlens_q = torch.tensor([0, 1, 33, 34], dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.tensor([0, 64, 96, 128], dtype=torch.int32, device="cuda")
total_q = 34
Q = torch.randn(total_q, 32, head_dim, dtype=torch.float16, device="cuda")
k_cache = torch.randn(num_blocks, num_kv_heads, page_size, head_dim,
dtype=torch.float16, device="cuda")
v_cache = torch.randn(num_blocks, num_kv_heads, page_size, head_dim,
dtype=torch.float16, device="cuda")
# block_table maps each sequence to physical pages
max_blocks_per_seq = 8
block_table = torch.zeros(batch_size, max_blocks_per_seq, dtype=torch.int32, device="cuda")
# fill block_table with physical block IDs...
# forward
O = flash_attn_paged_varlen_forward(
Q, k_cache, v_cache,
cu_seqlens_q, cu_seqlens_kv,
block_table,
max_seqlen_q=32, max_seqlen_kv=64,
page_size=page_size,
is_causal=True
)#include "flash_attn/flash_api.cuh"
// Fixed length
tfa::flashAttn<__half>(Q, K, V, O,
batch, seqLenQ, seqLenKV,
numHeadsQ, numHeadsKV, headDim,
isCausal, stream);
// Variable length
tfa::flashAttnVarLen<__half>(Q, K, V, O,
cu_seqlens_q, cu_seqlens_kv,
batchSize, maxSeqLenQ, maxSeqLenKV,
numHeadsQ, numHeadsKV, headDim,
isCausal, stream);
// Paged VarLen (packed Q + paged KV cache)
tfa::flashAttnPagedVarLen<__half>(
Q, O, // packed Q/O: [totalTokens, numHeadsQ, headDim]
kCachePool, vCachePool, // [numBlocks, numKvHeads, pageSize, headDim]
cu_seqlens_q, cu_seqlens_kv, // [batchSize + 1]
blockTable, // [batchSize, maxBlocksPerSeq]
batchSize, maxSeqLenQ, maxSeqLenKV,
numHeadsQ, numHeadsKV, headDim,
pageSize, maxBlocksPerSeq,
isCausal, stream);git submodule update --init --recursive
mkdir build && cd build
cmake .. -DTFA_BUILD_TESTS=ON
make -j$(nproc)
# Run tests
ctest --test-dir . --output-on-failure- NVIDIA CUTLASS — CuTe sublibrary for tensor core abstractions
- Forward pass only
- No dropout
This code is licensed under the MIT License (see LICENSE).