From 0e4f60cb799b0f8bfdd58b51b9ce37007f89a9a5 Mon Sep 17 00:00:00 2001 From: Agata Migalska <38040787+agatamigalska-newstore@users.noreply.github.com> Date: Mon, 12 Nov 2018 23:05:15 +0100 Subject: [PATCH] `Do` function panics when HTTPClient returns a nil response (#714) --- stripe.go | 22 ++++++++++++++-------- stripe_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/stripe.go b/stripe.go index c7fb5c8102..18c4c474aa 100644 --- a/stripe.go +++ b/stripe.go @@ -351,17 +351,22 @@ func (s *BackendImplementation) Do(req *http.Request, body *bytes.Buffer, v inte break } - resBody, err := ioutil.ReadAll(res.Body) - res.Body.Close() if err != nil { if s.LogLevel > 0 { - s.Logger.Printf("Cannot read response: %v\n", err) + s.Logger.Printf("Request failed with error: %v\n", err) + } + } else { + resBody, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + if s.LogLevel > 0 { + s.Logger.Printf("Cannot read response: %v\n", err) + } + } else { + if s.LogLevel > 0 { + s.Logger.Printf("Request failed with body: %s (status: %v)\n", string(resBody), res.StatusCode) + } } - return err - } - - if s.LogLevel > 0 { - s.Logger.Printf("Request failed with: %s (error: %v)\n", string(resBody), err) } sleepDuration := s.sleepTime(retry) @@ -381,6 +386,7 @@ func (s *BackendImplementation) Do(req *http.Request, body *bytes.Buffer, v inte } return err } + defer res.Body.Close() resBody, err := ioutil.ReadAll(res.Body) diff --git a/stripe_test.go b/stripe_test.go index c83e9780ea..bbf7b6f655 100644 --- a/stripe_test.go +++ b/stripe_test.go @@ -9,10 +9,12 @@ import ( "regexp" "runtime" "sync" + "sync/atomic" "testing" + "time" assert "github.com/stretchr/testify/require" - stripe "github.com/stripe/stripe-go" + "github.com/stripe/stripe-go" . "github.com/stripe/stripe-go/testing" ) @@ -118,6 +120,51 @@ func TestDo_Retry(t *testing.T) { assert.Equal(t, 2, requestNum) } +func TestDo_RetryOnTimeout(t *testing.T) { + type testServerResponse struct { + Message string `json:"message"` + } + + timeout := time.Second + var counter uint32 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&counter, 1) + time.Sleep(timeout) + })) + defer testServer.Close() + + backend := stripe.GetBackendWithConfig( + stripe.APIBackend, + &stripe.BackendConfig{ + LogLevel: 3, + MaxNetworkRetries: 1, + URL: testServer.URL, + HTTPClient: &http.Client{Timeout: timeout}, + }, + ).(*stripe.BackendImplementation) + + backend.SetNetworkRetriesSleep(false) + + request, err := backend.NewRequest( + http.MethodPost, + "/hello", + "sk_test_123", + "application/x-www-form-urlencoded", + nil, + ) + assert.NoError(t, err) + + var body = bytes.NewBufferString("foo=bar") + var response testServerResponse + + err = backend.Do(request, body, &response) + + assert.Error(t, err) + // timeout should not prevent retry + assert.Equal(t, uint32(2), atomic.LoadUint32(&counter)) +} + func TestFormatURLPath(t *testing.T) { assert.Equal(t, "/resources/1/subresources/2", stripe.FormatURLPath("/resources/%s/subresources/%s", "1", "2"))