diff --git a/github/transport.go b/github/transport.go index 743c9e5de7..d8d714c7d0 100644 --- a/github/transport.go +++ b/github/transport.go @@ -2,6 +2,7 @@ package github import ( "bytes" + "context" "errors" "io" "log" @@ -66,7 +67,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err // for read and write requests. See isWriteMethod for the distinction between them. if rlt.nextRequestDelay > 0 { log.Printf("[DEBUG] Sleeping %s between operations", rlt.nextRequestDelay) - time.Sleep(rlt.nextRequestDelay) + sleep(req.Context(), rlt.nextRequestDelay) } rlt.nextRequestDelay = rlt.calculateNextDelay(req.Method) @@ -82,6 +83,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err // See https://github.com/google/go-github/pull/986 r1, r2, err := drainBody(resp.Body) if err != nil { + rlt.smartLock(false) return nil, err } resp.Body = r1 @@ -95,7 +97,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err retryAfter := arlErr.GetRetryAfter() log.Printf("[WARN] Abuse detection mechanism triggered, sleeping for %s before retrying", retryAfter) - time.Sleep(retryAfter) + sleep(req.Context(), retryAfter) rlt.smartLock(false) return rlt.RoundTrip(req) } @@ -106,7 +108,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err retryAfter := time.Until(rlErr.Rate.Reset.Time) log.Printf("[WARN] Rate limit %d reached, sleeping for %s (until %s) before retrying", rlErr.Rate.Limit, retryAfter, time.Now().Add(retryAfter)) - time.Sleep(retryAfter) + sleep(req.Context(), retryAfter) rlt.smartLock(false) return rlt.RoundTrip(req) } @@ -116,6 +118,17 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err return resp, nil } +// sleep is used an alternative to time.Sleep that supports cancellation via the passed context.Context. +func sleep(ctx context.Context, dur time.Duration) { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + case <-ctx.Done(): + } +} + // smartLock wraps the mutex locking system and performs its operation via a boolean input for locking and unlocking. // It also skips the locking when parallelRequests is set to true since, in this case, the lock is not needed. func (rlt *RateLimitTransport) smartLock(lock bool) { diff --git a/github/transport_test.go b/github/transport_test.go index fab8242a29..042a73602c 100644 --- a/github/transport_test.go +++ b/github/transport_test.go @@ -160,6 +160,42 @@ func TestRateLimitTransport_abuseLimit_get(t *testing.T) { } } +func TestRateLimitTransport_abuseLimit_get_cancelled(t *testing.T) { + ts := githubApiMock([]*mockResponse{ + { + ExpectedUri: "/repos/test/blah", + ResponseBody: `{ + "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", + "documentation_url": "https://developer.github.com/v3/#abuse-rate-limits" +}`, + StatusCode: 403, + ResponseHeaders: map[string]string{ + "Retry-After": "10", + }, + }, + }) + defer ts.Close() + + httpClient := http.DefaultClient + httpClient.Transport = NewRateLimitTransport(http.DefaultTransport) + + client := github.NewClient(httpClient) + u, _ := url.Parse(ts.URL + "/") + client.BaseURL = u + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, _, err := client.Repositories.Get(ctx, "test", "blah") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expected context deadline exceeded, got: %v", err) + } + if time.Since(start) > time.Second { + t.Fatalf("Waited for longer than expected: %s", time.Since(start)) + } +} + func TestRateLimitTransport_abuseLimit_post(t *testing.T) { ts := githubApiMock([]*mockResponse{ {