Skip to content

keith2018/TinyFA

Repository files navigation

TinyFA

A lightweight, from-scratch Flash Attention CUDA implementation (forward pass only).

Features

  • Data Types: FP16, BF16, FP32
  • Head Dimensions: 32, 64, 96, 128, 192, 256
  • Causal Attention
  • Grouped-Query Attention (numHeadsQ % numHeadsKV == 0)
  • Fixed Length / Variable Length (cu_seqlens packed format)
  • Paged Attention (packed Q + paged KV cache)

Installation

Requirements

  • NVIDIA GPU (Turing or newer, compute capability >= 7.5)
  • CUDA Toolkit >= 11.8
  • PyTorch (with CUDA support)
  • Python >= 3.8
  • C++17 compatible compiler

Install from source

git clone --recursive https://github.com/keith2018/TinyFA.git
cd TinyFA/python
pip install --no-build-isolation .

Faster compilation

By default, only HeadDim 64 and 128 are compiled. Use environment variables to enable additional head dimensions or target specific configurations:

# Enable additional head dimensions
TFA_TARGET_HEADDIM_32=1 TFA_TARGET_HEADDIM_256=1 pip install --no-build-isolation .

# Target a specific GPU architecture
TFA_TARGET_SM=sm80 pip install --no-build-isolation .

# Target a specific data type
TFA_TARGET_DTYPE=fp16 pip install --no-build-isolation .

# Combine for fastest compilation (only HeadDim=128, fp16, sm80)
TFA_TARGET_SM=sm80 TFA_TARGET_DTYPE=fp16 TFA_TARGET_HEADDIM_64=0 \
pip install --no-build-isolation .

Benchmarks

Achieves 94–96% of Dao-AILab/flash-attention throughput (forward pass).

  • Device: NVIDIA A100-SXM4-40GB
  • Configuration: batch=2, numHeads=32, headDim=128, SeqLen=4096
Dtype Causal TinyFA (ms) TinyFA (TFLOPS) flash_attn (ms) flash_attn (TFLOPS) Relative
fp16 False 2.832 194.09 2.654 207.14 0.94x
fp16 True 1.636 168.01 1.557 176.54 0.95x
bf16 False 2.743 200.40 2.623 209.58 0.96x
bf16 True 1.624 169.28 1.537 178.82 0.95x

Run benchmarks

cd benchmarks

# Default: fp16, head_dim=128, SeqLen=[512, 1024, 2048, 4096]
python benchmark.py --dtype fp16 --head-dim 128

# With causal mask
python benchmark.py --dtype fp16 --head-dim 128 --causal

# Sweep all combinations (fp16/bf16 x causal x head_dim)
python benchmark.py --sweep

Usage

Fixed-length Attention

import torch
from tiny_flash_attn import flash_attn_forward

# Q: [batch, seqQ,  numHeadsQ,  headDim]
# K: [batch, seqKV, numHeadsKV, headDim]
# V: [batch, seqKV, numHeadsKV, headDim]
Q = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")
K = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")
V = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")

O = flash_attn_forward(Q, K, V, is_causal=False)
# O: [batch, seqQ, numHeadsQ, headDim]

Variable-length Attention

from tiny_flash_attn import flash_attn_varlen_forward

# Q: [totalQ,  numHeadsQ,  headDim]  (packed sequences)
# K: [totalKV, numHeadsKV, headDim]
# V: [totalKV, numHeadsKV, headDim]
# cu_seqlens_q:  [batch + 1], int32, cumulative sequence lengths
# cu_seqlens_kv: [batch + 1], int32

O = flash_attn_varlen_forward(
    Q, K, V,
    cu_seqlens_q, cu_seqlens_kv,
    max_seqlen_q, max_seqlen_kv,
    is_causal=False
)

GQA (Grouped-Query Attention)

# GQA: numHeadsQ must be divisible by numHeadsKV
Q = torch.randn(2, 1024, 32, 128, dtype=torch.float16, device="cuda")  # 32 query heads
K = torch.randn(2, 1024,  8, 128, dtype=torch.float16, device="cuda")  #  8 KV heads (4 groups)
V = torch.randn(2, 1024,  8, 128, dtype=torch.float16, device="cuda")

O = flash_attn_forward(Q, K, V, is_causal=True)

Paged Attention

from tiny_flash_attn import flash_attn_paged_varlen_forward

# Q: [totalTokens, numHeadsQ, headDim] (packed sequences)
# k_cache_pool: [numBlocks, numKvHeads, pageSize, headDim]
# v_cache_pool: [numBlocks, numKvHeads, pageSize, headDim]
# block_table: [batchSize, maxBlocksPerSeq] int32

batch_size = 3
page_size = 16
num_kv_heads = 8
head_dim = 128
num_blocks = 256

# mixed prefill + decode batch
cu_seqlens_q  = torch.tensor([0, 1, 33, 34], dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.tensor([0, 64, 96, 128], dtype=torch.int32, device="cuda")
total_q = 34

Q = torch.randn(total_q, 32, head_dim, dtype=torch.float16, device="cuda")
k_cache = torch.randn(num_blocks, num_kv_heads, page_size, head_dim,
                       dtype=torch.float16, device="cuda")
v_cache = torch.randn(num_blocks, num_kv_heads, page_size, head_dim,
                       dtype=torch.float16, device="cuda")

# block_table maps each sequence to physical pages
max_blocks_per_seq = 8
block_table = torch.zeros(batch_size, max_blocks_per_seq, dtype=torch.int32, device="cuda")

# fill block_table with physical block IDs...

# forward
O = flash_attn_paged_varlen_forward(
    Q, k_cache, v_cache,
    cu_seqlens_q, cu_seqlens_kv,
    block_table,
    max_seqlen_q=32, max_seqlen_kv=64,
    page_size=page_size,
    is_causal=True
)

C++ API

#include "flash_attn/flash_api.cuh"

// Fixed length
tfa::flashAttn<__half>(Q, K, V, O,
    batch, seqLenQ, seqLenKV,
    numHeadsQ, numHeadsKV, headDim,
    isCausal, stream);

// Variable length
tfa::flashAttnVarLen<__half>(Q, K, V, O,
    cu_seqlens_q, cu_seqlens_kv,
    batchSize, maxSeqLenQ, maxSeqLenKV,
    numHeadsQ, numHeadsKV, headDim,
    isCausal, stream);

// Paged VarLen (packed Q + paged KV cache)
tfa::flashAttnPagedVarLen<__half>(
    Q, O,                         // packed Q/O: [totalTokens, numHeadsQ, headDim]
    kCachePool, vCachePool,       // [numBlocks, numKvHeads, pageSize, headDim]
    cu_seqlens_q, cu_seqlens_kv,  // [batchSize + 1]
    blockTable,                   // [batchSize, maxBlocksPerSeq]
    batchSize, maxSeqLenQ, maxSeqLenKV,
    numHeadsQ, numHeadsKV, headDim,
    pageSize, maxBlocksPerSeq,
    isCausal, stream);

Building & Running Tests

git submodule update --init --recursive

mkdir build && cd build
cmake .. -DTFA_BUILD_TESTS=ON
make -j$(nproc)

# Run tests
ctest --test-dir . --output-on-failure

Dependencies

Limitations

  • Forward pass only
  • No dropout

License

This code is licensed under the MIT License (see LICENSE).

About

A lightweight, from-scratch Flash Attention CUDA implementation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors