Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Open
yaox12 wants to merge 2 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa
Open

[PyTorch] Enable head dim 256 for FA4#2932
yaox12 wants to merge 2 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 27, 2026

Description

Need FA4 version 4.0.0b11.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR enables head_dim=256 support in FlashAttention 4 by delegating head-dimension validation to FA4's own _validate_head_dims function (replacing a hand-maintained lookup table), bumping the required FA4 version to 4.0.0b11, and adding a properly SM100-gated test. It also removes stale cuDNN version guards from existing FA4 tests.

  • backends.py: Imports _validate_head_dims from flash_attn.cute.interface and stores it on FlashAttentionUtils.v4_validate_head_dims; this import is co-located with the two existing FA4 symbols in a bare else block, so a missing symbol in an older FA4 version raises an unhandled ImportError that breaks the entire module load.
  • utils.py: get_attention_backend now calls FA4's validator inside a try/except AssertionError, and the SM100 MLA backward-kernel workaround is correctly restructured from elif to if to remain active after the new validation block.
  • test_attention.py: test_dpa_fa4_hdim256 is a new, dedicated test gated on get_device_compute_capability() == (10, 0), keeping the CI signal unambiguous on non-SM100 hardware.

Confidence Score: 4/5

Safe to merge once the _validate_head_dims import is isolated from the two existing FA4 symbols so that an older FA4 install does not crash backends.py on load.

The import of _validate_head_dims is placed in the same from … import block as the existing, already-working FA4 symbols. Any FA4 build that does not expose this private symbol — including the previously recommended 4.0.0b8 — will trigger an unhandled ImportError at module load time, breaking all FA4 functionality for users who have not yet upgraded. The rest of the changes (delegation to FA4's validator, MLA workaround restructuring, SM100-gated test) are correct.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — the _validate_head_dims import needs its own try/except to stay safe with older FA4 installs.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _validate_head_dims import from FA4 in the same statement as the two existing FA4 symbols; an ImportError for this private symbol in older FA4 installs would crash the entire backends module.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces inline head-dim table with a call to FA4's own _validate_head_dims, adds the v4_validate_head_dims class attribute, and restructures the MLA workaround from elif to if; logic is correct.
tests/pytorch/attention/test_attention.py Adds a properly SM100-gated test_dpa_fa4_hdim256 test and removes stale cuDNN-version skipifs from all FA4 tests.
qa/L3_pytorch_FA_versions_test/test.sh Updates FA4 version from 4.0.0b8 to 4.0.0b11 for both SM90 and SM100+ test matrices.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[backends.py module load] --> B{FA4 package installed?}
    B -- No --> C[flash_attn_func_v4 = None]
    B -- Yes --> D[import flash_attn_func, flash_attn_varlen_func,\n_validate_head_dims]
    D -- ImportError if symbol missing --> E[Unhandled ImportError breaks backends.py load]
    D -- OK --> F[v4_validate_head_dims = _fa4_validate_head_dims]
    F --> G[get_attention_backend called]
    G --> H{use_flash_attention_4 and v4_validate_head_dims != None?}
    H -- No --> I[Skip FA4 head-dim validation]
    H -- Yes --> J[Call v4_validate_head_dims]
    J -- AssertionError --> K[use_flash_attention_4 = False]
    J -- OK --> L{SM100 MLA workaround needed?}
    L -- Yes misaligned --> M[use_flash_attention_4 = False]
    L -- No --> N[FA4 selected]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 166-174 (link)

    P1 The _validate_head_dims import is added to the same from … import statement as flash_attn_func and flash_attn_varlen_func, inside the bare else block of a try/except PackageNotFoundError. If _validate_head_dims does not exist in the installed FA4 (e.g., any version older than the one that added this private symbol, such as 4.0.0b8), the whole import statement raises an unhandled ImportError, crashing backends.py module load for every user who still has an older FA4 installed — including all existing FA4 tests. The intent is clearly to keep v4_validate_head_dims = None as a graceful "not supported" sentinel, but the current structure defeats that. Importing _validate_head_dims separately with its own try/except preserves the graceful fallback and matches the class-level default.

Reviews (2): Last reviewed commit: "update CI, fix lint, resolve comments" | Re-trigger Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from 3b3f7d0 to 9a93156 Compare May 6, 2026 02:44
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from ae74e44 to 8aa5242 Compare May 6, 2026 02:55
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

/te-ci pytorch L3

@yaox12 yaox12 marked this pull request as ready for review May 6, 2026 02:59
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

@vcherepanov-nv @KshitijLakhani Please review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant