[JAX][Common] Enable cuDNN fused attn backend for NO_MASK + bidirectional SWA#2961
Conversation
Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]>
…so add a helper to pick window based on cuDNN version support in fused_attn.cpp Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 L2 |
Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables the cuDNN fused attention backend for the
Confidence Score: 4/5Safe to merge; the C++ backend change is a one-line, well-scoped addition that mirrors the pattern already used for PADDING_MASK in the same block, and the JAX warning change is purely cosmetic. The core backend logic change is minimal and follows the established pattern for existing mask types in the 9.6 SWA branch. The test helper is readable and the cuDNN version gating is correct. The only notable gap is that on cuDNN ≥ 9.6 the existing SWA tests for NO_MASK now exclusively exercise the bidirectional path, leaving the left-only (right=0) path untested on that version family. tests/jax/test_fused_attn.py — the _get_swa_window_size_for_test helper shifts NO_MASK SWA tests from left-only to bidirectional on cuDNN ≥ 9.6, narrowing coverage of the left-only path on newer driver versions. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend] --> B{cuDNN version?}
B -->|< 9.2| C{SWA window?}
C -->|left=-1, right=-1 or 0| D[Allow]
C -->|other| E[Reject SWA]
B -->|>= 9.2| F{Sliding window branch}
F -->|left=-1, right=-1, NO_MASK| G[Allow full attn]
F -->|left=any, right=0, NO_MASK/CAUSAL/...| H{Format check}
H -->|BSHD/SBHD| I[Allow left-only SWA]
H -->|other| J[Reject]
B -->|>= 9.6| K{9.6 sliding window}
K -->|left=-1, right=-1 or 0| L[Allow - any mask]
K -->|left=any, right>=0 or -1| M{Mask type?}
M -->|CAUSAL_BOTTOM_RIGHT + sm_arch checks| N[Allow bidir SWA]
M -->|NO_MASK NEW| N
M -->|PADDING_MASK| N
M -->|PADDING_CAUSAL_MASK| N
M -->|PADDING_CAUSAL_BOTTOM_RIGHT + sm_arch checks| N
M -->|CAUSAL_MASK or other| O[Reject bidir SWA]
N --> P{max_seqlen_q <= max_seqlen_kv + NO_BIAS + dropout=0}
P -->|Yes| Q[flag_arb = true - cuDNN fused backend]
P -->|No| R[Reject]
|
|
Failures in the CI are unrelated to attention in any way so safe to merge. |
Description
Enable cuDNN fused attn backend for right sided window for NO_MASK when using cuDNN 9.6+
Type of change
Changes
Checklist: