Skip to content

Commit

Permalink
RateLimitRoundtripper: Fix mutex leak and not respecting context canc…
Browse files Browse the repository at this point in the history
…ellation
  • Loading branch information
pete-woods committed Sep 3, 2024
1 parent f3792c8 commit 2795869
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
18 changes: 15 additions & 3 deletions github/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package github

import (
"bytes"
"context"
"io"
"log"
"net/http"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
// for read and write requests. See isWriteMethod for the distinction between them.
if rlt.nextRequestDelay > 0 {
log.Printf("[DEBUG] Sleeping %s between operations", rlt.nextRequestDelay)
time.Sleep(rlt.nextRequestDelay)
sleep(req.Context(), rlt.nextRequestDelay)
}

rlt.nextRequestDelay = rlt.calculateNextDelay(req.Method)
Expand All @@ -81,6 +82,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
// See https://github.com/google/go-github/pull/986
r1, r2, err := drainBody(resp.Body)
if err != nil {
rlt.smartLock(false)
return nil, err
}
resp.Body = r1
Expand All @@ -93,7 +95,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
retryAfter := arlErr.GetRetryAfter()
log.Printf("[DEBUG] Abuse detection mechanism triggered, sleeping for %s before retrying",
retryAfter)
time.Sleep(retryAfter)
sleep(req.Context(), retryAfter)
rlt.smartLock(false)
return rlt.RoundTrip(req)
}
Expand All @@ -103,7 +105,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
retryAfter := time.Until(rlErr.Rate.Reset.Time)
log.Printf("[DEBUG] Rate limit %d reached, sleeping for %s (until %s) before retrying",
rlErr.Rate.Limit, retryAfter, time.Now().Add(retryAfter))
time.Sleep(retryAfter)
sleep(req.Context(), retryAfter)
rlt.smartLock(false)
return rlt.RoundTrip(req)
}
Expand All @@ -113,6 +115,16 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
return resp, nil
}

func sleep(ctx context.Context, dur time.Duration) {
t := time.NewTimer(dur)
defer t.Stop()

select {
case <-t.C:
case <-ctx.Done():
}
}

// smartLock wraps the mutex locking system and performs its operation via a boolean input for locking and unlocking.
// It also skips the locking when parallelRequests is set to true since, in this case, the lock is not needed.
func (rlt *RateLimitTransport) smartLock(lock bool) {
Expand Down
38 changes: 38 additions & 0 deletions github/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package github

import (
"context"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -159,6 +160,43 @@ func TestRateLimitTransport_abuseLimit_get(t *testing.T) {
}
}

func TestRateLimitTransport_abuseLimit_get_cancelled(t *testing.T) {
ts := githubApiMock([]*mockResponse{
{
ExpectedUri: "/repos/test/blah",
ResponseBody: `{
"message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.",
"documentation_url": "https://developer.github.com/v3/#abuse-rate-limits"
}`,
StatusCode: 403,
ResponseHeaders: map[string]string{
"Retry-After": "10",
},
},
})
defer ts.Close()

httpClient := http.DefaultClient
httpClient.Transport = NewRateLimitTransport(http.DefaultTransport)

client := github.NewClient(httpClient)
u, _ := url.Parse(ts.URL + "/")
client.BaseURL = u

ctx := context.WithValue(context.Background(), ctxId, t.Name())
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

start := time.Now()
_, _, err := client.Repositories.Get(ctx, "test", "blah")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected context deadline exceeded, got: %v", err)
}
if time.Since(start) > time.Second {
t.Fatalf("Waited for longer than expected: %s", time.Since(start))
}
}

func TestRateLimitTransport_abuseLimit_post(t *testing.T) {
ts := githubApiMock([]*mockResponse{
{
Expand Down

0 comments on commit 2795869

Please sign in to comment.