Skip to content

Vectorize promo logit computation#2

Open
mcognetta wants to merge 3 commits into
CSSLab:mainfrom
mcognetta:vectorized_promo
Open

Vectorize promo logit computation#2
mcognetta wants to merge 3 commits into
CSSLab:mainfrom
mcognetta:vectorized_promo

Conversation

@mcognetta
Copy link
Copy Markdown

@mcognetta mcognetta commented May 25, 2026

This replaces the nested loop in the promotion logit computation with a vectorized version that gives an ~18% speedup for the entire forward pass on a small local benchmark on my machine.

It also moves the rank7_indices and rank8_indices to the model definition, since these are static and don't need to be constructed each forward pass.

Comment thread maia3/models.py
bias = promo_biases[:, to_file, piece_idx] # (B,)
promotion_logits.append((base_score + bias).unsqueeze(1))
promotion_logits = torch.cat(promotion_logits, dim=1) # (B, 256)
base = scores_base[:, self.rank7_indices][:, :, self.rank8_indices] # (B, 8, 8)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rank_7/rank_8 could also just be replaced directly with a slice like:

base = scores_base[:, 48:56][:, :, 56:65].

This avoids a new allocation, etc. since it is just a contiguous view. Doesn't give much speedup on my machine though, so maybe not worth the loss in clarity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant