Skip to content

Commit 2549c89

Browse files
committed
use errgroup to cancel goroutine
1 parent 9fed7b1 commit 2549c89

1 file changed

Lines changed: 17 additions & 36 deletions

File tree

worker/task.go

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -795,20 +795,18 @@ func (qs *queryState) handleUidPostings(
795795
needFiltering := needsStringFiltering(srcFn, q.Langs, q.Attr)
796796
isList := schema.State().IsList(q.Attr)
797797

798-
errCh := make(chan error, numGo)
799798
outputs := make([]*pb.Result, numGo)
800799

801-
cctx, ccancel := context.WithCancel(ctx)
802-
defer ccancel()
800+
eg, egCtx := errgroup.WithContext(ctx)
803801
calculate := func(start, end int) error {
804802
x.AssertTrue(start%width == 0)
805803
out := &pb.Result{}
806804
outputs[start/width] = out
807805

808806
for i := start; i < end; i++ {
809807
select {
810-
case <-cctx.Done():
811-
return cctx.Err()
808+
case <-egCtx.Done():
809+
return egCtx.Err()
812810
default:
813811
}
814812
if i%100 == 0 {
@@ -957,20 +955,12 @@ func (qs *queryState) handleUidPostings(
957955
if end > srcFn.n {
958956
end = srcFn.n
959957
}
960-
go func(start, end int) {
961-
if err := calculate(start, end); err != nil {
962-
errCh <- err
963-
ccancel()
964-
return
965-
} else {
966-
errCh <- nil
967-
}
968-
}(start, end)
958+
eg.Go(func() error {
959+
return calculate(start, end)
960+
})
969961
}
970-
for range numGo {
971-
if err := <-errCh; err != nil {
972-
return err
973-
}
962+
if err := eg.Wait(); err != nil {
963+
return err
974964
}
975965
// All goroutines are done. Now attach their results.
976966
out := args.out
@@ -1610,16 +1600,15 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error
16101600
attribute.Int("num_go", numGo),
16111601
attribute.Int("width", width)))
16121602

1613-
cctx, ccancel := context.WithCancel(ctx)
1614-
defer ccancel()
1603+
eg, egCtx := errgroup.WithContext(ctx)
16151604
filtered := make([]*pb.List, numGo)
16161605
filter := func(idx, start, end int) error {
16171606
filtered[idx] = &pb.List{}
16181607
out := filtered[idx]
16191608
for _, uid := range uids.Uids[start:end] {
16201609
select {
1621-
case <-cctx.Done():
1622-
return cctx.Err()
1610+
case <-egCtx.Done():
1611+
return egCtx.Err()
16231612
default:
16241613
}
16251614
pl, err := qs.cache.Get(x.DataKey(attr, uid))
@@ -1643,27 +1632,19 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error
16431632
return nil
16441633
}
16451634

1646-
errCh := make(chan error, numGo)
16471635
for i := range numGo {
16481636
start := i * width
16491637
end := start + width
16501638
if end > len(uids.Uids) {
16511639
end = len(uids.Uids)
16521640
}
1653-
go func(idx, start, end int) {
1654-
if err := filter(idx, start, end); err != nil {
1655-
errCh <- err
1656-
ccancel()
1657-
return
1658-
} else {
1659-
errCh <- nil
1660-
}
1661-
}(i, start, end)
1641+
idx := i
1642+
eg.Go(func() error {
1643+
return filter(idx, start, end)
1644+
})
16621645
}
1663-
for range numGo {
1664-
if err := <-errCh; err != nil {
1665-
return err
1666-
}
1646+
if err := eg.Wait(); err != nil {
1647+
return err
16671648
}
16681649
final := &pb.List{}
16691650
for _, out := range filtered {

0 commit comments

Comments
 (0)