From 4b4a17b109d5bd3c1bc93ffd51bece74fcc59ea9 Mon Sep 17 00:00:00 2001 From: Brandur Date: Mon, 6 Aug 2018 13:54:45 -0700 Subject: [PATCH] Introduce getters so a request's body can be read multiple times In #646, I attempted to fix a problem whereby when using HTTP/2 we were sending invalid requests during automatic retries because a `Request` struct with a body cannot be reused in the context of HTTP/2. Unfortunately, it wasn't blowing up quite as spectacularly as before, the fix didn't quite work either. It would set a new `Body` before retries, but that body would be pointing to an `io.Reader` that was already exhausted as it had been fully consumed during the original request, thereby producing an empty request body. Here we modify the `Do` and `CallMultipart` interfaces so that they take body buffers instead of just readers. When making a request, we create a new reader for those buffers every time, thus ensuring a fresh one. This is relatively trivial when making non-multipart requests (which is most requests) because we were already producing a buffer in almost the right place. It's a little more complicated for multipart requests (file uploads) because we had encoding scheme that wasn't quite compatible, but I've done some refactoring there (and added more tests) to bring things in line. We also add a much improved test framework this time around that verifies that the problem is fixed and will definitively stay fixed. Fixes #647. --- fileupload.go | 26 +++++------ fileupload/client.go | 6 +-- fileupload/client_test.go | 2 +- fileupload_test.go | 27 ++++++++++++ stripe.go | 93 +++++++++++++++++++++++++-------------- stripe_test.go | 81 ++++++++++++++++++++++++++++++++++ 6 files changed, 182 insertions(+), 53 deletions(-) diff --git a/fileupload.go b/fileupload.go index 1f75c51549..a99dfe06ce 100644 --- a/fileupload.go +++ b/fileupload.go @@ -1,6 +1,7 @@ package stripe import ( + "bytes" "encoding/json" "io" "mime/multipart" @@ -59,40 +60,37 @@ type FileUploadList struct { Data []*FileUpload `json:"data"` } -// AppendDetails adds the file upload details to an io.ReadWriter. It returns -// the boundary string for a multipart/form-data request and an error (if one -// exists). -func (f *FileUploadParams) AppendDetails(body io.ReadWriter) (string, error) { +// GetBody gets an appropriate multipart form payload to use in a request body +// to create a new file. +func (f *FileUploadParams) GetBody() (*bytes.Buffer, string, error) { + body := &bytes.Buffer{} writer := multipart.NewWriter(body) - var err error if f.Purpose != nil { - err = writer.WriteField("purpose", StringValue(f.Purpose)) + err := writer.WriteField("purpose", StringValue(f.Purpose)) if err != nil { - return "", err + return nil, "", err } } - // Support both FileReader/Filename and File with - // the former being the newer preferred version if f.FileReader != nil && f.Filename != nil { part, err := writer.CreateFormFile("file", filepath.Base(StringValue(f.Filename))) if err != nil { - return "", err + return nil, "", err } _, err = io.Copy(part, f.FileReader) if err != nil { - return "", err + return nil, "", err } } - err = writer.Close() + err := writer.Close() if err != nil { - return "", err + return nil, "", err } - return writer.Boundary(), nil + return body, writer.Boundary(), nil } // UnmarshalJSON handles deserialization of a FileUpload. diff --git a/fileupload/client.go b/fileupload/client.go index 53c2ad9926..68abc0e23e 100644 --- a/fileupload/client.go +++ b/fileupload/client.go @@ -2,7 +2,6 @@ package fileupload import ( - "bytes" "fmt" "net/http" @@ -27,14 +26,13 @@ func (c Client) New(params *stripe.FileUploadParams) (*stripe.FileUpload, error) return nil, fmt.Errorf("params cannot be nil, and params.Purpose and params.File must be set") } - body := &bytes.Buffer{} - boundary, err := params.AppendDetails(body) + bodyBuffer, boundary, err := params.GetBody() if err != nil { return nil, err } upload := &stripe.FileUpload{} - err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, body, ¶ms.Params, upload) + err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, bodyBuffer, ¶ms.Params, upload) return upload, err } diff --git a/fileupload/client_test.go b/fileupload/client_test.go index d9a094f225..9a0da6cd91 100644 --- a/fileupload/client_test.go +++ b/fileupload/client_test.go @@ -53,6 +53,6 @@ func TestFileUploadNew(t *testing.T) { } fileupload, err := New(uploadParams) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, fileupload) } diff --git a/fileupload_test.go b/fileupload_test.go index dfa52fae6a..eb154ee898 100644 --- a/fileupload_test.go +++ b/fileupload_test.go @@ -2,6 +2,7 @@ package stripe import ( "encoding/json" + "os" "testing" assert "github.com/stretchr/testify/require" @@ -27,3 +28,29 @@ func TestFileUpload_UnmarshalJSON(t *testing.T) { assert.Equal(t, "file_123", v.ID) } } + +func TestFileUploadParams_GetBody(t *testing.T) { + f, err := os.Open("fileupload/test_data.pdf") + if err != nil { + t.Errorf("Unable to open test file upload file %v\n", err) + } + + p := &FileUploadParams{ + FileReader: f, + Filename: String(f.Name()), + } + + buffer, boundary, err := p.GetBody() + assert.NoError(t, err) + + assert.NotEqual(t, 0, buffer.Len()) + + // Copied from the check performed by `multipart.Writer.SetBoundary`. A + // very basic check that the string we got back indeed looks like a + // boundary. + // + // rfc2046#section-5.1.1 + if len(boundary) < 1 || len(boundary) > 70 { + t.Errorf("invalid boundary length") + } +} diff --git a/stripe.go b/stripe.go index 3e1a565079..516787b2c6 100644 --- a/stripe.go +++ b/stripe.go @@ -100,7 +100,7 @@ func (a *AppInfo) formatUserAgent() string { type Backend interface { Call(method, path, key string, params ParamsContainer, v interface{}) error CallRaw(method, path, key string, body *form.Values, params *Params, v interface{}) error - CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error + CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error SetMaxNetworkRetries(maxNetworkRetries int) } @@ -152,6 +152,12 @@ type BackendConfiguration struct { MaxNetworkRetries int LogLevel int Logger Printfer + + // networkRetriesSleep indicates whether the backend should use the normal + // sleep between retries. + // + // See also SetNetworkRetriesSleep. + networkRetriesSleep bool } // Call is the Backend.Call implementation for invoking Stripe APIs. @@ -181,7 +187,7 @@ func (s *BackendConfiguration) Call(method, path, key string, params ParamsConta } // CallMultipart is the Backend.CallMultipart implementation for invoking Stripe APIs. -func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error { +func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error { contentType := "multipart/form-data; boundary=" + boundary req, err := s.NewRequest(method, path, key, contentType, params) @@ -198,24 +204,24 @@ func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, // CallRaw is the implementation for invoking Stripe APIs internally without a backend. func (s *BackendConfiguration) CallRaw(method, path, key string, form *form.Values, params *Params, v interface{}) error { - var data string + var body string if form != nil && !form.Empty() { - data = form.Encode() + body = form.Encode() // On `GET`, move the payload into the URL if method == http.MethodGet { - path += "?" + data - data = "" + path += "?" + body + body = "" } } - dataBuffer := bytes.NewBufferString(data) + bodyBuffer := bytes.NewBufferString(body) req, err := s.NewRequest(method, path, key, "application/x-www-form-urlencoded", params) if err != nil { return err } - if err := s.Do(req, dataBuffer, v); err != nil { + if err := s.Do(req, bodyBuffer, v); err != nil { return err } @@ -281,7 +287,7 @@ func (s *BackendConfiguration) NewRequest(method, path, key, contentType string, // Do is used by Call to execute an API request and parse the response. It uses // the backend's HTTP client to execute the request and unmarshals the response // into v. It also handles unmarshaling errors returned by the API. -func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface{}) error { +func (s *BackendConfiguration) Do(req *http.Request, body *bytes.Buffer, v interface{}) error { if s.LogLevel > 1 { s.Logger.Printf("Requesting %v %v%v\n", req.Method, req.URL.Host, req.URL.Path) } @@ -312,7 +318,12 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface // every time we execute it, and this seems to empirically resolve the // problem. if body != nil { - req.Body = nopReadCloser{body} + // We can safely reuse the same buffer that we used to encode our body, + // but return a new reader to it everytime so that each read is from + // the beginning. + reader := bytes.NewReader(body.Bytes()) + + req.Body = nopReadCloser{reader} } res, err = s.HTTPClient.Do(req) @@ -341,7 +352,7 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface s.Logger.Printf("Request failed with: %s (error: %v)\n", string(resBody), err) } - sleepDuration := sleepTime(retry) + sleepDuration := s.sleepTime(retry) retry++ if s.LogLevel > 1 { @@ -469,6 +480,15 @@ func (s *BackendConfiguration) SetMaxNetworkRetries(maxNetworkRetries int) { s.MaxNetworkRetries = maxNetworkRetries } +// SetNetworkRetriesSleep allows the normal sleep between network retries to be +// enabled or disabled. +// +// This function is available for internal testing only and should never be +// used in production. +func (s *BackendConfiguration) SetNetworkRetriesSleep(sleep bool) { + s.networkRetriesSleep = sleep +} + // Checks if an error is a problem that we should retry on. This includes both // socket errors that may represent an intermittent problem and some special // HTTP statuses. @@ -487,6 +507,34 @@ func (s *BackendConfiguration) shouldRetry(err error, resp *http.Response, numRe return false } +// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request. +func (s *BackendConfiguration) sleepTime(numRetries int) time.Duration { + // We disable sleeping in some cases for tests. + if !s.networkRetriesSleep { + return 0 * time.Second + } + + // Apply exponential backoff with minNetworkRetriesDelay on the + // number of num_retries so far as inputs. + delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries) + + // Do not allow the number to exceed maxNetworkRetriesDelay. + if delay > maxNetworkRetriesDelay { + delay = maxNetworkRetriesDelay + } + + // Apply some jitter by randomizing the value in the range of 75%-100%. + jitter := rand.Int63n(int64(delay / 4)) + delay -= time.Duration(jitter) + + // But never sleep less than the base sleep seconds. + if delay < minNetworkRetriesDelay { + delay = minNetworkRetriesDelay + } + + return delay +} + // Backends are the currently supported endpoints. type Backends struct { API, Uploads Backend @@ -859,26 +907,3 @@ func newBackendConfiguration(backendType SupportedBackend, config *BackendConfig URL: config.URL, } } - -// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request. -func sleepTime(numRetries int) time.Duration { - // Apply exponential backoff with minNetworkRetriesDelay on the - // number of num_retries so far as inputs. - delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries) - - // Do not allow the number to exceed maxNetworkRetriesDelay. - if delay > maxNetworkRetriesDelay { - delay = maxNetworkRetriesDelay - } - - // Apply some jitter by randomizing the value in the range of 75%-100%. - jitter := rand.Int63n(int64(delay / 4)) - delay -= time.Duration(jitter) - - // But never sleep less than the base sleep seconds. - if delay < minNetworkRetriesDelay { - delay = minNetworkRetriesDelay - } - - return delay -} diff --git a/stripe_test.go b/stripe_test.go index 7079a37fcf..cd48cdfa8d 100644 --- a/stripe_test.go +++ b/stripe_test.go @@ -1,9 +1,11 @@ package stripe_test import ( + "bytes" "context" "encoding/json" "net/http" + "net/http/httptest" "regexp" "runtime" "sync" @@ -65,6 +67,85 @@ func TestContext_Cancel(t *testing.T) { assert.Regexp(t, regexp.MustCompile(`(request canceled|context canceled\z)`), err.Error()) } +// Tests client retries. +// +// You can get pretty good visibility into what's going on by running just this +// test on verbose: +// +// go test . -run TestDo_Retry -test.v +// +func TestDo_Retry(t *testing.T) { + type testServerResponse struct { + Message string `json:"message"` + } + + message := "Hello, client." + requestNum := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + assert.NoError(t, err) + + // The body should always be the same with every retry. We've + // previously had regressions in this behavior as we switched to HTTP/2 + // and `Request` became non-reusable, so we want to check it with every + // request. + assert.Equal(t, "bar", r.Form.Get("foo")) + + switch requestNum { + case 0: + w.WriteHeader(http.StatusConflict) + w.Write([]byte(`{"error":"Conflict (this should be retried)."}`)) + + case 1: + response := testServerResponse{Message: message} + + data, err := json.Marshal(response) + assert.NoError(t, err) + + _, err = w.Write(data) + assert.NoError(t, err) + + default: + assert.Fail(t, "Should not have reached request %v", requestNum) + } + + requestNum++ + })) + defer testServer.Close() + + backend := stripe.GetBackendWithConfig( + stripe.APIBackend, + &stripe.BackendConfig{ + LogLevel: 3, + MaxNetworkRetries: 5, + URL: testServer.URL, + }, + ).(*stripe.BackendConfiguration) + + // Disable sleeping duration our tests. + backend.SetNetworkRetriesSleep(false) + + request, err := backend.NewRequest( + http.MethodPost, + "/hello", + "sk_test_123", + "application/x-www-form-urlencoded", + nil, + ) + assert.NoError(t, err) + + bodyBuffer := bytes.NewBufferString("foo=bar") + var response testServerResponse + err = backend.Do(request, bodyBuffer, &response) + + assert.NoError(t, err) + assert.Equal(t, message, response.Message) + + // We should have seen exactly two requests. + assert.Equal(t, 2, requestNum) +} + func TestFormatURLPath(t *testing.T) { assert.Equal(t, "/resources/1/subresources/2", stripe.FormatURLPath("/resources/%s/subresources/%s", "1", "2"))