Add core serial TurboQuant attention and packed state#90
Open
scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-plumbingfrom
Open
Add core serial TurboQuant attention and packed state#90scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-plumbingfrom
scouzi1966 wants to merge 5 commits intofeature/codex-turboquant-plumbingfrom
Conversation
Reviewer's GuideImplements explicit TurboQuant-aware attention routing and converts TurboQuantKVCache to use packed low-bit MSE-style codec state as the serialized source of truth, while keeping dense shadow buffers for fallback attention and maintaining backward-compatible cache metadata and serialization. Sequence diagram for TurboQuant-aware attention dispatchsequenceDiagram
participant Caller
participant AttentionUtils as attentionWithCacheUpdate
participant KV as KVCache
participant TurboKV as TurboQuantKVCacheProtocol
participant QuantKV as QuantizedKVCacheProtocol
participant FastAttn as MLXFast
participant QuantAttn as quantizedScaledDotProductAttention
Caller->>AttentionUtils: attentionWithCacheUpdate(queries, keys, values, cache, scale, mask)
alt cache is nil
AttentionUtils->>FastAttn: scaledDotProductAttention(queries, keys, values, scale, mask)
FastAttn-->>AttentionUtils: output
AttentionUtils-->>Caller: output
else cache is TurboQuantKVCacheProtocol
AttentionUtils->>TurboKV: queries.dim(2) == 1?
alt decode path
AttentionUtils->>TurboKV: decodeAttention(queries, keys, values, scale, mask)
TurboKV-->>AttentionUtils: output
else prefill path
AttentionUtils->>TurboKV: prefillAttention(queries, keys, values, scale, mask)
TurboKV-->>AttentionUtils: output
end
AttentionUtils-->>Caller: output
else cache is QuantizedKVCacheProtocol
AttentionUtils->>QuantKV: updateQuantized(keys, values)
QuantKV-->>AttentionUtils: quantizedKeys, quantizedValues
AttentionUtils->>QuantAttn: quantizedScaledDotProductAttention(queries, quantizedKeys, quantizedValues, scale, mask, groupSize, bits, mode)
QuantAttn-->>AttentionUtils: output
AttentionUtils-->>Caller: output
else generic KVCache
AttentionUtils->>KV: update(keys, values)
KV-->>AttentionUtils: cachedKeys, cachedValues
AttentionUtils->>FastAttn: scaledDotProductAttention(queries, cachedKeys, cachedValues, scale, mask)
FastAttn-->>AttentionUtils: output
AttentionUtils-->>Caller: output
end
Sequence diagram for TurboQuantKVCache packed update and fallback attentionsequenceDiagram
participant Model
participant TurboKV as TurboQuantKVCache
participant CodecK as TurboQuantMSECodec(keys)
participant CodecV as TurboQuantMSECodec(values)
participant FastAttn as MLXFast
Model->>TurboKV: decodeAttention(queries, keys, values, scale, mask)
activate TurboKV
TurboKV->>TurboKV: fallbackAttention(queries, keys, values, scale, mask)
TurboKV->>TurboKV: update(keys, values)
note over TurboKV: ensureCodecs(keyDim, valueDim)
TurboKV->>CodecK: quantize(keys)
CodecK-->>TurboKV: keyStateUpdate
TurboKV->>CodecV: quantize(values)
CodecV-->>TurboKV: valueStateUpdate
TurboKV->>TurboKV: appendPackedState(keyState, keyStateUpdate, previousOffset)
TurboKV->>TurboKV: appendPackedState(valueState, valueStateUpdate, previousOffset)
TurboKV->>TurboKV: appendShadow(keys, values, previousOffset)
TurboKV->>TurboKV: denseState()
TurboKV-->>TurboKV: cachedKeys, cachedValues
TurboKV->>FastAttn: scaledDotProductAttention(queries, cachedKeys, cachedValues, scale, mask)
FastAttn-->>TurboKV: attentionOutput
TurboKV-->>Model: attentionOutput
deactivate TurboKV
Updated class diagram for TurboQuantKVCache and TurboQuant codecsclassDiagram
class KVCache {
<<protocol>>
+offset: Int
+state: [MLXArray]
+metaState: [String]
+update(keys: MLXArray, values: MLXArray) MLXArray MLXArray
+truncateToOffset()
}
class TurboQuantKVCacheProtocol {
<<protocol>>
+configuration: TurboQuantConfiguration
+decodeAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
+prefillAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
}
class QuantizedKVCacheProtocol {
<<protocol>>
+groupSize: Int
+bits: Int
+mode: Int
+updateQuantized(keys: MLXArray, values: MLXArray) MLXArray MLXArray
}
class BaseKVCache {
+offset: Int
+state: [MLXArray]
+metaState: [String]
+innerState() [MLXArray]
+update(keys: MLXArray, values: MLXArray) MLXArray MLXArray
+truncateToOffset()
}
class TurboQuantMSEState {
+norms: MLXArray
+indices: MLXArray
}
class TurboQuantMSECodec {
+dim: Int
+bits: Int
+useRHT: Bool
+signs: MLXArray?
+codebook: MLXArray
+midpoints: [Float]
+TurboQuantMSECodec(dim: Int, bits: Int, seed: Int)
+quantize(vectors: MLXArray) TurboQuantMSEState
+dequantize(state: TurboQuantMSEState) MLXArray
-rotateForward(array: MLXArray) MLXArray
-rotateInverse(array: MLXArray) MLXArray
}
class TurboQuantConfiguration {
+bits: Float
+variant: TurboQuantVariant
+metadataPath: String?
+metadataVersion: Int
+transformVersion: String
+codebookVersion: String
}
class TurboQuantKVCache {
+configuration: TurboQuantConfiguration
+step: Int
+didGrow: Bool
-keyState: TurboQuantMSEState?
-valueState: TurboQuantMSEState?
-shadowKeys: MLXArray?
-shadowValues: MLXArray?
-legacyDenseState: (keys: MLXArray, values: MLXArray)?
-keyCodec: TurboQuantMSECodec?
-valueCodec: TurboQuantMSECodec?
-keyDimension: Int?
-valueDimension: Int?
+innerState() [MLXArray]
+update(keys: MLXArray, values: MLXArray) MLXArray MLXArray
+decodeAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
+prefillAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
+truncateToOffset()
+toUnquantized() KVCacheSimple
+state: [MLXArray]
+metaState: [String]
+debugDescription: String
-ensureCodecs(keyDim: Int, valueDim: Int)
-appendShadow(keys: MLXArray, values: MLXArray, previous: Int)
-appendPackedState(current: TurboQuantMSEState?, update: TurboQuantMSEState, previous: Int)
-rebuildFromDenseState(keys: MLXArray, values: MLXArray)
-rehydrateShadowFromPackedState()
-denseState() MLXArray MLXArray
-fallbackAttention(queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXFast.ScaledDotProductAttentionMaskMode) MLXArray
}
class KVCacheSimple {
+state: [MLXArray]
}
KVCache <|-- BaseKVCache
BaseKVCache <|-- TurboQuantKVCache
TurboQuantKVCacheProtocol <|.. TurboQuantKVCache
QuantizedKVCacheProtocol <|.. KVCache
TurboQuantConfiguration <.. TurboQuantKVCache
TurboQuantMSECodec <.. TurboQuantKVCache
TurboQuantMSEState <.. TurboQuantMSECodec
TurboQuantMSEState <.. TurboQuantKVCache
KVCacheSimple <.. TurboQuantKVCache
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- The low-bit pack/unpack helpers (
turboQuantPackLowBit/turboQuantUnpackLowBit) currently loop over positions and perform per-index MLXArray slicing and assignment, which will be quite slow for long sequences; consider restructuring these into vectorized operations or a dedicated kernel-style helper so the bit packing work scales better. - In
TurboQuantKVCache.truncateToOffset(), onlykeyState/valueStateand shadow tensors are trimmed; iflegacyDenseStateis present it remains unmodified and will be re-quantized to the full length on the next access, which can be surprising—either trimlegacyDenseStateas well or make it explicit that truncate is a no-op when running in dense-compatibility mode.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The low-bit pack/unpack helpers (`turboQuantPackLowBit` / `turboQuantUnpackLowBit`) currently loop over positions and perform per-index MLXArray slicing and assignment, which will be quite slow for long sequences; consider restructuring these into vectorized operations or a dedicated kernel-style helper so the bit packing work scales better.
- In `TurboQuantKVCache.truncateToOffset()`, only `keyState`/`valueState` and shadow tensors are trimmed; if `legacyDenseState` is present it remains unmodified and will be re-quantized to the full length on the next access, which can be surprising—either trim `legacyDenseState` as well or make it explicit that truncate is a no-op when running in dense-compatibility mode.
## Individual Comments
### Comment 1
<location path="Scripts/apply-mlx-patches.sh" line_range="22-24" />
<code_context>
+PATCH_FILES=("Qwen3VL.swift" "Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "DeepseekV3.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "LLMModelFactory.swift" "Load.swift" "Evaluate.swift" "LanguageModel.swift" "Tokenizer.swift" "AttentionUtils.swift" "Qwen3_5MoEVL.swift" "VLMModelFactory.swift" "SamplerTests.swift" "ToolCallFormat.swift" "KVCache.swift" "SwitchLayers.swift" "BatchKVCache.swift" "SSM.swift" "Chat.swift" "Gemma4FunctionParser.swift")
</code_context>
<issue_to_address>
**issue (bug_risk):** AttentionUtils.swift is added to PATCH_FILES/TARGET_PATHS but not to NEW_FILES, which may break patching on a clean tree.
On a clean checkout where this file doesn’t exist yet, the script won’t pre-create it, so `patch` may fail. Please add `AttentionUtils.swift` to NEW_FILES to keep the arrays consistent and ensure patching works from a clean tree.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add the core serial TurboQuant behavior on top of the plumbing branch.
What this PR includes
What this PR does not include
Validation
MACAFM_MLX_METALLIB="$PWD/default.metallib" swift test --filter TurboQuantCacheTests --parallel --num-workers 1MACAFM_MLX_METALLIB="$PWD/default.metallib" swift test --filter 'KVCacheTruncateTests|BatchedPrefillTests' --parallel --num-workers 1Stack
This is PR 2 of a stacked TurboQuant series and targets PR 1.
Summary by Sourcery
Introduce packed TurboQuant KV cache state and explicit attention dispatch hooks, routing model attention through TurboQuant-specific decode/prefill paths while preserving compatibility with existing quantized and dense caches.
New Features:
Enhancements:
Build:
Documentation:
Tests: