Skip to content

Commit 1775615

Browse files
perf(vector): Improve hnsw by sharding vectors
1 parent 104cbae commit 1775615

6 files changed

Lines changed: 293 additions & 31 deletions

File tree

posting/index.go

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"math"
1616
"os"
1717
"strings"
18+
"sync"
1819
"sync/atomic"
1920
"time"
2021
"unsafe"
@@ -33,8 +34,11 @@ import (
3334
"github.com/hypermodeinc/dgraph/v25/schema"
3435
"github.com/hypermodeinc/dgraph/v25/tok"
3536
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
37+
"github.com/hypermodeinc/dgraph/v25/tok/index"
3638
"github.com/hypermodeinc/dgraph/v25/types"
3739
"github.com/hypermodeinc/dgraph/v25/x"
40+
41+
"github.com/viterin/vek/vek32"
3842
)
3943

4044
var emptyCountParams countParams
@@ -162,7 +166,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
162166
// retrieve vector from inUuid save as inVec
163167
inVec := types.BytesAsFloatArray(data[0].Value.([]byte))
164168
tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs)
165-
indexer, err := info.factorySpecs[0].CreateIndex(attr)
169+
indexer, err := info.factorySpecs[0].CreateIndex(attr, 0)
166170
if err != nil {
167171
return []*pb.DirectedEdge{}, err
168172
}
@@ -1361,6 +1365,198 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
13611365
return prefixes, nil
13621366
}
13631367

1368+
type vectorCentroids struct {
1369+
dimension int
1370+
numCenters int
1371+
1372+
centroids [][]float32
1373+
counts []int64
1374+
weights [][]float32
1375+
mutexs []*sync.Mutex
1376+
}
1377+
1378+
func (vc *vectorCentroids) findCentroid(input []float32) int {
1379+
minIdx := 0
1380+
minDist := math.MaxFloat32
1381+
for i, centroid := range vc.centroids {
1382+
dist := vek32.Distance(centroid, input)
1383+
if float64(dist) < minDist {
1384+
minDist = float64(dist)
1385+
minIdx = i
1386+
}
1387+
}
1388+
return minIdx
1389+
}
1390+
1391+
func (vc *vectorCentroids) addVector(vec []float32) {
1392+
idx := vc.findCentroid(vec)
1393+
vc.mutexs[idx].Lock()
1394+
defer vc.mutexs[idx].Unlock()
1395+
for i := 0; i < vc.dimension; i++ {
1396+
vc.weights[idx][i] += vec[i]
1397+
}
1398+
vc.counts[idx]++
1399+
}
1400+
1401+
func (vc *vectorCentroids) updateCentroids() {
1402+
for i := 0; i < vc.numCenters; i++ {
1403+
for j := 0; j < vc.dimension; j++ {
1404+
vc.centroids[i][j] = vc.weights[i][j] / float32(vc.counts[i])
1405+
vc.weights[i][j] = 0
1406+
}
1407+
fmt.Printf("%d, ", vc.counts[i])
1408+
vc.counts[i] = 0
1409+
}
1410+
fmt.Println()
1411+
}
1412+
1413+
func (vc *vectorCentroids) randomInit() {
1414+
vc.dimension = len(vc.centroids[0])
1415+
vc.numCenters = len(vc.centroids)
1416+
vc.centroids = make([][]float32, vc.numCenters)
1417+
vc.counts = make([]int64, vc.numCenters)
1418+
vc.weights = make([][]float32, vc.numCenters)
1419+
vc.mutexs = make([]*sync.Mutex, vc.numCenters)
1420+
for i := 0; i < vc.numCenters; i++ {
1421+
vc.weights[i] = make([]float32, vc.dimension)
1422+
vc.counts[i] = 0
1423+
vc.mutexs[i] = &sync.Mutex{}
1424+
}
1425+
}
1426+
1427+
func (vc *vectorCentroids) addSeedCentroid(vec []float32) {
1428+
vc.centroids = append(vc.centroids, vec)
1429+
}
1430+
1431+
const numCentroids = 1000
1432+
1433+
func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
1434+
pk := x.ParsedKey{Attr: rb.Attr}
1435+
vc := &vectorCentroids{}
1436+
1437+
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1438+
Prefix: pk.DataPrefix(),
1439+
ReadTs: rb.StartTs,
1440+
AllVersions: false,
1441+
Reverse: false,
1442+
CheckInclusion: func(uid uint64) error {
1443+
return nil
1444+
},
1445+
Function: func(l *List, pk x.ParsedKey) error {
1446+
val, err := l.Value(rb.StartTs)
1447+
if err != nil {
1448+
return err
1449+
}
1450+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1451+
vc.addSeedCentroid(inVec)
1452+
if len(vc.centroids) == numCentroids {
1453+
return ErrStopIteration
1454+
}
1455+
return nil
1456+
},
1457+
StartKey: x.DataKey(rb.Attr, 0),
1458+
})
1459+
1460+
vc.randomInit()
1461+
1462+
fmt.Println("Clustering Vectors")
1463+
for range 5 {
1464+
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1465+
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1466+
edges := []*pb.DirectedEdge{}
1467+
val, err := pl.Value(txn.StartTs)
1468+
if err != nil {
1469+
return []*pb.DirectedEdge{}, err
1470+
}
1471+
1472+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1473+
vc.addVector(inVec)
1474+
return edges, nil
1475+
}
1476+
1477+
err := builder.RunWithoutTemp(ctx)
1478+
if err != nil {
1479+
return err
1480+
}
1481+
1482+
vc.updateCentroids()
1483+
}
1484+
1485+
tcs := make([]*hnsw.TxnCache, vc.numCenters)
1486+
txns := make([]*Txn, vc.numCenters)
1487+
indexers := make([]index.VectorIndex[float32], vc.numCenters)
1488+
for i := 0; i < vc.numCenters; i++ {
1489+
txns[i] = NewTxn(rb.StartTs)
1490+
tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
1491+
indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i)
1492+
if err != nil {
1493+
return err
1494+
}
1495+
vc.mutexs[i] = &sync.Mutex{}
1496+
indexers[i] = indexers_i
1497+
}
1498+
1499+
var edgesCreated atomic.Int64
1500+
1501+
numPasses := vc.numCenters / 100
1502+
for pass_idx := range numPasses {
1503+
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1504+
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1505+
val, err := pl.Value(txn.StartTs)
1506+
if err != nil {
1507+
return []*pb.DirectedEdge{}, err
1508+
}
1509+
1510+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1511+
idx := vc.findCentroid(inVec)
1512+
if idx%numPasses != pass_idx {
1513+
return []*pb.DirectedEdge{}, nil
1514+
}
1515+
vc.mutexs[idx].Lock()
1516+
defer vc.mutexs[idx].Unlock()
1517+
_, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec)
1518+
if err != nil {
1519+
return []*pb.DirectedEdge{}, err
1520+
}
1521+
1522+
edgesCreated.Add(int64(1))
1523+
return nil, nil
1524+
}
1525+
1526+
err := builder.RunWithoutTemp(ctx)
1527+
if err != nil {
1528+
return err
1529+
}
1530+
1531+
for idx := range vc.counts {
1532+
if idx%numPasses != pass_idx {
1533+
continue
1534+
}
1535+
txns[idx].Update()
1536+
writer := NewTxnWriter(pstore)
1537+
1538+
x.ExponentialRetry(int(x.Config.MaxRetries),
1539+
20*time.Millisecond, func() error {
1540+
err := txns[idx].CommitToDisk(writer, rb.StartTs)
1541+
if err == badger.ErrBannedKey {
1542+
glog.Errorf("Error while writing to banned namespace.")
1543+
return nil
1544+
}
1545+
return err
1546+
})
1547+
1548+
txns[idx].cache.plists = nil
1549+
txns[idx] = nil
1550+
tcs[idx] = nil
1551+
indexers[idx] = nil
1552+
}
1553+
1554+
fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses)
1555+
}
1556+
1557+
return nil
1558+
}
1559+
13641560
// rebuildTokIndex rebuilds index for a given attribute.
13651561
// We commit mutations with startTs and ignore the errors.
13661562
func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
@@ -1392,6 +1588,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
13921588
}
13931589

13941590
runForVectors := (len(factorySpecs) != 0)
1591+
if runForVectors {
1592+
return rebuildVectorIndex(ctx, factorySpecs, rb)
1593+
}
13951594

13961595
pk := x.ParsedKey{Attr: rb.Attr}
13971596
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}

tok/hnsw/persistent_factory.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,27 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions {
8787
func (hf *persistentIndexFactory[T]) Create(
8888
name string,
8989
o opt.Options,
90-
floatBits int) (index.VectorIndex[T], error) {
90+
floatBits int,
91+
split int) (index.VectorIndex[T], error) {
9192
hf.mu.Lock()
9293
defer hf.mu.Unlock()
93-
return hf.createWithLock(name, o, floatBits)
94+
return hf.createWithLock(name, o, floatBits, split)
9495
}
9596

9697
func (hf *persistentIndexFactory[T]) createWithLock(
9798
name string,
9899
o opt.Options,
99-
floatBits int) (index.VectorIndex[T], error) {
100-
if !hf.isNameAvailableWithLock(name) {
100+
floatBits int,
101+
split int) (index.VectorIndex[T], error) {
102+
if !hf.isNameAvailableWithLock(fmt.Sprintf("%s-%d", name, split)) {
101103
err := errors.New("index with name " + name + " already exists")
102104
return nil, err
103105
}
104106
retVal := &persistentHNSW[T]{
105107
pred: name,
106-
vecEntryKey: ConcatStrings(name, VecEntry),
107-
vecKey: ConcatStrings(name, VecKeyword),
108-
vecDead: ConcatStrings(name, VecDead),
108+
vecEntryKey: ConcatStrings(name, VecEntry, fmt.Sprintf("_%d", split)),
109+
vecKey: ConcatStrings(name, VecKeyword, fmt.Sprintf("_%d", split)),
110+
vecDead: ConcatStrings(name, VecDead, fmt.Sprintf("_%d", split)),
109111
floatBits: floatBits,
110112
nodeAllEdges: map[uint64][][]uint64{},
111113
}
@@ -152,7 +154,8 @@ func (hf *persistentIndexFactory[T]) removeWithLock(name string) error {
152154
func (hf *persistentIndexFactory[T]) CreateOrReplace(
153155
name string,
154156
o opt.Options,
155-
floatBits int) (index.VectorIndex[T], error) {
157+
floatBits int,
158+
split int) (index.VectorIndex[T], error) {
156159
hf.mu.Lock()
157160
defer hf.mu.Unlock()
158161
vi, err := hf.findWithLock(name)
@@ -165,5 +168,5 @@ func (hf *persistentIndexFactory[T]) CreateOrReplace(
165168
return nil, err
166169
}
167170
}
168-
return hf.createWithLock(name, o, floatBits)
171+
return hf.createWithLock(name, o, floatBits, split)
169172
}

tok/hnsw/persistent_hnsw.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package hnsw
88
import (
99
"context"
1010
"fmt"
11+
"sort"
1112
"strings"
1213
"time"
1314

@@ -254,6 +255,46 @@ func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, quer
254255
return r.Neighbors, err
255256
}
256257

258+
type resultRow[T c.Float] struct {
259+
uid uint64
260+
dist T
261+
}
262+
263+
func (ph *persistentHNSW[T]) MergeResults(ctx context.Context, c index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) {
264+
var result []resultRow[T]
265+
266+
for i := range list {
267+
var vec []T
268+
err := ph.getVecFromUid(list[i], c, &vec)
269+
if err != nil {
270+
return nil, err
271+
}
272+
273+
dist, err := ph.simType.distanceScore(vec, query, ph.floatBits)
274+
if err != nil {
275+
return nil, err
276+
}
277+
result = append(result, resultRow[T]{
278+
uid: list[i],
279+
dist: dist,
280+
})
281+
}
282+
283+
sort.Slice(result, func(i, j int) bool {
284+
return result[i].dist < result[j].dist
285+
})
286+
287+
uids := []uint64{}
288+
for i := range maxResults {
289+
if i > len(result) {
290+
break
291+
}
292+
uids = append(uids, result[i].uid)
293+
}
294+
295+
return uids, nil
296+
}
297+
257298
// SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid
258299
// and returns the traversal path and the nearest neighbors
259300
func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64,

tok/index/index.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type IndexFactory[T c.Float] interface {
3939
// same object.
4040
// The set of vectors to use in the index process is defined by
4141
// source.
42-
Create(name string, o opts.Options, floatBits int) (VectorIndex[T], error)
42+
Create(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error)
4343

4444
// Find is expected to retrieve the VectorIndex corresponding with the
4545
// name. If it attempts to find a name that does not exist, the VectorIndex
@@ -56,7 +56,7 @@ type IndexFactory[T c.Float] interface {
5656
// CreateOrReplace will create a new index -- as defined by the Create
5757
// function -- if it does not yet exist, otherwise, it will replace any
5858
// index with the given name.
59-
CreateOrReplace(name string, o opts.Options, floatBits int) (VectorIndex[T], error)
59+
CreateOrReplace(name string, o opts.Options, floatBits int, split int) (VectorIndex[T], error)
6060
}
6161

6262
// SearchFilter defines a predicate function that we will use to determine
@@ -93,6 +93,9 @@ type OptionalIndexSupport[T c.Float] interface {
9393
type VectorIndex[T c.Float] interface {
9494
OptionalIndexSupport[T]
9595

96+
MergeResults(ctx context.Context, c CacheType, list []uint64, query []T, maxResults int,
97+
filter SearchFilter[T]) ([]uint64, error)
98+
9699
// Search will find the uids for a given set of vectors based on the
97100
// input query, limiting to the specified maximum number of results.
98101
// The filter parameter indicates that we might discard certain parameters

0 commit comments

Comments
 (0)