Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions github/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package github

import (
"bytes"
"context"
"errors"
"io"
"log"
Expand Down Expand Up @@ -66,7 +67,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
// for read and write requests. See isWriteMethod for the distinction between them.
if rlt.nextRequestDelay > 0 {
log.Printf("[DEBUG] Sleeping %s between operations", rlt.nextRequestDelay)
time.Sleep(rlt.nextRequestDelay)
sleep(req.Context(), rlt.nextRequestDelay)
}

rlt.nextRequestDelay = rlt.calculateNextDelay(req.Method)
Expand All @@ -82,6 +83,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
// See https://github.com/google/go-github/pull/986
r1, r2, err := drainBody(resp.Body)
if err != nil {
rlt.smartLock(false)
Comment thread
pete-woods marked this conversation as resolved.
return nil, err
}
resp.Body = r1
Expand All @@ -95,7 +97,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
retryAfter := arlErr.GetRetryAfter()
log.Printf("[WARN] Abuse detection mechanism triggered, sleeping for %s before retrying",
retryAfter)
time.Sleep(retryAfter)
sleep(req.Context(), retryAfter)
rlt.smartLock(false)
Comment thread
pete-woods marked this conversation as resolved.
return rlt.RoundTrip(req)
}
Expand All @@ -106,7 +108,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
retryAfter := time.Until(rlErr.Rate.Reset.Time)
log.Printf("[WARN] Rate limit %d reached, sleeping for %s (until %s) before retrying",
rlErr.Rate.Limit, retryAfter, time.Now().Add(retryAfter))
time.Sleep(retryAfter)
sleep(req.Context(), retryAfter)
rlt.smartLock(false)
return rlt.RoundTrip(req)
}
Expand All @@ -116,6 +118,17 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
return resp, nil
}

// sleep is used an alternative to time.Sleep that supports cancellation via the passed context.Context.
func sleep(ctx context.Context, dur time.Duration) {
Comment thread
pete-woods marked this conversation as resolved.
t := time.NewTimer(dur)
defer t.Stop()

select {
case <-t.C:
case <-ctx.Done():
}
}

// smartLock wraps the mutex locking system and performs its operation via a boolean input for locking and unlocking.
// It also skips the locking when parallelRequests is set to true since, in this case, the lock is not needed.
func (rlt *RateLimitTransport) smartLock(lock bool) {
Expand Down
36 changes: 36 additions & 0 deletions github/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,42 @@ func TestRateLimitTransport_abuseLimit_get(t *testing.T) {
}
}

func TestRateLimitTransport_abuseLimit_get_cancelled(t *testing.T) {
ts := githubApiMock([]*mockResponse{
{
ExpectedUri: "/repos/test/blah",
ResponseBody: `{
"message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.",
"documentation_url": "https://developer.github.com/v3/#abuse-rate-limits"
}`,
StatusCode: 403,
ResponseHeaders: map[string]string{
"Retry-After": "10",
},
},
})
defer ts.Close()

httpClient := http.DefaultClient
httpClient.Transport = NewRateLimitTransport(http.DefaultTransport)

client := github.NewClient(httpClient)
u, _ := url.Parse(ts.URL + "/")
client.BaseURL = u

ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer cancel()

start := time.Now()
_, _, err := client.Repositories.Get(ctx, "test", "blah")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected context deadline exceeded, got: %v", err)
}
if time.Since(start) > time.Second {
t.Fatalf("Waited for longer than expected: %s", time.Since(start))
}
}

func TestRateLimitTransport_abuseLimit_post(t *testing.T) {
ts := githubApiMock([]*mockResponse{
{
Expand Down