Skip to content

Commit 441a86f

Browse files
added changes
1 parent ee87b7e commit 441a86f

10 files changed

Lines changed: 759 additions & 178 deletions

File tree

posting/index.go

Lines changed: 176 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"math"
1616
"os"
1717
"strings"
18-
"sync"
1918
"sync/atomic"
2019
"time"
2120
"unsafe"
@@ -34,11 +33,10 @@ import (
3433
"github.com/hypermodeinc/dgraph/v25/schema"
3534
"github.com/hypermodeinc/dgraph/v25/tok"
3635
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
37-
"github.com/hypermodeinc/dgraph/v25/tok/index"
36+
tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
37+
3838
"github.com/hypermodeinc/dgraph/v25/types"
3939
"github.com/hypermodeinc/dgraph/v25/x"
40-
41-
"github.com/viterin/vek/vek32"
4240
)
4341

4442
var emptyCountParams countParams
@@ -166,7 +164,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
166164
// retrieve vector from inUuid save as inVec
167165
inVec := types.BytesAsFloatArray(data[0].Value.([]byte))
168166
tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs)
169-
indexer, err := info.factorySpecs[0].CreateIndex(attr, 0)
167+
indexer, err := info.factorySpecs[0].CreateIndex(attr)
170168
if err != nil {
171169
return []*pb.DirectedEdge{}, err
172170
}
@@ -1365,112 +1363,67 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
13651363
return prefixes, nil
13661364
}
13671365

1368-
type vectorCentroids struct {
1369-
dimension int
1370-
numCenters int
1366+
const numCentroids = 1000
13711367

1372-
centroids [][]float32
1373-
counts []int64
1374-
weights [][]float32
1375-
mutexs []*sync.Mutex
1376-
}
1368+
func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
1369+
pk := x.ParsedKey{Attr: rb.Attr}
13771370

1378-
func (vc *vectorCentroids) findCentroid(input []float32) int {
1379-
minIdx := 0
1380-
minDist := math.MaxFloat32
1381-
for i, centroid := range vc.centroids {
1382-
dist := vek32.Distance(centroid, input)
1383-
if float64(dist) < minDist {
1384-
minDist = float64(dist)
1385-
minIdx = i
1386-
}
1371+
indexer, err := factorySpecs[0].CreateIndex(pk.Attr)
1372+
if err != nil {
1373+
return err
13871374
}
1388-
return minIdx
1389-
}
13901375

1391-
func (vc *vectorCentroids) addVector(vec []float32) {
1392-
idx := vc.findCentroid(vec)
1393-
vc.mutexs[idx].Lock()
1394-
defer vc.mutexs[idx].Unlock()
1395-
for i := 0; i < vc.dimension; i++ {
1396-
vc.weights[idx][i] += vec[i]
1376+
if indexer.NumSeedVectors() > 0 {
1377+
count := 0
1378+
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1379+
Prefix: pk.DataPrefix(),
1380+
ReadTs: rb.StartTs,
1381+
AllVersions: false,
1382+
Reverse: false,
1383+
CheckInclusion: func(uid uint64) error {
1384+
return nil
1385+
},
1386+
Function: func(l *List, pk x.ParsedKey) error {
1387+
val, err := l.Value(rb.StartTs)
1388+
if err != nil {
1389+
return err
1390+
}
1391+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1392+
count += 1
1393+
indexer.AddSeedVector(inVec)
1394+
if count == indexer.NumSeedVectors() {
1395+
return ErrStopIteration
1396+
}
1397+
return nil
1398+
},
1399+
StartKey: x.DataKey(rb.Attr, 0),
1400+
})
13971401
}
1398-
vc.counts[idx]++
1399-
}
14001402

1401-
func (vc *vectorCentroids) updateCentroids() {
1402-
for i := 0; i < vc.numCenters; i++ {
1403-
for j := 0; j < vc.dimension; j++ {
1404-
vc.centroids[i][j] = vc.weights[i][j] / float32(vc.counts[i])
1405-
vc.weights[i][j] = 0
1406-
}
1407-
fmt.Printf("%d, ", vc.counts[i])
1408-
vc.counts[i] = 0
1403+
txns := make([]*Txn, indexer.NumThreads())
1404+
for i := range txns {
1405+
txns[i] = NewTxn(rb.StartTs)
14091406
}
1410-
fmt.Println()
1411-
}
1412-
1413-
func (vc *vectorCentroids) randomInit() {
1414-
vc.dimension = len(vc.centroids[0])
1415-
vc.numCenters = len(vc.centroids)
1416-
vc.counts = make([]int64, vc.numCenters)
1417-
vc.weights = make([][]float32, vc.numCenters)
1418-
vc.mutexs = make([]*sync.Mutex, vc.numCenters)
1419-
for i := 0; i < vc.numCenters; i++ {
1420-
vc.weights[i] = make([]float32, vc.dimension)
1421-
vc.counts[i] = 0
1422-
vc.mutexs[i] = &sync.Mutex{}
1407+
caches := make([]tokIndex.CacheType, indexer.NumThreads())
1408+
for i := range caches {
1409+
caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
14231410
}
1424-
}
14251411

1426-
func (vc *vectorCentroids) addSeedCentroid(vec []float32) {
1427-
vc.centroids = append(vc.centroids, vec)
1428-
}
1412+
for pass_idx := range indexer.NumBuildPasses() {
1413+
fmt.Println("Building pass", pass_idx)
14291414

1430-
const numCentroids = 1000
1415+
indexer.StartBuild(caches)
14311416

1432-
func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
1433-
pk := x.ParsedKey{Attr: rb.Attr}
1434-
vc := &vectorCentroids{}
1435-
vc.centroids = make([][]float32, 0)
1436-
1437-
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1438-
Prefix: pk.DataPrefix(),
1439-
ReadTs: rb.StartTs,
1440-
AllVersions: false,
1441-
Reverse: false,
1442-
CheckInclusion: func(uid uint64) error {
1443-
return nil
1444-
},
1445-
Function: func(l *List, pk x.ParsedKey) error {
1446-
val, err := l.Value(rb.StartTs)
1447-
if err != nil {
1448-
return err
1449-
}
1450-
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1451-
vc.addSeedCentroid(inVec)
1452-
if len(vc.centroids) == numCentroids {
1453-
return ErrStopIteration
1454-
}
1455-
return nil
1456-
},
1457-
StartKey: x.DataKey(rb.Attr, 0),
1458-
})
1459-
1460-
vc.randomInit()
1461-
1462-
fmt.Println("Clustering Vectors")
1463-
for range 5 {
14641417
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
14651418
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
14661419
edges := []*pb.DirectedEdge{}
1467-
val, err := pl.Value(txn.StartTs)
1420+
val, err := pl.Value(rb.StartTs)
14681421
if err != nil {
14691422
return []*pb.DirectedEdge{}, err
14701423
}
14711424

14721425
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1473-
vc.addVector(inVec)
1426+
indexer.BuildInsert(ctx, uid, inVec)
14741427
return edges, nil
14751428
}
14761429

@@ -1479,59 +1432,33 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14791432
return err
14801433
}
14811434

1482-
vc.updateCentroids()
1435+
indexer.EndBuild()
14831436
}
14841437

1485-
tcs := make([]*hnsw.TxnCache, vc.numCenters)
1486-
txns := make([]*Txn, vc.numCenters)
1487-
indexers := make([]index.VectorIndex[float32], vc.numCenters)
1488-
for i := 0; i < vc.numCenters; i++ {
1489-
txns[i] = NewTxn(rb.StartTs)
1490-
tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
1491-
indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i)
1492-
if err != nil {
1493-
return err
1494-
}
1495-
vc.mutexs[i] = &sync.Mutex{}
1496-
indexers[i] = indexers_i
1497-
}
1438+
for pass_idx := range indexer.NumIndexPasses() {
1439+
fmt.Println("Indexing pass", pass_idx)
14981440

1499-
var edgesCreated atomic.Int64
1441+
indexer.StartBuild(caches)
15001442

1501-
numPasses := vc.numCenters / 100
1502-
for pass_idx := range numPasses {
15031443
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
15041444
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1505-
val, err := pl.Value(txn.StartTs)
1445+
edges := []*pb.DirectedEdge{}
1446+
val, err := pl.Value(rb.StartTs)
15061447
if err != nil {
15071448
return []*pb.DirectedEdge{}, err
15081449
}
15091450

15101451
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1511-
idx := vc.findCentroid(inVec)
1512-
if idx%numPasses != pass_idx {
1513-
return []*pb.DirectedEdge{}, nil
1514-
}
1515-
vc.mutexs[idx].Lock()
1516-
defer vc.mutexs[idx].Unlock()
1517-
_, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec)
1518-
if err != nil {
1519-
return []*pb.DirectedEdge{}, err
1520-
}
1521-
1522-
edgesCreated.Add(int64(1))
1523-
return nil, nil
1452+
indexer.BuildInsert(ctx, uid, inVec)
1453+
return edges, nil
15241454
}
15251455

15261456
err := builder.RunWithoutTemp(ctx)
15271457
if err != nil {
15281458
return err
15291459
}
15301460

1531-
for idx := range vc.counts {
1532-
if idx%numPasses != pass_idx {
1533-
continue
1534-
}
1461+
for _, idx := range indexer.EndBuild() {
15351462
txns[idx].Update()
15361463
writer := NewTxnWriter(pstore)
15371464

@@ -1547,14 +1474,132 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
15471474

15481475
txns[idx].cache.plists = nil
15491476
txns[idx] = nil
1550-
tcs[idx] = nil
1551-
indexers[idx] = nil
15521477
}
1553-
1554-
fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses)
15551478
}
15561479

15571480
return nil
1481+
1482+
// MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1483+
// Prefix: pk.DataPrefix(),
1484+
// ReadTs: rb.StartTs,
1485+
// AllVersions: false,
1486+
// Reverse: false,
1487+
// CheckInclusion: func(uid uint64) error {
1488+
// return nil
1489+
// },
1490+
// Function: func(l *List, pk x.ParsedKey) error {
1491+
// val, err := l.Value(rb.StartTs)
1492+
// if err != nil {
1493+
// return err
1494+
// }
1495+
// inVec := types.BytesAsFloatArray(val.Value.([]byte))
1496+
// vc.addSeedCentroid(inVec)
1497+
// if len(vc.centroids) == numCentroids {
1498+
// return ErrStopIteration
1499+
// }
1500+
// return nil
1501+
// },
1502+
// StartKey: x.DataKey(rb.Attr, 0),
1503+
// })
1504+
1505+
// vc.randomInit()
1506+
1507+
// fmt.Println("Clustering Vectors")
1508+
// for range 5 {
1509+
// builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1510+
// builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1511+
// edges := []*pb.DirectedEdge{}
1512+
// val, err := pl.Value(txn.StartTs)
1513+
// if err != nil {
1514+
// return []*pb.DirectedEdge{}, err
1515+
// }
1516+
1517+
// inVec := types.BytesAsFloatArray(val.Value.([]byte))
1518+
// vc.addVector(inVec)
1519+
// return edges, nil
1520+
// }
1521+
1522+
// err := builder.RunWithoutTemp(ctx)
1523+
// if err != nil {
1524+
// return err
1525+
// }
1526+
1527+
// vc.updateCentroids()
1528+
// }
1529+
1530+
// tcs := make([]*hnsw.TxnCache, vc.numCenters)
1531+
// txns := make([]*Txn, vc.numCenters)
1532+
// indexers := make([]index.VectorIndex[float32], vc.numCenters)
1533+
// for i := 0; i < vc.numCenters; i++ {
1534+
// txns[i] = NewTxn(rb.StartTs)
1535+
// tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
1536+
// indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i)
1537+
// if err != nil {
1538+
// return err
1539+
// }
1540+
// vc.mutexs[i] = &sync.Mutex{}
1541+
// indexers[i] = indexers_i
1542+
// }
1543+
1544+
// var edgesCreated atomic.Int64
1545+
1546+
// numPasses := vc.numCenters / 100
1547+
// for pass_idx := range numPasses {
1548+
// builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1549+
// builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1550+
// val, err := pl.Value(txn.StartTs)
1551+
// if err != nil {
1552+
// return []*pb.DirectedEdge{}, err
1553+
// }
1554+
1555+
// inVec := types.BytesAsFloatArray(val.Value.([]byte))
1556+
// idx := vc.findCentroid(inVec)
1557+
// if idx%numPasses != pass_idx {
1558+
// return []*pb.DirectedEdge{}, nil
1559+
// }
1560+
// vc.mutexs[idx].Lock()
1561+
// defer vc.mutexs[idx].Unlock()
1562+
// _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec)
1563+
// if err != nil {
1564+
// return []*pb.DirectedEdge{}, err
1565+
// }
1566+
1567+
// edgesCreated.Add(int64(1))
1568+
// return nil, nil
1569+
// }
1570+
1571+
// err := builder.RunWithoutTemp(ctx)
1572+
// if err != nil {
1573+
// return err
1574+
// }
1575+
1576+
// for idx := range vc.counts {
1577+
// if idx%numPasses != pass_idx {
1578+
// continue
1579+
// }
1580+
// txns[idx].Update()
1581+
// writer := NewTxnWriter(pstore)
1582+
1583+
// x.ExponentialRetry(int(x.Config.MaxRetries),
1584+
// 20*time.Millisecond, func() error {
1585+
// err := txns[idx].CommitToDisk(writer, rb.StartTs)
1586+
// if err == badger.ErrBannedKey {
1587+
// glog.Errorf("Error while writing to banned namespace.")
1588+
// return nil
1589+
// }
1590+
// return err
1591+
// })
1592+
1593+
// txns[idx].cache.plists = nil
1594+
// txns[idx] = nil
1595+
// tcs[idx] = nil
1596+
// indexers[idx] = nil
1597+
// }
1598+
1599+
// fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses)
1600+
// }
1601+
1602+
// return nil
15581603
}
15591604

15601605
// rebuildTokIndex rebuilds index for a given attribute.

0 commit comments

Comments
 (0)