diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 770c3d99..d79db95a 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -71,9 +71,11 @@ def _rerotate_cos_sin(x, inv_freq, selected_positions): ``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``. """ bsz, num_key_value_heads, n_kept = selected_positions.shape - device = selected_positions.device + device = x.device device_type = x.device.type dtype = x.dtype + selected_positions = selected_positions.to(device) + inv_freq = inv_freq.to(device) # Original positional indices idx = torch.arange(0, n_kept, device=device) # (n_kept,) idx = idx.unsqueeze(0) # (1, n_kept)