@@ -15,7 +15,6 @@ import (
1515 "math"
1616 "os"
1717 "strings"
18- "sync"
1918 "sync/atomic"
2019 "time"
2120 "unsafe"
@@ -34,11 +33,10 @@ import (
3433 "github.com/hypermodeinc/dgraph/v25/schema"
3534 "github.com/hypermodeinc/dgraph/v25/tok"
3635 "github.com/hypermodeinc/dgraph/v25/tok/hnsw"
37- "github.com/hypermodeinc/dgraph/v25/tok/index"
36+ tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
37+
3838 "github.com/hypermodeinc/dgraph/v25/types"
3939 "github.com/hypermodeinc/dgraph/v25/x"
40-
41- "github.com/viterin/vek/vek32"
4240)
4341
4442var emptyCountParams countParams
@@ -166,7 +164,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
166164 // retrieve vector from inUuid save as inVec
167165 inVec := types .BytesAsFloatArray (data [0 ].Value .([]byte ))
168166 tc := hnsw .NewTxnCache (NewViTxn (txn ), txn .StartTs )
169- indexer , err := info .factorySpecs [0 ].CreateIndex (attr , 0 )
167+ indexer , err := info .factorySpecs [0 ].CreateIndex (attr )
170168 if err != nil {
171169 return []* pb.DirectedEdge {}, err
172170 }
@@ -1365,112 +1363,67 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
13651363 return prefixes , nil
13661364}
13671365
1368- type vectorCentroids struct {
1369- dimension int
1370- numCenters int
1366+ const numCentroids = 1000
13711367
1372- centroids [][]float32
1373- counts []int64
1374- weights [][]float32
1375- mutexs []* sync.Mutex
1376- }
1368+ func rebuildVectorIndex (ctx context.Context , factorySpecs []* tok.FactoryCreateSpec , rb * IndexRebuild ) error {
1369+ pk := x.ParsedKey {Attr : rb .Attr }
13771370
1378- func (vc * vectorCentroids ) findCentroid (input []float32 ) int {
1379- minIdx := 0
1380- minDist := math .MaxFloat32
1381- for i , centroid := range vc .centroids {
1382- dist := vek32 .Distance (centroid , input )
1383- if float64 (dist ) < minDist {
1384- minDist = float64 (dist )
1385- minIdx = i
1386- }
1371+ indexer , err := factorySpecs [0 ].CreateIndex (pk .Attr )
1372+ if err != nil {
1373+ return err
13871374 }
1388- return minIdx
1389- }
13901375
1391- func (vc * vectorCentroids ) addVector (vec []float32 ) {
1392- idx := vc .findCentroid (vec )
1393- vc .mutexs [idx ].Lock ()
1394- defer vc .mutexs [idx ].Unlock ()
1395- for i := 0 ; i < vc .dimension ; i ++ {
1396- vc.weights [idx ][i ] += vec [i ]
1376+ if indexer .NumSeedVectors () > 0 {
1377+ count := 0
1378+ MemLayerInstance .IterateDisk (ctx , IterateDiskArgs {
1379+ Prefix : pk .DataPrefix (),
1380+ ReadTs : rb .StartTs ,
1381+ AllVersions : false ,
1382+ Reverse : false ,
1383+ CheckInclusion : func (uid uint64 ) error {
1384+ return nil
1385+ },
1386+ Function : func (l * List , pk x.ParsedKey ) error {
1387+ val , err := l .Value (rb .StartTs )
1388+ if err != nil {
1389+ return err
1390+ }
1391+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1392+ count += 1
1393+ indexer .AddSeedVector (inVec )
1394+ if count == indexer .NumSeedVectors () {
1395+ return ErrStopIteration
1396+ }
1397+ return nil
1398+ },
1399+ StartKey : x .DataKey (rb .Attr , 0 ),
1400+ })
13971401 }
1398- vc .counts [idx ]++
1399- }
14001402
1401- func (vc * vectorCentroids ) updateCentroids () {
1402- for i := 0 ; i < vc .numCenters ; i ++ {
1403- for j := 0 ; j < vc .dimension ; j ++ {
1404- vc.centroids [i ][j ] = vc.weights [i ][j ] / float32 (vc .counts [i ])
1405- vc.weights [i ][j ] = 0
1406- }
1407- fmt .Printf ("%d, " , vc .counts [i ])
1408- vc .counts [i ] = 0
1403+ txns := make ([]* Txn , indexer .NumThreads ())
1404+ for i := range txns {
1405+ txns [i ] = NewTxn (rb .StartTs )
14091406 }
1410- fmt .Println ()
1411- }
1412-
1413- func (vc * vectorCentroids ) randomInit () {
1414- vc .dimension = len (vc .centroids [0 ])
1415- vc .numCenters = len (vc .centroids )
1416- vc .counts = make ([]int64 , vc .numCenters )
1417- vc .weights = make ([][]float32 , vc .numCenters )
1418- vc .mutexs = make ([]* sync.Mutex , vc .numCenters )
1419- for i := 0 ; i < vc .numCenters ; i ++ {
1420- vc .weights [i ] = make ([]float32 , vc .dimension )
1421- vc .counts [i ] = 0
1422- vc .mutexs [i ] = & sync.Mutex {}
1407+ caches := make ([]tokIndex.CacheType , indexer .NumThreads ())
1408+ for i := range caches {
1409+ caches [i ] = hnsw .NewTxnCache (NewViTxn (txns [i ]), rb .StartTs )
14231410 }
1424- }
14251411
1426- func (vc * vectorCentroids ) addSeedCentroid (vec []float32 ) {
1427- vc .centroids = append (vc .centroids , vec )
1428- }
1412+ for pass_idx := range indexer .NumBuildPasses () {
1413+ fmt .Println ("Building pass" , pass_idx )
14291414
1430- const numCentroids = 1000
1415+ indexer . StartBuild ( caches )
14311416
1432- func rebuildVectorIndex (ctx context.Context , factorySpecs []* tok.FactoryCreateSpec , rb * IndexRebuild ) error {
1433- pk := x.ParsedKey {Attr : rb .Attr }
1434- vc := & vectorCentroids {}
1435- vc .centroids = make ([][]float32 , 0 )
1436-
1437- MemLayerInstance .IterateDisk (ctx , IterateDiskArgs {
1438- Prefix : pk .DataPrefix (),
1439- ReadTs : rb .StartTs ,
1440- AllVersions : false ,
1441- Reverse : false ,
1442- CheckInclusion : func (uid uint64 ) error {
1443- return nil
1444- },
1445- Function : func (l * List , pk x.ParsedKey ) error {
1446- val , err := l .Value (rb .StartTs )
1447- if err != nil {
1448- return err
1449- }
1450- inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1451- vc .addSeedCentroid (inVec )
1452- if len (vc .centroids ) == numCentroids {
1453- return ErrStopIteration
1454- }
1455- return nil
1456- },
1457- StartKey : x .DataKey (rb .Attr , 0 ),
1458- })
1459-
1460- vc .randomInit ()
1461-
1462- fmt .Println ("Clustering Vectors" )
1463- for range 5 {
14641417 builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
14651418 builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
14661419 edges := []* pb.DirectedEdge {}
1467- val , err := pl .Value (txn .StartTs )
1420+ val , err := pl .Value (rb .StartTs )
14681421 if err != nil {
14691422 return []* pb.DirectedEdge {}, err
14701423 }
14711424
14721425 inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1473- vc . addVector ( inVec )
1426+ indexer . BuildInsert ( ctx , uid , inVec )
14741427 return edges , nil
14751428 }
14761429
@@ -1479,59 +1432,33 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14791432 return err
14801433 }
14811434
1482- vc . updateCentroids ()
1435+ indexer . EndBuild ()
14831436 }
14841437
1485- tcs := make ([]* hnsw.TxnCache , vc .numCenters )
1486- txns := make ([]* Txn , vc .numCenters )
1487- indexers := make ([]index.VectorIndex [float32 ], vc .numCenters )
1488- for i := 0 ; i < vc .numCenters ; i ++ {
1489- txns [i ] = NewTxn (rb .StartTs )
1490- tcs [i ] = hnsw .NewTxnCache (NewViTxn (txns [i ]), rb .StartTs )
1491- indexers_i , err := factorySpecs [0 ].CreateIndex (pk .Attr , i )
1492- if err != nil {
1493- return err
1494- }
1495- vc .mutexs [i ] = & sync.Mutex {}
1496- indexers [i ] = indexers_i
1497- }
1438+ for pass_idx := range indexer .NumIndexPasses () {
1439+ fmt .Println ("Indexing pass" , pass_idx )
14981440
1499- var edgesCreated atomic. Int64
1441+ indexer . StartBuild ( caches )
15001442
1501- numPasses := vc .numCenters / 100
1502- for pass_idx := range numPasses {
15031443 builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
15041444 builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1505- val , err := pl .Value (txn .StartTs )
1445+ edges := []* pb.DirectedEdge {}
1446+ val , err := pl .Value (rb .StartTs )
15061447 if err != nil {
15071448 return []* pb.DirectedEdge {}, err
15081449 }
15091450
15101451 inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1511- idx := vc .findCentroid (inVec )
1512- if idx % numPasses != pass_idx {
1513- return []* pb.DirectedEdge {}, nil
1514- }
1515- vc .mutexs [idx ].Lock ()
1516- defer vc .mutexs [idx ].Unlock ()
1517- _ , err = indexers [idx ].Insert (ctx , tcs [idx ], uid , inVec )
1518- if err != nil {
1519- return []* pb.DirectedEdge {}, err
1520- }
1521-
1522- edgesCreated .Add (int64 (1 ))
1523- return nil , nil
1452+ indexer .BuildInsert (ctx , uid , inVec )
1453+ return edges , nil
15241454 }
15251455
15261456 err := builder .RunWithoutTemp (ctx )
15271457 if err != nil {
15281458 return err
15291459 }
15301460
1531- for idx := range vc .counts {
1532- if idx % numPasses != pass_idx {
1533- continue
1534- }
1461+ for _ , idx := range indexer .EndBuild () {
15351462 txns [idx ].Update ()
15361463 writer := NewTxnWriter (pstore )
15371464
@@ -1547,14 +1474,132 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
15471474
15481475 txns [idx ].cache .plists = nil
15491476 txns [idx ] = nil
1550- tcs [idx ] = nil
1551- indexers [idx ] = nil
15521477 }
1553-
1554- fmt .Printf ("Created %d edges in pass %d out of %d\n " , edgesCreated .Load (), pass_idx , numPasses )
15551478 }
15561479
15571480 return nil
1481+
1482+ // MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1483+ // Prefix: pk.DataPrefix(),
1484+ // ReadTs: rb.StartTs,
1485+ // AllVersions: false,
1486+ // Reverse: false,
1487+ // CheckInclusion: func(uid uint64) error {
1488+ // return nil
1489+ // },
1490+ // Function: func(l *List, pk x.ParsedKey) error {
1491+ // val, err := l.Value(rb.StartTs)
1492+ // if err != nil {
1493+ // return err
1494+ // }
1495+ // inVec := types.BytesAsFloatArray(val.Value.([]byte))
1496+ // vc.addSeedCentroid(inVec)
1497+ // if len(vc.centroids) == numCentroids {
1498+ // return ErrStopIteration
1499+ // }
1500+ // return nil
1501+ // },
1502+ // StartKey: x.DataKey(rb.Attr, 0),
1503+ // })
1504+
1505+ // vc.randomInit()
1506+
1507+ // fmt.Println("Clustering Vectors")
1508+ // for range 5 {
1509+ // builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1510+ // builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1511+ // edges := []*pb.DirectedEdge{}
1512+ // val, err := pl.Value(txn.StartTs)
1513+ // if err != nil {
1514+ // return []*pb.DirectedEdge{}, err
1515+ // }
1516+
1517+ // inVec := types.BytesAsFloatArray(val.Value.([]byte))
1518+ // vc.addVector(inVec)
1519+ // return edges, nil
1520+ // }
1521+
1522+ // err := builder.RunWithoutTemp(ctx)
1523+ // if err != nil {
1524+ // return err
1525+ // }
1526+
1527+ // vc.updateCentroids()
1528+ // }
1529+
1530+ // tcs := make([]*hnsw.TxnCache, vc.numCenters)
1531+ // txns := make([]*Txn, vc.numCenters)
1532+ // indexers := make([]index.VectorIndex[float32], vc.numCenters)
1533+ // for i := 0; i < vc.numCenters; i++ {
1534+ // txns[i] = NewTxn(rb.StartTs)
1535+ // tcs[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
1536+ // indexers_i, err := factorySpecs[0].CreateIndex(pk.Attr, i)
1537+ // if err != nil {
1538+ // return err
1539+ // }
1540+ // vc.mutexs[i] = &sync.Mutex{}
1541+ // indexers[i] = indexers_i
1542+ // }
1543+
1544+ // var edgesCreated atomic.Int64
1545+
1546+ // numPasses := vc.numCenters / 100
1547+ // for pass_idx := range numPasses {
1548+ // builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1549+ // builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1550+ // val, err := pl.Value(txn.StartTs)
1551+ // if err != nil {
1552+ // return []*pb.DirectedEdge{}, err
1553+ // }
1554+
1555+ // inVec := types.BytesAsFloatArray(val.Value.([]byte))
1556+ // idx := vc.findCentroid(inVec)
1557+ // if idx%numPasses != pass_idx {
1558+ // return []*pb.DirectedEdge{}, nil
1559+ // }
1560+ // vc.mutexs[idx].Lock()
1561+ // defer vc.mutexs[idx].Unlock()
1562+ // _, err = indexers[idx].Insert(ctx, tcs[idx], uid, inVec)
1563+ // if err != nil {
1564+ // return []*pb.DirectedEdge{}, err
1565+ // }
1566+
1567+ // edgesCreated.Add(int64(1))
1568+ // return nil, nil
1569+ // }
1570+
1571+ // err := builder.RunWithoutTemp(ctx)
1572+ // if err != nil {
1573+ // return err
1574+ // }
1575+
1576+ // for idx := range vc.counts {
1577+ // if idx%numPasses != pass_idx {
1578+ // continue
1579+ // }
1580+ // txns[idx].Update()
1581+ // writer := NewTxnWriter(pstore)
1582+
1583+ // x.ExponentialRetry(int(x.Config.MaxRetries),
1584+ // 20*time.Millisecond, func() error {
1585+ // err := txns[idx].CommitToDisk(writer, rb.StartTs)
1586+ // if err == badger.ErrBannedKey {
1587+ // glog.Errorf("Error while writing to banned namespace.")
1588+ // return nil
1589+ // }
1590+ // return err
1591+ // })
1592+
1593+ // txns[idx].cache.plists = nil
1594+ // txns[idx] = nil
1595+ // tcs[idx] = nil
1596+ // indexers[idx] = nil
1597+ // }
1598+
1599+ // fmt.Printf("Created %d edges in pass %d out of %d\n", edgesCreated.Load(), pass_idx, numPasses)
1600+ // }
1601+
1602+ // return nil
15581603}
15591604
15601605// rebuildTokIndex rebuilds index for a given attribute.
0 commit comments