Add TurboQuant serial Metal decode fast path#91
Open
scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-corefrom
Open
Add TurboQuant serial Metal decode fast path#91scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-corefrom
scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-corefrom
Conversation
Reviewer's GuideImplements a narrow serial TurboQuant Metal-backed decode path by adding inline Metal kernels for MSE scoring/weighted sum, wiring a decode-only fast path through TurboQuantKVCache that ingests into the quantized cache and falls back to dense SDPA when unsupported, fixes shadow cache rehydration after packed cache restore, and adds a regression test plus a design note documenting the new Metal slice. Class diagram for TurboQuant Metal-backed decode pathclassDiagram
class TurboQuantMSECodec {
+prepareQueries(queries: MLXArray) MLXArray
+finalizeWeightedSum(output: MLXArray) MLXArray
+quantize(vectors: MLXArray) TurboQuantMSEState
+dequantize(state: TurboQuantMSEState) MLXArray
+bits: Int
+codebook: MLXArray
}
class TurboQuantMSEState {
+norms: MLXArray
+indices: MLXArray
}
class TurboQuantMSEKernelManager {
<<singleton>>
+shared: TurboQuantMSEKernelManager
+scoreKernel: MLXFast.MLXFastKernel?
+weightedSumKernel: MLXFast.MLXFastKernel?
-TurboQuantMSEKernelManager()
}
class TurboQuantKVCache {
+offset: Int
+keyState: TurboQuantMSEState?
+valueState: TurboQuantMSEState?
+keyCodec: TurboQuantMSECodec?
+valueCodec: TurboQuantMSECodec?
+update(keys: MLXArray, values: MLXArray) (MLXArray, MLXArray)
-ingest(keys: MLXArray, values: MLXArray) (TurboQuantMSEState, TurboQuantMSEState)
-appendShadow(keys: MLXArray, values: MLXArray, previous: Int) void
-rehydrateShadowFromPackedState() void
-groupedDecodeQueries(queries: MLXArray) MLXArray?
-fastDecodeAttention(queries: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray?
+decodeAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
-denseState() (MLXArray, MLXArray)
}
class MLXFastKernel {
+call(inputs: [MLXArray], template: [(String, Int)], grid: (Int, Int, Int), threadGroup: (Int, Int, Int), outputShapes: [[Int]], outputDTypes: [MLXArray.DType]) [MLXArray]
}
class MLXFast {
+metalKernel(name: String, inputNames: [String], outputNames: [String], source: String) MLXFastKernel?
+scaledDotProductAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
<<enum>> ScaledDotProductAttentionMaskMode
}
TurboQuantMSECodec --> TurboQuantMSEState : produces
TurboQuantMSEKernelManager --> MLXFastKernel : owns
TurboQuantMSEKernelManager ..> MLXFast : creates kernels via metalKernel
TurboQuantKVCache --> TurboQuantMSECodec : uses keyCodec
TurboQuantKVCache --> TurboQuantMSECodec : uses valueCodec
TurboQuantKVCache --> TurboQuantMSEState : maintains keyState
TurboQuantKVCache --> TurboQuantMSEState : maintains valueState
TurboQuantKVCache ..> TurboQuantMSEKernelManager : uses shared kernels
TurboQuantKVCache ..> MLXFast : falls back to scaledDotProductAttention
MLXFastKernel <|-- MLXFast.MLXFastKernel
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 3 issues, and left some high level feedback:
- The Metal-backed weighted-sum path (makeTurboQuantMSEWeightedSumKernel, turboQuantMSEDecodeWeightedSum, and TurboQuantMSECodec.finalizeWeightedSum) is currently unused; consider either wiring it into fastDecodeAttention or dropping it for now to keep the slice minimal and avoid dead code that may drift from the score path.
- fastDecodeAttention assumes a single-token, grouped-query layout via groupedDecodeQueries but this is only enforced implicitly by dim checks; it may be worth adding lightweight assertions or clearer early-returns (e.g., when queries.dim(2) != 1 or queryHeads % kvHeads != 0) to make unsupported shapes fail in a more obvious and maintainable way.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The Metal-backed weighted-sum path (makeTurboQuantMSEWeightedSumKernel, turboQuantMSEDecodeWeightedSum, and TurboQuantMSECodec.finalizeWeightedSum) is currently unused; consider either wiring it into fastDecodeAttention or dropping it for now to keep the slice minimal and avoid dead code that may drift from the score path.
- fastDecodeAttention assumes a single-token, grouped-query layout via groupedDecodeQueries but this is only enforced implicitly by dim checks; it may be worth adding lightweight assertions or clearer early-returns (e.g., when queries.dim(2) != 1 or queryHeads % kvHeads != 0) to make unsupported shapes fail in a more obvious and maintainable way.
## Individual Comments
### Comment 1
<location path="Scripts/patches/KVCache.swift" line_range="933-941" />
<code_context>
+ let D = queries.dim(3)
+ let T = state.norms.dim(2)
+
+ return kernel(
+ [queries, state.norms, state.indices, codebook],
+ template: [
+ ("Dim", D),
+ ("RepeatCount", R),
+ ("Bits", bits),
+ ("PackedWidth", state.indices.dim(3)),
+ ],
+ grid: (32, R, B * H * T),
+ threadGroup: (32, 1, 1),
+ outputShapes: [[B, H, R, T]],
</code_context>
<issue_to_address>
**issue (bug_risk):** The score kernel does not validate shape consistency between `queries` and `state`, which can lead to subtle misindexing.
`turboQuantMSEDecodeScores` only checks `queries.ndim == 4` and `state.norms.dim(2) > 0`, but not that `B`, `H`, `R`, `T` from `queries` match `state.norms`/`state.indices`. With mismatched shapes (e.g., different `B`/`H`), the kernel will still launch and misinterpret memory. Please add cheap shape checks (e.g., `state.norms.dim(0) == B`, `state.norms.dim(1) == H`, `state.indices.dim(2) == T`) to enforce consistency before invoking the Metal kernel.
</issue_to_address>
### Comment 2
<location path="Scripts/patches/KVCache.swift" line_range="1896-1897" />
<code_context>
- public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
+ @discardableResult
+ private func ingest(keys: MLXArray, values: MLXArray) -> (TurboQuantMSEState, TurboQuantMSEState) {
let previous = offset
ensureCodecs(keyDim: keys.dim(3), valueDim: values.dim(3))
</code_context>
<issue_to_address>
**nitpick (bug_risk):** `previous` is computed but not used in `ingest`, which suggests either dead code or a missing use.
Since `previous` was previously passed to `appendShadow`, its presence here without any usage suggests ingestion/shadow logic may not have been fully migrated. If shadow handling is now separate, remove `previous`; otherwise, reintroduce the relevant call so state and shadow remain consistent.
</issue_to_address>
### Comment 3
<location path="docs/feature-codex-turboquant-metal.md" line_range="24" />
<code_context>
+- [KVCache.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/KVCache.swift)
</code_context>
<issue_to_address>
**suggestion:** Consider using repository-relative links instead of absolute filesystem paths in markdown links.
Linking to `/Volumes/edata/...` will only work on your local machine. Use a repo-relative path like `/Scripts/patches/KVCache.swift` so the link works for other users and in documentation viewers.
```suggestion
- [KVCache.swift](/Scripts/patches/KVCache.swift)
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add the first inline Metal-backed TurboQuant execution slice on top of the core serial implementation.
What this PR includes
TurboQuantKVCacheWhat this PR does not include
Validation
MACAFM_MLX_METALLIB="$PWD/default.metallib" swift test --filter TurboQuantCacheTests --parallel --num-workers 1MACAFM_MLX_METALLIB="$PWD/default.metallib" swift test --filter 'KVCacheTruncateTests|BatchedPrefillTests' --parallel --num-workers 1Stack
This is PR 3 of a stacked TurboQuant series and targets PR 2.
Summary by Sourcery
Introduce a Metal-backed TurboQuant serial decode path and improve TurboQuant cache correctness around packed-state restores.
New Features:
Bug Fixes:
Enhancements:
Documentation:
Tests: