Skip to content

transformers: compute CPU FlashSdpa P·V as a tile GEMM (~3.1×)#2319

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/flash-sdpa-cpu-pv-gemm
Open

transformers: compute CPU FlashSdpa P·V as a tile GEMM (~3.1×)#2319
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/flash-sdpa-cpu-pv-gemm

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

What

The CPU FlashSdpaOp forward pass (transformers/src/ops/flash_sdpa.rs) computed the P·V step as head_dim separate dot products against strided columns of V:

for j in 0..head_dim {
    let sv = s.row(i).dot(&vblock.column(j));
    ...
}

vblock.column(j) is a strided view (stride = head_dim), so this is neither a GEMM nor vectorization-friendly. This replaces it with a single contiguous tile GEMM plus the online-softmax rescale:

let sv_tile = s.dot(&vblock);
for i in 0..q_range.len() {
    ...
    for j in 0..head_dim {
        oblock[(i, j)] = oblock[(i, j)] * mul_o + sv_tile[(i, j)] * mul_sv;
    }
}

The QK^T step above already used a 2-D .dot(); this just brings P·V to parity.

Why

  • Correctness: bit-exact — max_abs = 0.0 vs a naive softmax(QKᵀ·scale)·V reference.
  • Performance: ~3.1× on the whole flash_attention_gqa — H=8, Sq=Sk=512, D=64, release build: 76.2 → 24.6 ms/run. The strided P·V was ~68% of the runtime.

Tests

  • flash_sdpa_pv_matches_naive — correctness vs naive reference (non-square, multi-head).
  • bench_flash_sdpa_pv (#[ignore]) — the before/after A/B used for the numbers above.
  • Full tract-transformers suite green; fmt + clippy clean.

🤖 Generated with Claude Code

The forward flash-attention's P·V step computed `head_dim` separate dot products
against strided columns of V (s.row(i).dot(vblock.column(j))) -- no GEMM, and the
strided column access defeats vectorization. Replace it with one contiguous tile
GEMM `s.dot(&vblock)` plus the online-softmax rescale.

Bit-exact vs a naive host reference (max_abs 0.0). Bench (H8/Sq512/Sk512/D64,
release): 76.2 -> 24.6 ms/run = 3.1x on the whole flash_attention_gqa (the strided
P·V was ~68% of the time). QK^T already used a 2-D GEMM; this brings P·V to parity.

Co-Authored-By: Claude Opus 4.8 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant