Skip to content

Fix: rerotation with multi-gpus#241

Merged
SimJeg merged 1 commit into
NVIDIA:mainfrom
giulio98:fix-multigpu-rerotation
Jun 16, 2026
Merged

Fix: rerotation with multi-gpus#241
SimJeg merged 1 commit into
NVIDIA:mainfrom
giulio98:fix-multigpu-rerotation

Conversation

@giulio98

@giulio98 giulio98 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Fixes #240

This PR fixes a device mismatch in key re-rotation when models are split across multiple GPUs with device_map="auto".

KeyRerotationPress._rerotate_cos_sin() could receive x, selected_positions, and inv_freq on different CUDA devices. This caused re-rotation to fail with errors like:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

The fix uses x.device as the canonical device inside _rerotate_cos_sin() and moves selected_positions and inv_freq there before computing the rotary frequencies.

This fixes FinchPress and should also apply to other presses using KeyRerotationPress.rerotate_keys().

Tested locally with FinchPress and meta-llama/Llama-3.1-8B-Instruct using device_map="auto" on a multi-GPU setup.

Checklist

Before submitting a PR, please make sure:

  • Tests are working (make test)
  • Code is formatted correctly (make style, on errors try fix with make format)
  • Copyright header is included
  • All commits are signed-off using git commit -s
  • (new press) mypress_press.py is in the presses directory
  • (new press) MyPress is in __init__.py
  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section
  • (new press) New press is in the default_presses list in tests/default_presses.py
  • (new press) A docstring is provided that follows the same structure as the existing ones

Signed-off-by: Giulio Corallo <[email protected]>
Signed-off-by: Giulio Corallo <[email protected]>
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@SimJeg SimJeg merged commit fee1af1 into NVIDIA:main Jun 16, 2026
2 checks passed
@SimJeg

SimJeg commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the contribution @giulio98 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KeyRerotationPress fails on multi-GPU setups

2 participants