From 0b9efdb75d569af7cd9f1fdfccbf86ad46d6308e Mon Sep 17 00:00:00 2001 From: Giulio Corallo Date: Tue, 16 Jun 2026 12:03:48 +0200 Subject: [PATCH] Fix: rerotation with multi-gpus Signed-off-by: Giulio Corallo Signed-off-by: Giulio Corallo --- kvpress/presses/key_rerotation_press.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 770c3d99..d79db95a 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -71,9 +71,11 @@ def _rerotate_cos_sin(x, inv_freq, selected_positions): ``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``. """ bsz, num_key_value_heads, n_kept = selected_positions.shape - device = selected_positions.device + device = x.device device_type = x.device.type dtype = x.dtype + selected_positions = selected_positions.to(device) + inv_freq = inv_freq.to(device) # Original positional indices idx = torch.arange(0, n_kept, device=device) # (n_kept,) idx = idx.unsqueeze(0) # (1, n_kept)