Skip to content

Commit 0d519a3

Browse files
Add functionality to correctly evaluate scores
Previous impl assumed lower values were better (distance) which was breaking similarity search (cosine, dotp)
1 parent ed39971 commit 0d519a3

4 files changed

Lines changed: 432 additions & 9 deletions

File tree

tok/hnsw/heap.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,104 @@ func buildPersistentHeapByInit[T c.Float](array []minPersistentHeapElement[T]) *
7070
heap.Init(minPersistentTupleHeap)
7171
return minPersistentTupleHeap
7272
}
73+
74+
// maxPersistentTupleHeap is a max-heap for similarity metrics (cosine, dot-product)
75+
// where higher values indicate better matches.
76+
type maxPersistentTupleHeap[T c.Float] []minPersistentHeapElement[T]
77+
78+
func (h maxPersistentTupleHeap[T]) Len() int {
79+
return len(h)
80+
}
81+
82+
func (h maxPersistentTupleHeap[T]) Less(i, j int) bool {
83+
return h[i].value > h[j].value // reversed for max-heap
84+
}
85+
86+
func (h maxPersistentTupleHeap[T]) Swap(i, j int) {
87+
h[i], h[j] = h[j], h[i]
88+
}
89+
90+
func (h *maxPersistentTupleHeap[T]) Push(x interface{}) {
91+
*h = append(*h, x.(minPersistentHeapElement[T]))
92+
}
93+
94+
func (h *maxPersistentTupleHeap[T]) PopLast() {
95+
heap.Remove(h, h.Len()-1)
96+
}
97+
98+
func (h *maxPersistentTupleHeap[T]) Pop() interface{} {
99+
old := *h
100+
n := len(old)
101+
x := old[n-1]
102+
*h = old[:n-1]
103+
return x
104+
}
105+
106+
// buildMaxPersistentHeapByInit will create a max-heap for similarity metrics
107+
// in time O(n), where n = length of array
108+
func buildMaxPersistentHeapByInit[T c.Float](array []minPersistentHeapElement[T]) *maxPersistentTupleHeap[T] {
109+
maxHeap := &maxPersistentTupleHeap[T]{}
110+
*maxHeap = array
111+
heap.Init(maxHeap)
112+
return maxHeap
113+
}
114+
115+
// candidateHeap is an interface for the candidate heap used in HNSW search.
116+
// It abstracts over min-heap (for distance metrics) and max-heap (for similarity metrics).
117+
type candidateHeap[T c.Float] interface {
118+
Len() int
119+
Push(x minPersistentHeapElement[T])
120+
Pop() minPersistentHeapElement[T]
121+
PopLast()
122+
}
123+
124+
// minHeapWrapper wraps minPersistentTupleHeap to implement candidateHeap interface
125+
type minHeapWrapper[T c.Float] struct {
126+
h *minPersistentTupleHeap[T]
127+
}
128+
129+
func (w *minHeapWrapper[T]) Len() int {
130+
return w.h.Len()
131+
}
132+
133+
func (w *minHeapWrapper[T]) Push(x minPersistentHeapElement[T]) {
134+
heap.Push(w.h, x)
135+
}
136+
137+
func (w *minHeapWrapper[T]) Pop() minPersistentHeapElement[T] {
138+
return heap.Pop(w.h).(minPersistentHeapElement[T])
139+
}
140+
141+
func (w *minHeapWrapper[T]) PopLast() {
142+
w.h.PopLast()
143+
}
144+
145+
// maxHeapWrapper wraps maxPersistentTupleHeap to implement candidateHeap interface
146+
type maxHeapWrapper[T c.Float] struct {
147+
h *maxPersistentTupleHeap[T]
148+
}
149+
150+
func (w *maxHeapWrapper[T]) Len() int {
151+
return w.h.Len()
152+
}
153+
154+
func (w *maxHeapWrapper[T]) Push(x minPersistentHeapElement[T]) {
155+
heap.Push(w.h, x)
156+
}
157+
158+
func (w *maxHeapWrapper[T]) Pop() minPersistentHeapElement[T] {
159+
return heap.Pop(w.h).(minPersistentHeapElement[T])
160+
}
161+
162+
func (w *maxHeapWrapper[T]) PopLast() {
163+
w.h.PopLast()
164+
}
165+
166+
// buildCandidateHeap creates the appropriate heap based on whether we're using
167+
// a distance metric (lower is better) or similarity metric (higher is better).
168+
func buildCandidateHeap[T c.Float](array []minPersistentHeapElement[T], isSimilarityMetric bool) candidateHeap[T] {
169+
if isSimilarityMetric {
170+
return &maxHeapWrapper[T]{h: buildMaxPersistentHeapByInit(array)}
171+
}
172+
return &minHeapWrapper[T]{h: buildPersistentHeapByInit(array)}
173+
}

tok/hnsw/helper.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,30 @@ type SimilarityType[T c.Float] struct {
210210
distanceScore func(v, w []T, floatBits int) (T, error)
211211
insortHeap func(slice []minPersistentHeapElement[T], val minPersistentHeapElement[T]) []minPersistentHeapElement[T]
212212
isBetterScore func(a, b T) bool
213+
// isSimilarityMetric is true for metrics where higher values indicate better matches
214+
// (e.g., cosine similarity, dot product). For distance metrics like euclidean,
215+
// this is false because lower values indicate better matches.
216+
isSimilarityMetric bool
213217
}
214218

215219
func GetSimType[T c.Float](indexType string, floatBits int) SimilarityType[T] {
216220
switch {
217221
case indexType == Euclidean:
218222
return SimilarityType[T]{indexType: Euclidean, distanceScore: euclideanDistanceSq[T],
219-
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T]}
223+
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T],
224+
isSimilarityMetric: false}
220225
case indexType == Cosine:
221226
return SimilarityType[T]{indexType: Cosine, distanceScore: cosineSimilarity[T],
222-
insortHeap: insortPersistentHeapDescending[T], isBetterScore: isBetterScoreForSimilarity[T]}
227+
insortHeap: insortPersistentHeapDescending[T], isBetterScore: isBetterScoreForSimilarity[T],
228+
isSimilarityMetric: true}
223229
case indexType == DotProd:
224230
return SimilarityType[T]{indexType: DotProd, distanceScore: dotProduct[T],
225-
insortHeap: insortPersistentHeapDescending[T], isBetterScore: isBetterScoreForSimilarity[T]}
231+
insortHeap: insortPersistentHeapDescending[T], isBetterScore: isBetterScoreForSimilarity[T],
232+
isSimilarityMetric: true}
226233
default:
227234
return SimilarityType[T]{indexType: Euclidean, distanceScore: euclideanDistanceSq[T],
228-
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T]}
235+
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T],
236+
isSimilarityMetric: false}
229237
}
230238
}
231239

@@ -619,8 +627,14 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache,
619627
h := &HeapDataHolder{
620628
data: allLayerEdges[level],
621629
compare: func(i, j uint64) bool {
622-
return ph.distance_betw(ctx, tc, uuid, i, &inVec, &outVec) >
623-
ph.distance_betw(ctx, tc, uuid, j, &inVec, &outVec)
630+
distI := ph.distance_betw(ctx, tc, uuid, i, &inVec, &outVec)
631+
distJ := ph.distance_betw(ctx, tc, uuid, j, &inVec, &outVec)
632+
// We want to keep the BEST edges and remove the WORST.
633+
// Pop removes the root element. We need the WORST at root.
634+
// For distance metrics (lower is better): worst = highest, so max-heap (>)
635+
// For similarity metrics (higher is better): worst = lowest, so min-heap (<)
636+
// Using !isBetterScore gives us the correct heap type for each metric.
637+
return !ph.simType.isBetterScore(distI, distJ)
624638
}}
625639

626640
for _, e := range allLayerNeighbors[level] {

tok/hnsw/persistent_hnsw.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error {
106106
ph.simType = okSimType
107107
} else {
108108
ph.simType = SimilarityType[T]{indexType: Euclidean, distanceScore: euclideanDistanceSq[T],
109-
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T]}
109+
insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T],
110+
isSimilarityMetric: false}
110111
}
111112
return nil
112113
}
@@ -170,13 +171,15 @@ func (ph *persistentHNSW[T]) searchPersistentLayer(
170171
}
171172

172173
r.setFirstPathNode(best)
173-
candidateHeap := *buildPersistentHeapByInit([]minPersistentHeapElement[T]{best})
174+
// Use the appropriate heap type based on metric: min-heap for distance metrics
175+
// (lower is better), max-heap for similarity metrics (higher is better).
176+
candidateHeap := buildCandidateHeap([]minPersistentHeapElement[T]{best}, ph.simType.isSimilarityMetric)
174177

175178
var allLayerEdges [][]uint64
176179

177180
//create set using map to append to on future visited nodes
178181
for candidateHeap.Len() != 0 {
179-
currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T])
182+
currCandidate := candidateHeap.Pop()
180183
if r.numNeighbors() >= expectedNeighbors &&
181184
ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) {
182185
// Standard HNSW termination: once the current best candidate

0 commit comments

Comments
 (0)