Skip to content

Commit f5ddbbd

Browse files
ShivaShiva
authored andcommitted
add more tests
1 parent b6c53f0 commit f5ddbbd

3 files changed

Lines changed: 33 additions & 51 deletions

File tree

dgraph/cmd/bulk/reduce.go

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,28 @@ func (r *reducer) run() error {
6262

6363
if len(vectorIndexSpecs) > 0 {
6464
fmt.Printf("Creating shared vector database for %d vector predicate(s)\n", len(vectorIndexSpecs))
65+
// Track which predicates belong to which output shard
66+
predToOutputShard = make(map[string]int)
6567

66-
// Create single shared vectorTmpDb
6768
sharedVectorDb = r.createVectorTmpBadger()
6869

69-
// Initialize posting and schema ONCE (avoids race condition!)
7070
posting.Init(sharedVectorDb, 0, false)
7171
schema.Init(sharedVectorDb)
7272
for pred, sch := range r.schema.schemaMap {
73+
_, ok := vectorIndexSpecs[pred]
74+
if !ok {
75+
continue
76+
}
7377
schema.State().Set(pred, sch)
7478
}
75-
76-
// Track which predicates belong to which output shard
77-
predToOutputShard = make(map[string]int)
7879
}
7980

8081
thr := y.NewThrottle(r.opt.NumReducers)
8182
for i := range r.opt.ReduceShards {
8283
if err := thr.Do(); err != nil {
8384
return err
8485
}
85-
go func(shardId int, db *badger.DB, tmpDb *badger.DB) {
86+
go func(shardId int, db *badger.DB, tmpDb *badger.DB, vectorTmpDb *badger.DB) {
8687
defer thr.Done(nil)
8788

8889
mapFiles := filenamesInTree(dirs[shardId])
@@ -117,10 +118,10 @@ func (r *reducer) run() error {
117118

118119
// Create vector indexer using shared DB (if vectors exist)
119120
var vi *vectorIndexer
120-
if sharedVectorDb != nil && len(vectorIndexSpecs) > 0 {
121+
if vectorTmpDb != nil && len(vectorIndexSpecs) > 0 {
121122
fmt.Printf("Initializing vector indexer for shard %d with %d predicate(s)\n",
122123
shardId, len(vectorIndexSpecs))
123-
vi = newVectorIndexerShared(r, sharedVectorDb, vectorIndexSpecs,
124+
vi = newVectorIndexerShared(r, vectorTmpDb, vectorIndexSpecs,
124125
shardId, &predToShardMu, predToOutputShard)
125126
}
126127

@@ -149,7 +150,7 @@ func (r *reducer) run() error {
149150
fmt.Printf("Error while closing iterator: %v", err)
150151
}
151152
}
152-
}(i, r.createBadger(i), r.createTmpBadger())
153+
}(i, r.createBadger(i), r.createTmpBadger(), sharedVectorDb)
153154
}
154155
if err := thr.Finish(); err != nil {
155156
return err
@@ -294,8 +295,7 @@ func newMapIterator(filename string) (*pb.MapHeader, *mapIterator) {
294295
type encodeRequest struct {
295296
cbuf *z.Buffer
296297
countBuf *z.Buffer
297-
vectorBuf *z.Buffer // Buffer for vector entries to be indexed
298-
vi *vectorIndexer // Vector indexer for routing vector predicates to tmpDb
298+
vectorBuf *z.Buffer // Buffer for vector entries to be indexed
299299
wg *sync.WaitGroup
300300
listCh chan *z.Buffer
301301
splitCh chan *bpb.KVList
@@ -318,11 +318,11 @@ func (r *reducer) streamIdFor(pred string) uint32 {
318318
return streamId
319319
}
320320

321-
func (r *reducer) encode(entryCh chan *encodeRequest, closer *z.Closer) {
321+
func (r *reducer) encode(entryCh chan *encodeRequest, vi *vectorIndexer, closer *z.Closer) {
322322
defer closer.Done()
323323

324324
for req := range entryCh {
325-
r.toList(req)
325+
r.toList(req, vi)
326326
req.wg.Done()
327327
}
328328
}
@@ -470,6 +470,9 @@ func (r *reducer) startWriting(ci *countIndexer, vi *vectorIndexer, writerCh cha
470470

471471
count(req)
472472
if vi != nil {
473+
if err := vi.flushWriteBatch(); err != nil {
474+
glog.Errorf("Error flushing vector write batch before HNSW insertion: %v", err)
475+
}
473476
vector(req)
474477
}
475478
}
@@ -646,7 +649,7 @@ func (r *reducer) reduce(partitionKeys [][]byte, mapItrs []*mapIterator, ci *cou
646649
for range cpu {
647650
// Start listening to encode entries
648651
// For time being let's lease 100 stream id for each encoder.
649-
go r.encode(encoderCh, encoderCloser)
652+
go r.encode(encoderCh, vi, encoderCloser)
650653
}
651654
// Start listening to write the badger list.
652655
writerCloser := z.NewCloser(1)
@@ -661,7 +664,6 @@ func (r *reducer) reduce(partitionKeys [][]byte, mapItrs []*mapIterator, ci *cou
661664
listCh: make(chan *z.Buffer, 3),
662665
splitCh: ci.splitCh,
663666
countBuf: getBuf(r.opt.TmpDir),
664-
vi: vi,
665667
}
666668
// Only allocate vectorBuf when we have vector predicates to index
667669
if vi != nil {
@@ -733,7 +735,7 @@ func (r *reducer) reduce(partitionKeys [][]byte, mapItrs []*mapIterator, ci *cou
733735
writerCloser.SignalAndWait()
734736
}
735737

736-
func (r *reducer) toList(req *encodeRequest) {
738+
func (r *reducer) toList(req *encodeRequest, vi *vectorIndexer) {
737739
cbuf := req.cbuf
738740
defer func() {
739741
atomic.AddInt64(&r.prog.numEncoding, -int64(cbuf.LenNoPadding()))
@@ -888,8 +890,8 @@ func (r *reducer) toList(req *encodeRequest) {
888890
}
889891
}
890892

891-
// Check if this is a vector predicate that should be routed to tmpDb
892-
isVectorPred := req.vi != nil && pk.IsData() && req.vi.isVectorPredicate(pk.Attr)
893+
// Check if this is a vector predicate that should be routed to vectorTmpDb
894+
isVectorPred := vi != nil && pk.IsData() && vi.isVectorPredicate(pk.Attr)
893895

894896
shouldSplit := proto.Size(pl) > (1<<20)/2 && len(pl.Pack.Blocks) > 1
895897
if shouldSplit {
@@ -908,7 +910,7 @@ func (r *reducer) toList(req *encodeRequest) {
908910
// Vector predicates go to vectorTmpDb
909911
for _, kv := range kvs {
910912
kv.Version = writeVersionTs
911-
if err := req.vi.writeVectorKV(kv); err != nil {
913+
if err := vi.writeVectorKV(kv); err != nil {
912914
glog.Errorf("Error writing vector posting to tmpDb: %v", err)
913915
}
914916
}
@@ -931,7 +933,7 @@ func (r *reducer) toList(req *encodeRequest) {
931933

932934
if isVectorPred {
933935
// Vector predicates go to vectorTmpDb
934-
if err := req.vi.writeVectorKV(kv); err != nil {
936+
if err := vi.writeVectorKV(kv); err != nil {
935937
glog.Errorf("Error writing vector posting to tmpDb: %v", err)
936938
}
937939
} else {

dgraph/cmd/bulk/vector_indexer.go

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type vectorIndexer struct {
4141
predToShardMu *sync.Mutex // Mutex for predToOutputShard (shared across shards)
4242
predToOutputShard map[string]int // Predicate → output shard mapping (shared across shards)
4343

44-
// For batched writes
44+
// For batched writes of vector posting lists to shared DB
4545
writeBatch *badger.WriteBatch
4646
writeCount int
4747

@@ -143,7 +143,6 @@ func unmarshalVectorEntry(data []byte) *vectorEntry {
143143
}
144144

145145
// newVectorIndexerShared creates a new vectorIndexer for a shard using a shared vectorTmpDb.
146-
// This avoids the global pstore race condition by using a single shared DB with posting.Init()
147146
// called once before any shards start. Indexers are created lazily when vectors arrive.
148147
func newVectorIndexerShared(r *reducer, sharedVectorDb *badger.DB, indexSpecs map[string]*pb.VectorIndexSpec,
149148
shardId int, predToShardMu *sync.Mutex, predToOutputShard map[string]int) *vectorIndexer {
@@ -161,13 +160,6 @@ func newVectorIndexerShared(r *reducer, sharedVectorDb *badger.DB, indexSpecs ma
161160
predToShardMu: predToShardMu,
162161
predToOutputShard: predToOutputShard,
163162
}
164-
165-
// NOTE: posting.Init() and schema.Init() are called ONCE in reduce.go
166-
// before any shards start, so we don't call them here.
167-
168-
// NOTE: Indexers are created LAZILY when vectors arrive for a predicate.
169-
// This avoids creating indexers for predicates that don't exist in this shard.
170-
171163
glog.Infof("Vector indexer created for shard %d (lazy initialization, %d potential predicates)",
172164
shardId, len(indexSpecs))
173165

@@ -180,7 +172,6 @@ func (vi *vectorIndexer) getOrCreateIndexer(pred string) (index.VectorIndex[floa
180172
vi.mu.Lock()
181173
defer vi.mu.Unlock()
182174

183-
// Already created?
184175
if indexer, ok := vi.indexers[pred]; ok {
185176
return indexer, vi.txnCaches[pred], nil
186177
}

systest/vector/load_test.go

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func TestBulkLoadVectorIndex(t *testing.T) {
219219
t.Log("Step 6: Verifying vector similarity queries work on bulk loaded data...")
220220
fmt.Println("vectors: ", len(vectors))
221221
for i, vector := range vectors {
222-
similarVectors, err := targetGc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10)
222+
similarVectors, err := targetGc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 5)
223223
require.NoError(t, err)
224224
require.GreaterOrEqual(t, len(similarVectors), 4,
225225
"similar_to query should return results for vector %d", i)
@@ -255,9 +255,6 @@ func vectorsEqual(a, b []float32) bool {
255255
return true
256256
}
257257

258-
// TestBulkLoadVectorIndexMultipleGroups tests bulk loading vector data with multiple
259-
// alpha groups (shards). This ensures vector indexing works correctly when predicates
260-
// are distributed across different shards.
261258
func TestBulkLoadVectorIndexMultipleGroups(t *testing.T) {
262259
// if runtime.GOOS != "linux" && os.Getenv("DGRAPH_BINARY") == "" {
263260
// fmt.Println("You can set the DGRAPH_BINARY environment variable to path of a native dgraph binary to run these tests")
@@ -415,7 +412,7 @@ func TestBulkLoadVectorIndexMultipleGroups(t *testing.T) {
415412
sampleSize := 10
416413

417414
for i := 0; i < sampleSize; i++ {
418-
similarVectors, err := targetGc.QueryMultipleVectorsUsingSimilarTo(vectors[i], pred, 10)
415+
similarVectors, err := targetGc.QueryMultipleVectorsUsingSimilarTo(vectors[i], pred, 5)
419416
require.NoError(t, err)
420417
require.GreaterOrEqual(t, len(similarVectors), 4,
421418
"similar_to query should return results for predicate %s vector %d", pred, i)
@@ -437,7 +434,7 @@ func TestBulkLoadMixedPredicates(t *testing.T) {
437434

438435
// Schema with vectors AND other indexed predicates
439436
mixedSchema := `
440-
vec_embedding: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .
437+
project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .
441438
name: string @index(term, fulltext) .
442439
age: int @index(int) .
443440
score: float .
@@ -485,7 +482,7 @@ func TestBulkLoadMixedPredicates(t *testing.T) {
485482
vecStr := fmt.Sprintf(`"[%s]"`, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(vec)), ", "), "[]"))
486483

487484
// Add vector predicate
488-
rdfBuilder.WriteString(fmt.Sprintf("<0x%x> <vec_embedding> %s .\n", uid, vecStr))
485+
rdfBuilder.WriteString(fmt.Sprintf("<0x%x> <project_description_v> %s .\n", uid, vecStr))
489486
// Add string predicate
490487
rdfBuilder.WriteString(fmt.Sprintf("<0x%x> <name> \"Person %d\" .\n", uid, i))
491488
// Add int predicate
@@ -597,7 +594,7 @@ func TestBulkLoadMixedPredicates(t *testing.T) {
597594

598595
// Verify vector similarity query
599596
similarQuery := fmt.Sprintf(`{
600-
vector(func: similar_to(vec_embedding, 5, "%v")) {
597+
vector(func: similar_to(project_description_v, 5, "%v")) {
601598
uid
602599
name
603600
}
@@ -610,8 +607,6 @@ func TestBulkLoadMixedPredicates(t *testing.T) {
610607
t.Log("All mixed predicate types verified successfully!")
611608
}
612609

613-
// TestBulkLoadVectorDimensions tests bulk loading vectors with different dimensions
614-
// to ensure the implementation handles various vector sizes correctly.
615610
func TestBulkLoadVectorDimensions(t *testing.T) {
616611
// if runtime.GOOS != "linux" && os.Getenv("DGRAPH_BINARY") == "" {
617612
// t.Skip("Skipping test on non-Linux platforms due to dgraph binary dependency")
@@ -630,7 +625,7 @@ func TestBulkLoadVectorDimensions(t *testing.T) {
630625

631626
for _, tc := range testCases {
632627
t.Run(tc.name, func(t *testing.T) {
633-
predName := fmt.Sprintf("vec_%s", tc.name)
628+
predName := "project_description_v"
634629
schema := fmt.Sprintf(`%s: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .`, predName)
635630

636631
// Step 1: Create source cluster
@@ -704,7 +699,7 @@ func TestBulkLoadVectorDimensions(t *testing.T) {
704699

705700
targetCluster, err := dgraphtest.NewLocalCluster(targetConf)
706701
require.NoError(t, err)
707-
defer func() { targetCluster.Cleanup(t.Failed()) }()
702+
// defer func() { targetCluster.Cleanup(t.Failed()) }()
708703
require.NoError(t, targetCluster.Start())
709704

710705
targetGc, targetCleanup, err := targetCluster.Client()
@@ -724,15 +719,13 @@ func TestBulkLoadVectorDimensions(t *testing.T) {
724719
similarVectors, err := targetGc.QueryMultipleVectorsUsingSimilarTo(vector, predName, 5)
725720
require.NoError(t, err)
726721
require.GreaterOrEqual(t, len(similarVectors), 4,
727-
"similar_to query should return results for vector %d")
722+
"similar_to query should return results for vector")
728723
}
729724

730725
})
731726
}
732727
}
733728

734-
// TestBulkLoadVectorMetrics tests bulk loading vectors with different distance metrics
735-
// (euclidean, cosine, dotproduct) to ensure all HNSW configurations work correctly.
736729
func TestBulkLoadVectorMetrics(t *testing.T) {
737730
if runtime.GOOS != "linux" && os.Getenv("DGRAPH_BINARY") == "" {
738731
t.Skip("Skipping test on non-Linux platforms due to dgraph binary dependency")
@@ -741,13 +734,13 @@ func TestBulkLoadVectorMetrics(t *testing.T) {
741734
metrics := []string{"euclidean", "cosine", "dotproduct"}
742735
numVectors := 200
743736
vectorDim := 10
737+
predName := "project_description_v"
744738

745739
// Build schema with all metric types
746740
var schemaBuilder strings.Builder
747741
for _, metric := range metrics {
748742
schemaBuilder.WriteString(fmt.Sprintf(
749-
"vec_%s: float32vector @index(hnsw(exponent: \"5\", metric: \"%s\")) .\n",
750-
metric, metric))
743+
"project_description_v: float32vector @index(hnsw(exponent: \"5\", metric: \"%s\")) .\n", metric))
751744
}
752745
schema := schemaBuilder.String()
753746

@@ -779,7 +772,6 @@ func TestBulkLoadVectorMetrics(t *testing.T) {
779772
// Generate and load vectors for each metric type
780773
allVectors := make(map[string][][]float32)
781774
for _, metric := range metrics {
782-
predName := fmt.Sprintf("vec_%s", metric)
783775
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, vectorDim, predName)
784776
allVectors[predName] = vectors
785777

@@ -845,7 +837,6 @@ func TestBulkLoadVectorMetrics(t *testing.T) {
845837
// Step 5: Verify each metric type
846838
t.Log("Step 5: Verifying each metric type...")
847839
for _, metric := range metrics {
848-
predName := fmt.Sprintf("vec_%s", metric)
849840
vectors := allVectors[predName]
850841

851842
// Verify count
@@ -872,8 +863,6 @@ func TestBulkLoadVectorMetrics(t *testing.T) {
872863
t.Log("All vector metrics verified successfully!")
873864
}
874865

875-
// TestBulkLoadVectorEdgeCases tests edge cases like empty vector predicates,
876-
// single vector, and predicates with no data.
877866
func TestBulkLoadVectorEdgeCases(t *testing.T) {
878867
// if runtime.GOOS != "linux" && os.Getenv("DGRAPH_BINARY") == "" {
879868
// t.Skip("Skipping test on non-Linux platforms due to dgraph binary dependency")

0 commit comments

Comments
 (0)