[Feature] Add GVA support for Lightning#85
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Grouped Value Attention (GVA) support to the lightning attention SM100 operators, allowing the number of value heads (HV) to be greater than the query/key heads (H) as long as it is divisible by H. The changes span the core operator implementation, the benchmark suite, and the test suite. The review feedback identifies critical safety improvements for the CUDA kernels, specifically recommending that input tensors (Q, K, V, initial_state, and state_pool) be made contiguous to prevent silent memory corruption, and that the state_pool size be validated to avoid out-of-bounds memory accesses during variable-length sequence processing.
| B, S, H, D = Q.shape | ||
| O = torch.zeros_like(Q) | ||
| if K.shape != Q.shape: | ||
| raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}") | ||
| if V.ndim != 4 or V.shape[0] != B or V.shape[1] != S or V.shape[3] != D: | ||
| raise ValueError(f"V must have shape (B, S, HV, D), got {tuple(V.shape)}") | ||
| HV = V.shape[2] | ||
| if HV < H or HV % H != 0: | ||
| raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H") | ||
| decay = _normalize_gva_decay(decay, H, HV) | ||
| if initial_state is not None and initial_state.shape != (B, HV, D, D): | ||
| raise ValueError(f"initial_state must have shape {(B, HV, D, D)}, got {tuple(initial_state.shape)}") | ||
|
|
||
| O = torch.zeros_like(V) |
There was a problem hiding this comment.
Custom CUDA kernels using CuTe/Cutlass layouts assume standard contiguous memory strides. If non-contiguous tensors are passed to Q, K, V, or initial_state, the kernel will read/write incorrect memory locations, leading to silent data corruption. Ensuring these tensors are contiguous before passing them to the compiled kernel prevents this issue.
B, S, H, D = Q.shape
if K.shape != Q.shape:
raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}")
if V.ndim != 4 or V.shape[0] != B or V.shape[1] != S or V.shape[3] != D:
raise ValueError(f"V must have shape (B, S, HV, D), got {tuple(V.shape)}")
HV = V.shape[2]
if HV < H or HV % H != 0:
raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H")
decay = _normalize_gva_decay(decay, H, HV)
if initial_state is not None:
if initial_state.shape != (B, HV, D, D):
raise ValueError(f"initial_state must have shape {(B, HV, D, D)}, got {tuple(initial_state.shape)}")
initial_state = initial_state.contiguous()
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
O = torch.zeros_like(V)| _, T, H, D = Q.shape | ||
| if K.shape != Q.shape: | ||
| raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}") | ||
| if V.ndim != 4 or V.shape[0] != 1 or V.shape[1] != T or V.shape[3] != D: | ||
| raise ValueError(f"V must have shape (1, T, HV, D), got {tuple(V.shape)}") | ||
| HV = V.shape[2] | ||
| if HV < H or HV % H != 0: | ||
| raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H") | ||
| decay = _normalize_gva_decay(decay, H, HV) | ||
| N = cu_seqlens.shape[0] - 1 | ||
| O = torch.zeros_like(Q) | ||
| O = torch.zeros_like(V) | ||
|
|
||
| # Allocate state pool if not provided | ||
| if state_pool is None: | ||
| state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device) | ||
| state_pool = torch.zeros(N, HV, D, D, dtype=torch.float32, device=Q.device) | ||
| elif state_pool.ndim != 4 or state_pool.shape[1:] != (HV, D, D): | ||
| raise ValueError(f"state_pool must have shape (pool_size, {HV}, {D}, {D}), got {tuple(state_pool.shape)}") |
There was a problem hiding this comment.
Like the standard forward pass, the varlen kernel assumes contiguous layouts for Q, K, V, and state_pool. Additionally, if state_pool is provided but its first dimension (pool size) is smaller than the maximum index in initial_state_indices (or N if indices are None), the kernel will perform out-of-bounds memory accesses, leading to undefined behavior or crashes. Adding explicit checks for both pool size and contiguity prevents these critical issues.
_, T, H, D = Q.shape
if K.shape != Q.shape:
raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}")
if V.ndim != 4 or V.shape[0] != 1 or V.shape[1] != T or V.shape[3] != D:
raise ValueError(f"V must have shape (1, T, HV, D), got {tuple(V.shape)}")
HV = V.shape[2]
if HV < H or HV % H != 0:
raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H")
decay = _normalize_gva_decay(decay, H, HV)
N = cu_seqlens.shape[0] - 1
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
O = torch.zeros_like(V)
# Allocate state pool if not provided
if state_pool is None:
state_pool = torch.zeros(N, HV, D, D, dtype=torch.float32, device=Q.device)
elif state_pool.ndim != 4 or state_pool.shape[1:] != (HV, D, D):
raise ValueError(f"state_pool must have shape (pool_size, {HV}, {D}, {D}), got {tuple(state_pool.shape)}")
else:
required_size = N if initial_state_indices is None else int(initial_state_indices.max().item()) + 1
if state_pool.shape[0] < required_size:
raise ValueError(f"state_pool pool_size ({state_pool.shape[0]}) must be at least {required_size}")
state_pool = state_pool.contiguous()
#81
python tests/test_lightning_attn.py --test all
python tests/test_lightning_attn.py --test ref
python tests/test_lightning_attn.py --test fla
python tests/test_lightning_attn.py --test h0ht
python tests/test_lightning_attn.py --test varlen
python -m pytest tests/test_la_decode.py tests/test_la_decode_pool.py -v -s
python -m pytest tests/test_la_decode.py tests/test_la_decode_pool.py -v -s -k "gva or prefill_decode_e2e"
@icavan @KevinZeng08