Skip to content

Commit 93f126e

Browse files
fix(bm25): address Gemini/GPT-5 code review findings
- Add DecodeCount() to bm25enc for O(1) entry count reads without full decode, preventing OOM on legacy migration with large posting lists (e.g., common terms with millions of entries) - Use DecodeCount in WAND search legacy DF calculation path - Fix integer overflow in DecodeDir bounds check by using uint64 arithmetic (prevents panic on corrupted data with MaxUint32 count) - Pre-allocate shared score buffer in handleBM25Search with three-index slices to prevent accidental append corruption - Document bm25Writes concurrency model and limitations Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
1 parent 9093cb0 commit 93f126e

7 files changed

Lines changed: 83 additions & 7 deletions

File tree

posting/bm25block/bm25block.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ func DecodeDir(data []byte) *Dir {
7777
}
7878
count := binary.BigEndian.Uint32(data[0:4])
7979
nextID := binary.BigEndian.Uint32(data[4:8])
80-
if int(count)*dirEntrySize+dirHeaderSize > len(data) {
80+
// Use uint64 arithmetic to prevent integer overflow on corrupted data.
81+
if uint64(count)*dirEntrySize+dirHeaderSize > uint64(len(data)) {
8182
return &Dir{NextID: nextID}
8283
}
8384
blocks := make([]BlockMeta, count)

posting/bm25block/bm25block_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package bm25block
77

88
import (
9+
"encoding/binary"
910
"math"
1011
"testing"
1112

@@ -39,6 +40,18 @@ func TestDirRoundtripEmpty(t *testing.T) {
3940
require.Empty(t, got.Blocks)
4041
}
4142

43+
func TestDecodeDirCorruptedLargeCount(t *testing.T) {
44+
// A corrupted blob with a massive count should not panic due to integer overflow.
45+
// count = MaxUint32, nextID = 0, followed by only 8 bytes of data.
46+
data := make([]byte, 16)
47+
binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) // count = MaxUint32
48+
binary.BigEndian.PutUint32(data[4:8], 0) // nextID = 0
49+
got := DecodeDir(data)
50+
// Should return an empty Dir (with nextID preserved) rather than panicking.
51+
require.Empty(t, got.Blocks)
52+
require.Equal(t, uint32(0), got.NextID)
53+
}
54+
4255
func TestDirRoundtripSingle(t *testing.T) {
4356
dir := &Dir{
4457
NextID: 1,

posting/bm25enc/bm25enc.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ func UIDs(entries []Entry) []uint64 {
130130
return uids
131131
}
132132

133+
// DecodeCount reads just the entry count from the header of an encoded blob
134+
// without decoding any entries. This is O(1) and avoids allocating a full
135+
// []Entry slice, which matters for large posting lists (e.g., common terms
136+
// during legacy format migration).
137+
func DecodeCount(data []byte) uint32 {
138+
if len(data) < 4 {
139+
return 0
140+
}
141+
return binary.BigEndian.Uint32(data[:4])
142+
}
143+
133144
// EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes.
134145
func EncodeStats(docCount, totalTerms uint64) []byte {
135146
buf := make([]byte, 16)

posting/bm25enc/bm25enc_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,37 @@ func TestUIDs(t *testing.T) {
9292
require.Equal(t, []uint64{1, 5, 100}, UIDs(entries))
9393
}
9494

95+
func TestDecodeCount(t *testing.T) {
96+
// Normal case: count matches actual entries.
97+
entries := []Entry{
98+
{UID: 1, Value: 3},
99+
{UID: 5, Value: 1},
100+
{UID: 100, Value: 7},
101+
}
102+
data := Encode(entries)
103+
require.Equal(t, uint32(3), DecodeCount(data))
104+
105+
// Empty/nil input.
106+
require.Equal(t, uint32(0), DecodeCount(nil))
107+
require.Equal(t, uint32(0), DecodeCount([]byte{}))
108+
require.Equal(t, uint32(0), DecodeCount([]byte{1, 2, 3}))
109+
110+
// Zero count.
111+
require.Equal(t, uint32(0), DecodeCount([]byte{0, 0, 0, 0}))
112+
113+
// Single entry.
114+
single := Encode([]Entry{{UID: 42, Value: 10}})
115+
require.Equal(t, uint32(1), DecodeCount(single))
116+
117+
// Large count.
118+
large := make([]Entry, 10000)
119+
for i := range large {
120+
large[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)}
121+
}
122+
data = Encode(large)
123+
require.Equal(t, uint32(10000), DecodeCount(data))
124+
}
125+
95126
func TestStatsRoundtrip(t *testing.T) {
96127
data := EncodeStats(12345, 98765)
97128
dc, tt := DecodeStats(data)

posting/lists.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ type LocalCache struct {
7979

8080
// bm25Writes buffers BM25 direct KV writes (key → encoded blob).
8181
// These bypass the posting list infrastructure entirely.
82+
//
83+
// CONCURRENCY NOTE: BM25 blocks use full-value overwrites rather than
84+
// posting list deltas. Within a single Dgraph transaction this is safe
85+
// (each Txn has its own LocalCache). Across concurrent transactions,
86+
// Dgraph's Raft-based mutation serialization prevents lost updates for
87+
// the same predicate+UID pair. However, two transactions updating
88+
// different UIDs that share a common term could theoretically race on
89+
// the same term block. In practice this is mitigated by:
90+
// 1. Dgraph serializes mutations through Raft proposals
91+
// 2. Block splits keep contention surface small
92+
// If higher write concurrency is needed, blocks should be integrated
93+
// into the posting list delta mechanism.
8294
bm25Writes map[string][]byte
8395
}
8496

worker/bm25wand.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,11 @@ func wandSearch(attr string, readTs uint64, queryTokens []string,
447447
df += uint64(bm.Count)
448448
}
449449
} else {
450-
// Legacy fallback: read the monolithic blob to get df.
450+
// Legacy fallback: read just the count header to get df.
451+
// Avoids decoding the full posting list (which could be huge for common terms).
451452
legacyKey := x.BM25IndexKey(attr, token)
452453
legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs)
453-
legacyEntries := bm25enc.Decode(legacyBlob)
454-
df = uint64(len(legacyEntries))
454+
df = uint64(bm25enc.DecodeCount(legacyBlob))
455455
}
456456
if df == 0 {
457457
continue

worker/task.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,19 +1318,27 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error
13181318

13191319
// 7. Build output: UIDs sorted ascending (required by query pipeline)
13201320
// and ValueMatrix with aligned scores (for bm25_score pseudo-predicate).
1321+
// We use a single pre-allocated buffer for all score encodings to reduce
1322+
// per-result heap allocations.
13211323
sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid })
13221324
uids := make([]uint64, len(results))
13231325
for i, r := range results {
13241326
uids[i] = r.uid
13251327
}
13261328
args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids})
13271329

1330+
// Encode scores into ValueMatrix. Each entry in ValueMatrix corresponds
1331+
// positionally to a UID in UidMatrix[0], enabling the bm25_score
1332+
// pseudo-predicate in query.go to map UIDs to scores.
1333+
scoreBuf := make([]byte, len(results)*8)
13281334
scoreValues := make([]*pb.ValueList, len(results))
13291335
for i, r := range results {
1330-
buf := make([]byte, 8)
1331-
binary.LittleEndian.PutUint64(buf, math.Float64bits(r.score))
1336+
off := i * 8
1337+
binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score))
1338+
// Use three-index slice to cap capacity at 8, preventing any downstream
1339+
// append from corrupting adjacent scores in the shared backing array.
13321340
scoreValues[i] = &pb.ValueList{
1333-
Values: []*pb.TaskValue{{Val: buf, ValType: pb.Posting_ValType(pb.Posting_FLOAT)}},
1341+
Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_ValType(pb.Posting_FLOAT)}},
13341342
}
13351343
}
13361344
args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...)

0 commit comments

Comments
 (0)