Skip to content

Commit 14805d9

Browse files
eileenaaaHarshil Goelmatthewmcneely
authored
perf(concurrency): cancel remaining goroutines when has error (#9484)
**Description** This PR optimizes goroutine execution in multi-task scenarios: - Uses errgroup.WithContext to manage multiple goroutines and propagate cancellation signals. - When any goroutine encounters an error, errgroup automatically cancels the remaining goroutines, allowing them to exit early. - Reduces unnecessary computation and resource usage when partial failures occur. This improves system responsiveness and reduces wasted CPU cycles in error cases. **Checklist** - [x] Code compiles correctly and linting passes locally - [ ] For all _code_ changes, an entry added to the `CHANGELOG.md` file describing and linking to this PR - [ ] Tests added for new functionality, or regression tests for bug fixes added as applicable --------- Co-authored-by: Harshil Goel <[email protected]> Co-authored-by: Matthew McNeely <[email protected]>
1 parent 80cf21c commit 14805d9

1 file changed

Lines changed: 24 additions & 19 deletions

File tree

worker/task.go

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -823,9 +823,9 @@ func (qs *queryState) handleUidPostings(
823823
needFiltering := needsStringFiltering(srcFn, q.Langs, q.Attr)
824824
isList := schema.State().IsList(q.Attr)
825825

826-
errCh := make(chan error, numGo)
827826
outputs := make([]*pb.Result, numGo)
828827

828+
eg, egCtx := errgroup.WithContext(ctx)
829829
calculate := func(start, end int) error {
830830
x.AssertTrue(start%width == 0)
831831
out := &pb.Result{}
@@ -834,8 +834,8 @@ func (qs *queryState) handleUidPostings(
834834
for i := start; i < end; i++ {
835835
if i%100 == 0 {
836836
select {
837-
case <-ctx.Done():
838-
return ctx.Err()
837+
case <-egCtx.Done():
838+
return egCtx.Err()
839839
default:
840840
}
841841
}
@@ -978,14 +978,12 @@ func (qs *queryState) handleUidPostings(
978978
if end > srcFn.n {
979979
end = srcFn.n
980980
}
981-
go func(start, end int) {
982-
errCh <- calculate(start, end)
983-
}(start, end)
981+
eg.Go(func() error {
982+
return calculate(start, end)
983+
})
984984
}
985-
for range numGo {
986-
if err := <-errCh; err != nil {
987-
return err
988-
}
985+
if err := eg.Wait(); err != nil {
986+
return err
989987
}
990988
// All goroutines are done. Now attach their results.
991989
out := args.out
@@ -1635,11 +1633,20 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error
16351633
attribute.Int("num_go", numGo),
16361634
attribute.Int("width", width)))
16371635

1636+
eg, egCtx := errgroup.WithContext(ctx)
16381637
filtered := make([]*pb.List, numGo)
16391638
filter := func(idx, start, end int) error {
16401639
filtered[idx] = &pb.List{}
16411640
out := filtered[idx]
1642-
for _, uid := range uids.Uids[start:end] {
1641+
for i := start; i < end; i++ {
1642+
uid := uids.Uids[i]
1643+
if i%100 == 0 {
1644+
select {
1645+
case <-egCtx.Done():
1646+
return egCtx.Err()
1647+
default:
1648+
}
1649+
}
16431650
pl, err := qs.cache.Get(x.DataKey(attr, uid))
16441651
if err != nil {
16451652
return err
@@ -1661,21 +1668,19 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error
16611668
return nil
16621669
}
16631670

1664-
errCh := make(chan error, numGo)
16651671
for i := range numGo {
16661672
start := i * width
16671673
end := start + width
16681674
if end > len(uids.Uids) {
16691675
end = len(uids.Uids)
16701676
}
1671-
go func(idx, start, end int) {
1672-
errCh <- filter(idx, start, end)
1673-
}(i, start, end)
1677+
idx := i
1678+
eg.Go(func() error {
1679+
return filter(idx, start, end)
1680+
})
16741681
}
1675-
for range numGo {
1676-
if err := <-errCh; err != nil {
1677-
return err
1678-
}
1682+
if err := eg.Wait(); err != nil {
1683+
return err
16791684
}
16801685
final := &pb.List{}
16811686
for _, out := range filtered {

0 commit comments

Comments
 (0)