@@ -15,6 +15,7 @@ import (
1515 "math"
1616 "os"
1717 "strings"
18+ "sync"
1819 "sync/atomic"
1920 "time"
2021 "unsafe"
@@ -33,8 +34,11 @@ import (
3334 "github.com/hypermodeinc/dgraph/v25/schema"
3435 "github.com/hypermodeinc/dgraph/v25/tok"
3536 "github.com/hypermodeinc/dgraph/v25/tok/hnsw"
37+ "github.com/hypermodeinc/dgraph/v25/tok/index"
3638 "github.com/hypermodeinc/dgraph/v25/types"
3739 "github.com/hypermodeinc/dgraph/v25/x"
40+
41+ "github.com/viterin/vek/vek32"
3842)
3943
4044var emptyCountParams countParams
@@ -162,7 +166,7 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
162166 // retrieve vector from inUuid save as inVec
163167 inVec := types .BytesAsFloatArray (data [0 ].Value .([]byte ))
164168 tc := hnsw .NewTxnCache (NewViTxn (txn ), txn .StartTs )
165- indexer , err := info .factorySpecs [0 ].CreateIndex (attr )
169+ indexer , err := info .factorySpecs [0 ].CreateIndex (attr , 0 )
166170 if err != nil {
167171 return []* pb.DirectedEdge {}, err
168172 }
@@ -1361,6 +1365,198 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
13611365 return prefixes , nil
13621366}
13631367
1368+ type vectorCentroids struct {
1369+ dimension int
1370+ numCenters int
1371+
1372+ centroids [][]float32
1373+ counts []int64
1374+ weights [][]float32
1375+ mutexs []* sync.Mutex
1376+ }
1377+
1378+ func (vc * vectorCentroids ) findCentroid (input []float32 ) int {
1379+ minIdx := 0
1380+ minDist := math .MaxFloat32
1381+ for i , centroid := range vc .centroids {
1382+ dist := vek32 .Distance (centroid , input )
1383+ if float64 (dist ) < minDist {
1384+ minDist = float64 (dist )
1385+ minIdx = i
1386+ }
1387+ }
1388+ return minIdx
1389+ }
1390+
1391+ func (vc * vectorCentroids ) addVector (vec []float32 ) {
1392+ idx := vc .findCentroid (vec )
1393+ vc .mutexs [idx ].Lock ()
1394+ defer vc .mutexs [idx ].Unlock ()
1395+ for i := 0 ; i < vc .dimension ; i ++ {
1396+ vc.weights [idx ][i ] += vec [i ]
1397+ }
1398+ vc .counts [idx ]++
1399+ }
1400+
1401+ func (vc * vectorCentroids ) updateCentroids () {
1402+ for i := 0 ; i < vc .numCenters ; i ++ {
1403+ for j := 0 ; j < vc .dimension ; j ++ {
1404+ vc.centroids [i ][j ] = vc.weights [i ][j ] / float32 (vc .counts [i ])
1405+ vc.weights [i ][j ] = 0
1406+ }
1407+ fmt .Printf ("%d, " , vc .counts [i ])
1408+ vc .counts [i ] = 0
1409+ }
1410+ fmt .Println ()
1411+ }
1412+
1413+ func (vc * vectorCentroids ) randomInit () {
1414+ vc .dimension = len (vc .centroids [0 ])
1415+ vc .numCenters = len (vc .centroids )
1416+ vc .centroids = make ([][]float32 , vc .numCenters )
1417+ vc .counts = make ([]int64 , vc .numCenters )
1418+ vc .weights = make ([][]float32 , vc .numCenters )
1419+ vc .mutexs = make ([]* sync.Mutex , vc .numCenters )
1420+ for i := 0 ; i < vc .numCenters ; i ++ {
1421+ vc .weights [i ] = make ([]float32 , vc .dimension )
1422+ vc .counts [i ] = 0
1423+ vc .mutexs [i ] = & sync.Mutex {}
1424+ }
1425+ }
1426+
1427+ func (vc * vectorCentroids ) addSeedCentroid (vec []float32 ) {
1428+ vc .centroids = append (vc .centroids , vec )
1429+ }
1430+
1431+ const numCentroids = 1000
1432+
1433+ func rebuildVectorIndex (ctx context.Context , factorySpecs []* tok.FactoryCreateSpec , rb * IndexRebuild ) error {
1434+ pk := x.ParsedKey {Attr : rb .Attr }
1435+ vc := & vectorCentroids {}
1436+
1437+ MemLayerInstance .IterateDisk (ctx , IterateDiskArgs {
1438+ Prefix : pk .DataPrefix (),
1439+ ReadTs : rb .StartTs ,
1440+ AllVersions : false ,
1441+ Reverse : false ,
1442+ CheckInclusion : func (uid uint64 ) error {
1443+ return nil
1444+ },
1445+ Function : func (l * List , pk x.ParsedKey ) error {
1446+ val , err := l .Value (rb .StartTs )
1447+ if err != nil {
1448+ return err
1449+ }
1450+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1451+ vc .addSeedCentroid (inVec )
1452+ if len (vc .centroids ) == numCentroids {
1453+ return ErrStopIteration
1454+ }
1455+ return nil
1456+ },
1457+ StartKey : x .DataKey (rb .Attr , 0 ),
1458+ })
1459+
1460+ vc .randomInit ()
1461+
1462+ fmt .Println ("Clustering Vectors" )
1463+ for range 5 {
1464+ builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
1465+ builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1466+ edges := []* pb.DirectedEdge {}
1467+ val , err := pl .Value (txn .StartTs )
1468+ if err != nil {
1469+ return []* pb.DirectedEdge {}, err
1470+ }
1471+
1472+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1473+ vc .addVector (inVec )
1474+ return edges , nil
1475+ }
1476+
1477+ err := builder .RunWithoutTemp (ctx )
1478+ if err != nil {
1479+ return err
1480+ }
1481+
1482+ vc .updateCentroids ()
1483+ }
1484+
1485+ tcs := make ([]* hnsw.TxnCache , vc .numCenters )
1486+ txns := make ([]* Txn , vc .numCenters )
1487+ indexers := make ([]index.VectorIndex [float32 ], vc .numCenters )
1488+ for i := 0 ; i < vc .numCenters ; i ++ {
1489+ txns [i ] = NewTxn (rb .StartTs )
1490+ tcs [i ] = hnsw .NewTxnCache (NewViTxn (txns [i ]), rb .StartTs )
1491+ indexers_i , err := factorySpecs [0 ].CreateIndex (pk .Attr , i )
1492+ if err != nil {
1493+ return err
1494+ }
1495+ vc .mutexs [i ] = & sync.Mutex {}
1496+ indexers [i ] = indexers_i
1497+ }
1498+
1499+ var edgesCreated atomic.Int64
1500+
1501+ numPasses := vc .numCenters / 100
1502+ for pass_idx := range numPasses {
1503+ builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
1504+ builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1505+ val , err := pl .Value (txn .StartTs )
1506+ if err != nil {
1507+ return []* pb.DirectedEdge {}, err
1508+ }
1509+
1510+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1511+ idx := vc .findCentroid (inVec )
1512+ if idx % numPasses != pass_idx {
1513+ return []* pb.DirectedEdge {}, nil
1514+ }
1515+ vc .mutexs [idx ].Lock ()
1516+ defer vc .mutexs [idx ].Unlock ()
1517+ _ , err = indexers [idx ].Insert (ctx , tcs [idx ], uid , inVec )
1518+ if err != nil {
1519+ return []* pb.DirectedEdge {}, err
1520+ }
1521+
1522+ edgesCreated .Add (int64 (1 ))
1523+ return nil , nil
1524+ }
1525+
1526+ err := builder .RunWithoutTemp (ctx )
1527+ if err != nil {
1528+ return err
1529+ }
1530+
1531+ for idx := range vc .counts {
1532+ if idx % numPasses != pass_idx {
1533+ continue
1534+ }
1535+ txns [idx ].Update ()
1536+ writer := NewTxnWriter (pstore )
1537+
1538+ x .ExponentialRetry (int (x .Config .MaxRetries ),
1539+ 20 * time .Millisecond , func () error {
1540+ err := txns [idx ].CommitToDisk (writer , rb .StartTs )
1541+ if err == badger .ErrBannedKey {
1542+ glog .Errorf ("Error while writing to banned namespace." )
1543+ return nil
1544+ }
1545+ return err
1546+ })
1547+
1548+ txns [idx ].cache .plists = nil
1549+ txns [idx ] = nil
1550+ tcs [idx ] = nil
1551+ indexers [idx ] = nil
1552+ }
1553+
1554+ fmt .Printf ("Created %d edges in pass %d out of %d\n " , edgesCreated .Load (), pass_idx , numPasses )
1555+ }
1556+
1557+ return nil
1558+ }
1559+
13641560// rebuildTokIndex rebuilds index for a given attribute.
13651561// We commit mutations with startTs and ignore the errors.
13661562func rebuildTokIndex (ctx context.Context , rb * IndexRebuild ) error {
@@ -1392,6 +1588,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
13921588 }
13931589
13941590 runForVectors := (len (factorySpecs ) != 0 )
1591+ if runForVectors {
1592+ return rebuildVectorIndex (ctx , factorySpecs , rb )
1593+ }
13951594
13961595 pk := x.ParsedKey {Attr : rb .Attr }
13971596 builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
0 commit comments