diff --git a/fileupload.go b/fileupload.go index 1f75c51549..a99dfe06ce 100644 --- a/fileupload.go +++ b/fileupload.go @@ -1,6 +1,7 @@ package stripe import ( + "bytes" "encoding/json" "io" "mime/multipart" @@ -59,40 +60,37 @@ type FileUploadList struct { Data []*FileUpload `json:"data"` } -// AppendDetails adds the file upload details to an io.ReadWriter. It returns -// the boundary string for a multipart/form-data request and an error (if one -// exists). -func (f *FileUploadParams) AppendDetails(body io.ReadWriter) (string, error) { +// GetBody gets an appropriate multipart form payload to use in a request body +// to create a new file. +func (f *FileUploadParams) GetBody() (*bytes.Buffer, string, error) { + body := &bytes.Buffer{} writer := multipart.NewWriter(body) - var err error if f.Purpose != nil { - err = writer.WriteField("purpose", StringValue(f.Purpose)) + err := writer.WriteField("purpose", StringValue(f.Purpose)) if err != nil { - return "", err + return nil, "", err } } - // Support both FileReader/Filename and File with - // the former being the newer preferred version if f.FileReader != nil && f.Filename != nil { part, err := writer.CreateFormFile("file", filepath.Base(StringValue(f.Filename))) if err != nil { - return "", err + return nil, "", err } _, err = io.Copy(part, f.FileReader) if err != nil { - return "", err + return nil, "", err } } - err = writer.Close() + err := writer.Close() if err != nil { - return "", err + return nil, "", err } - return writer.Boundary(), nil + return body, writer.Boundary(), nil } // UnmarshalJSON handles deserialization of a FileUpload. diff --git a/fileupload/client.go b/fileupload/client.go index 53c2ad9926..68abc0e23e 100644 --- a/fileupload/client.go +++ b/fileupload/client.go @@ -2,7 +2,6 @@ package fileupload import ( - "bytes" "fmt" "net/http" @@ -27,14 +26,13 @@ func (c Client) New(params *stripe.FileUploadParams) (*stripe.FileUpload, error) return nil, fmt.Errorf("params cannot be nil, and params.Purpose and params.File must be set") } - body := &bytes.Buffer{} - boundary, err := params.AppendDetails(body) + bodyBuffer, boundary, err := params.GetBody() if err != nil { return nil, err } upload := &stripe.FileUpload{} - err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, body, ¶ms.Params, upload) + err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, bodyBuffer, ¶ms.Params, upload) return upload, err } diff --git a/fileupload/client_test.go b/fileupload/client_test.go index d9a094f225..9a0da6cd91 100644 --- a/fileupload/client_test.go +++ b/fileupload/client_test.go @@ -53,6 +53,6 @@ func TestFileUploadNew(t *testing.T) { } fileupload, err := New(uploadParams) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, fileupload) } diff --git a/fileupload_test.go b/fileupload_test.go index dfa52fae6a..eb154ee898 100644 --- a/fileupload_test.go +++ b/fileupload_test.go @@ -2,6 +2,7 @@ package stripe import ( "encoding/json" + "os" "testing" assert "github.com/stretchr/testify/require" @@ -27,3 +28,29 @@ func TestFileUpload_UnmarshalJSON(t *testing.T) { assert.Equal(t, "file_123", v.ID) } } + +func TestFileUploadParams_GetBody(t *testing.T) { + f, err := os.Open("fileupload/test_data.pdf") + if err != nil { + t.Errorf("Unable to open test file upload file %v\n", err) + } + + p := &FileUploadParams{ + FileReader: f, + Filename: String(f.Name()), + } + + buffer, boundary, err := p.GetBody() + assert.NoError(t, err) + + assert.NotEqual(t, 0, buffer.Len()) + + // Copied from the check performed by `multipart.Writer.SetBoundary`. A + // very basic check that the string we got back indeed looks like a + // boundary. + // + // rfc2046#section-5.1.1 + if len(boundary) < 1 || len(boundary) > 70 { + t.Errorf("invalid boundary length") + } +} diff --git a/stripe.go b/stripe.go index 3e1a565079..516787b2c6 100644 --- a/stripe.go +++ b/stripe.go @@ -100,7 +100,7 @@ func (a *AppInfo) formatUserAgent() string { type Backend interface { Call(method, path, key string, params ParamsContainer, v interface{}) error CallRaw(method, path, key string, body *form.Values, params *Params, v interface{}) error - CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error + CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error SetMaxNetworkRetries(maxNetworkRetries int) } @@ -152,6 +152,12 @@ type BackendConfiguration struct { MaxNetworkRetries int LogLevel int Logger Printfer + + // networkRetriesSleep indicates whether the backend should use the normal + // sleep between retries. + // + // See also SetNetworkRetriesSleep. + networkRetriesSleep bool } // Call is the Backend.Call implementation for invoking Stripe APIs. @@ -181,7 +187,7 @@ func (s *BackendConfiguration) Call(method, path, key string, params ParamsConta } // CallMultipart is the Backend.CallMultipart implementation for invoking Stripe APIs. -func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error { +func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error { contentType := "multipart/form-data; boundary=" + boundary req, err := s.NewRequest(method, path, key, contentType, params) @@ -198,24 +204,24 @@ func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, // CallRaw is the implementation for invoking Stripe APIs internally without a backend. func (s *BackendConfiguration) CallRaw(method, path, key string, form *form.Values, params *Params, v interface{}) error { - var data string + var body string if form != nil && !form.Empty() { - data = form.Encode() + body = form.Encode() // On `GET`, move the payload into the URL if method == http.MethodGet { - path += "?" + data - data = "" + path += "?" + body + body = "" } } - dataBuffer := bytes.NewBufferString(data) + bodyBuffer := bytes.NewBufferString(body) req, err := s.NewRequest(method, path, key, "application/x-www-form-urlencoded", params) if err != nil { return err } - if err := s.Do(req, dataBuffer, v); err != nil { + if err := s.Do(req, bodyBuffer, v); err != nil { return err } @@ -281,7 +287,7 @@ func (s *BackendConfiguration) NewRequest(method, path, key, contentType string, // Do is used by Call to execute an API request and parse the response. It uses // the backend's HTTP client to execute the request and unmarshals the response // into v. It also handles unmarshaling errors returned by the API. -func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface{}) error { +func (s *BackendConfiguration) Do(req *http.Request, body *bytes.Buffer, v interface{}) error { if s.LogLevel > 1 { s.Logger.Printf("Requesting %v %v%v\n", req.Method, req.URL.Host, req.URL.Path) } @@ -312,7 +318,12 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface // every time we execute it, and this seems to empirically resolve the // problem. if body != nil { - req.Body = nopReadCloser{body} + // We can safely reuse the same buffer that we used to encode our body, + // but return a new reader to it everytime so that each read is from + // the beginning. + reader := bytes.NewReader(body.Bytes()) + + req.Body = nopReadCloser{reader} } res, err = s.HTTPClient.Do(req) @@ -341,7 +352,7 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface s.Logger.Printf("Request failed with: %s (error: %v)\n", string(resBody), err) } - sleepDuration := sleepTime(retry) + sleepDuration := s.sleepTime(retry) retry++ if s.LogLevel > 1 { @@ -469,6 +480,15 @@ func (s *BackendConfiguration) SetMaxNetworkRetries(maxNetworkRetries int) { s.MaxNetworkRetries = maxNetworkRetries } +// SetNetworkRetriesSleep allows the normal sleep between network retries to be +// enabled or disabled. +// +// This function is available for internal testing only and should never be +// used in production. +func (s *BackendConfiguration) SetNetworkRetriesSleep(sleep bool) { + s.networkRetriesSleep = sleep +} + // Checks if an error is a problem that we should retry on. This includes both // socket errors that may represent an intermittent problem and some special // HTTP statuses. @@ -487,6 +507,34 @@ func (s *BackendConfiguration) shouldRetry(err error, resp *http.Response, numRe return false } +// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request. +func (s *BackendConfiguration) sleepTime(numRetries int) time.Duration { + // We disable sleeping in some cases for tests. + if !s.networkRetriesSleep { + return 0 * time.Second + } + + // Apply exponential backoff with minNetworkRetriesDelay on the + // number of num_retries so far as inputs. + delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries) + + // Do not allow the number to exceed maxNetworkRetriesDelay. + if delay > maxNetworkRetriesDelay { + delay = maxNetworkRetriesDelay + } + + // Apply some jitter by randomizing the value in the range of 75%-100%. + jitter := rand.Int63n(int64(delay / 4)) + delay -= time.Duration(jitter) + + // But never sleep less than the base sleep seconds. + if delay < minNetworkRetriesDelay { + delay = minNetworkRetriesDelay + } + + return delay +} + // Backends are the currently supported endpoints. type Backends struct { API, Uploads Backend @@ -859,26 +907,3 @@ func newBackendConfiguration(backendType SupportedBackend, config *BackendConfig URL: config.URL, } } - -// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request. -func sleepTime(numRetries int) time.Duration { - // Apply exponential backoff with minNetworkRetriesDelay on the - // number of num_retries so far as inputs. - delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries) - - // Do not allow the number to exceed maxNetworkRetriesDelay. - if delay > maxNetworkRetriesDelay { - delay = maxNetworkRetriesDelay - } - - // Apply some jitter by randomizing the value in the range of 75%-100%. - jitter := rand.Int63n(int64(delay / 4)) - delay -= time.Duration(jitter) - - // But never sleep less than the base sleep seconds. - if delay < minNetworkRetriesDelay { - delay = minNetworkRetriesDelay - } - - return delay -} diff --git a/stripe_test.go b/stripe_test.go index 7079a37fcf..cd48cdfa8d 100644 --- a/stripe_test.go +++ b/stripe_test.go @@ -1,9 +1,11 @@ package stripe_test import ( + "bytes" "context" "encoding/json" "net/http" + "net/http/httptest" "regexp" "runtime" "sync" @@ -65,6 +67,85 @@ func TestContext_Cancel(t *testing.T) { assert.Regexp(t, regexp.MustCompile(`(request canceled|context canceled\z)`), err.Error()) } +// Tests client retries. +// +// You can get pretty good visibility into what's going on by running just this +// test on verbose: +// +// go test . -run TestDo_Retry -test.v +// +func TestDo_Retry(t *testing.T) { + type testServerResponse struct { + Message string `json:"message"` + } + + message := "Hello, client." + requestNum := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + assert.NoError(t, err) + + // The body should always be the same with every retry. We've + // previously had regressions in this behavior as we switched to HTTP/2 + // and `Request` became non-reusable, so we want to check it with every + // request. + assert.Equal(t, "bar", r.Form.Get("foo")) + + switch requestNum { + case 0: + w.WriteHeader(http.StatusConflict) + w.Write([]byte(`{"error":"Conflict (this should be retried)."}`)) + + case 1: + response := testServerResponse{Message: message} + + data, err := json.Marshal(response) + assert.NoError(t, err) + + _, err = w.Write(data) + assert.NoError(t, err) + + default: + assert.Fail(t, "Should not have reached request %v", requestNum) + } + + requestNum++ + })) + defer testServer.Close() + + backend := stripe.GetBackendWithConfig( + stripe.APIBackend, + &stripe.BackendConfig{ + LogLevel: 3, + MaxNetworkRetries: 5, + URL: testServer.URL, + }, + ).(*stripe.BackendConfiguration) + + // Disable sleeping duration our tests. + backend.SetNetworkRetriesSleep(false) + + request, err := backend.NewRequest( + http.MethodPost, + "/hello", + "sk_test_123", + "application/x-www-form-urlencoded", + nil, + ) + assert.NoError(t, err) + + bodyBuffer := bytes.NewBufferString("foo=bar") + var response testServerResponse + err = backend.Do(request, bodyBuffer, &response) + + assert.NoError(t, err) + assert.Equal(t, message, response.Message) + + // We should have seen exactly two requests. + assert.Equal(t, 2, requestNum) +} + func TestFormatURLPath(t *testing.T) { assert.Equal(t, "/resources/1/subresources/2", stripe.FormatURLPath("/resources/%s/subresources/%s", "1", "2"))