Skip to content

Commit

Permalink
Introduce getters so a request's body can be read multiple times
Browse files Browse the repository at this point in the history
In #646, I attempted to fix a problem whereby when using HTTP/2 we were
sending invalid requests during automatic retries because a `Request`
struct with a body cannot be reused in the context of HTTP/2.

Unfortunately, it wasn't blowing up quite as spectacularly as before,
the fix didn't quite work either. It would set a new `Body` before
retries, but that body would be pointing to an `io.Reader` that was
already exhausted as it had been fully consumed during the original
request, thereby producing an empty request body.

Here we modify the `Do` and `CallMultipart` interfaces so that they take
body buffers instead of just readers. When making a request, we create a
new reader for those buffers every time, thus ensuring a fresh one.

This is relatively trivial when making non-multipart requests (which is
most requests) because we were already producing a buffer in almost the
right place. It's a little more complicated for multipart requests (file
uploads) because we had encoding scheme that wasn't quite compatible,
but I've done some refactoring there (and added more tests) to bring
things in line.

We also add a much improved test framework this time around that
verifies that the problem is fixed and will definitively stay fixed.

Fixes #647.
  • Loading branch information
brandur committed Aug 6, 2018
1 parent dbb7859 commit 4b4a17b
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 53 deletions.
26 changes: 12 additions & 14 deletions fileupload.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stripe

import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
Expand Down Expand Up @@ -59,40 +60,37 @@ type FileUploadList struct {
Data []*FileUpload `json:"data"`
}

// AppendDetails adds the file upload details to an io.ReadWriter. It returns
// the boundary string for a multipart/form-data request and an error (if one
// exists).
func (f *FileUploadParams) AppendDetails(body io.ReadWriter) (string, error) {
// GetBody gets an appropriate multipart form payload to use in a request body
// to create a new file.
func (f *FileUploadParams) GetBody() (*bytes.Buffer, string, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
var err error

if f.Purpose != nil {
err = writer.WriteField("purpose", StringValue(f.Purpose))
err := writer.WriteField("purpose", StringValue(f.Purpose))
if err != nil {
return "", err
return nil, "", err
}
}

// Support both FileReader/Filename and File with
// the former being the newer preferred version
if f.FileReader != nil && f.Filename != nil {
part, err := writer.CreateFormFile("file", filepath.Base(StringValue(f.Filename)))
if err != nil {
return "", err
return nil, "", err
}

_, err = io.Copy(part, f.FileReader)
if err != nil {
return "", err
return nil, "", err
}
}

err = writer.Close()
err := writer.Close()
if err != nil {
return "", err
return nil, "", err
}

return writer.Boundary(), nil
return body, writer.Boundary(), nil
}

// UnmarshalJSON handles deserialization of a FileUpload.
Expand Down
6 changes: 2 additions & 4 deletions fileupload/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package fileupload

import (
"bytes"
"fmt"
"net/http"

Expand All @@ -27,14 +26,13 @@ func (c Client) New(params *stripe.FileUploadParams) (*stripe.FileUpload, error)
return nil, fmt.Errorf("params cannot be nil, and params.Purpose and params.File must be set")
}

body := &bytes.Buffer{}
boundary, err := params.AppendDetails(body)
bodyBuffer, boundary, err := params.GetBody()
if err != nil {
return nil, err
}

upload := &stripe.FileUpload{}
err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, body, &params.Params, upload)
err = c.B.CallMultipart(http.MethodPost, "/files", c.Key, boundary, bodyBuffer, &params.Params, upload)

return upload, err
}
Expand Down
2 changes: 1 addition & 1 deletion fileupload/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ func TestFileUploadNew(t *testing.T) {
}

fileupload, err := New(uploadParams)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotNil(t, fileupload)
}
27 changes: 27 additions & 0 deletions fileupload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stripe

import (
"encoding/json"
"os"
"testing"

assert "github.com/stretchr/testify/require"
Expand All @@ -27,3 +28,29 @@ func TestFileUpload_UnmarshalJSON(t *testing.T) {
assert.Equal(t, "file_123", v.ID)
}
}

func TestFileUploadParams_GetBody(t *testing.T) {
f, err := os.Open("fileupload/test_data.pdf")
if err != nil {
t.Errorf("Unable to open test file upload file %v\n", err)
}

p := &FileUploadParams{
FileReader: f,
Filename: String(f.Name()),
}

buffer, boundary, err := p.GetBody()
assert.NoError(t, err)

assert.NotEqual(t, 0, buffer.Len())

// Copied from the check performed by `multipart.Writer.SetBoundary`. A
// very basic check that the string we got back indeed looks like a
// boundary.
//
// rfc2046#section-5.1.1
if len(boundary) < 1 || len(boundary) > 70 {
t.Errorf("invalid boundary length")
}
}
93 changes: 59 additions & 34 deletions stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (a *AppInfo) formatUserAgent() string {
type Backend interface {
Call(method, path, key string, params ParamsContainer, v interface{}) error
CallRaw(method, path, key string, body *form.Values, params *Params, v interface{}) error
CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error
CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error
SetMaxNetworkRetries(maxNetworkRetries int)
}

Expand Down Expand Up @@ -152,6 +152,12 @@ type BackendConfiguration struct {
MaxNetworkRetries int
LogLevel int
Logger Printfer

// networkRetriesSleep indicates whether the backend should use the normal
// sleep between retries.
//
// See also SetNetworkRetriesSleep.
networkRetriesSleep bool
}

// Call is the Backend.Call implementation for invoking Stripe APIs.
Expand Down Expand Up @@ -181,7 +187,7 @@ func (s *BackendConfiguration) Call(method, path, key string, params ParamsConta
}

// CallMultipart is the Backend.CallMultipart implementation for invoking Stripe APIs.
func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error {
func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *Params, v interface{}) error {
contentType := "multipart/form-data; boundary=" + boundary

req, err := s.NewRequest(method, path, key, contentType, params)
Expand All @@ -198,24 +204,24 @@ func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string,

// CallRaw is the implementation for invoking Stripe APIs internally without a backend.
func (s *BackendConfiguration) CallRaw(method, path, key string, form *form.Values, params *Params, v interface{}) error {
var data string
var body string
if form != nil && !form.Empty() {
data = form.Encode()
body = form.Encode()

// On `GET`, move the payload into the URL
if method == http.MethodGet {
path += "?" + data
data = ""
path += "?" + body
body = ""
}
}
dataBuffer := bytes.NewBufferString(data)
bodyBuffer := bytes.NewBufferString(body)

req, err := s.NewRequest(method, path, key, "application/x-www-form-urlencoded", params)
if err != nil {
return err
}

if err := s.Do(req, dataBuffer, v); err != nil {
if err := s.Do(req, bodyBuffer, v); err != nil {
return err
}

Expand Down Expand Up @@ -281,7 +287,7 @@ func (s *BackendConfiguration) NewRequest(method, path, key, contentType string,
// Do is used by Call to execute an API request and parse the response. It uses
// the backend's HTTP client to execute the request and unmarshals the response
// into v. It also handles unmarshaling errors returned by the API.
func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface{}) error {
func (s *BackendConfiguration) Do(req *http.Request, body *bytes.Buffer, v interface{}) error {
if s.LogLevel > 1 {
s.Logger.Printf("Requesting %v %v%v\n", req.Method, req.URL.Host, req.URL.Path)
}
Expand Down Expand Up @@ -312,7 +318,12 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface
// every time we execute it, and this seems to empirically resolve the
// problem.
if body != nil {
req.Body = nopReadCloser{body}
// We can safely reuse the same buffer that we used to encode our body,
// but return a new reader to it everytime so that each read is from
// the beginning.
reader := bytes.NewReader(body.Bytes())

req.Body = nopReadCloser{reader}
}

res, err = s.HTTPClient.Do(req)
Expand Down Expand Up @@ -341,7 +352,7 @@ func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface
s.Logger.Printf("Request failed with: %s (error: %v)\n", string(resBody), err)
}

sleepDuration := sleepTime(retry)
sleepDuration := s.sleepTime(retry)
retry++

if s.LogLevel > 1 {
Expand Down Expand Up @@ -469,6 +480,15 @@ func (s *BackendConfiguration) SetMaxNetworkRetries(maxNetworkRetries int) {
s.MaxNetworkRetries = maxNetworkRetries
}

// SetNetworkRetriesSleep allows the normal sleep between network retries to be
// enabled or disabled.
//
// This function is available for internal testing only and should never be
// used in production.
func (s *BackendConfiguration) SetNetworkRetriesSleep(sleep bool) {
s.networkRetriesSleep = sleep
}

// Checks if an error is a problem that we should retry on. This includes both
// socket errors that may represent an intermittent problem and some special
// HTTP statuses.
Expand All @@ -487,6 +507,34 @@ func (s *BackendConfiguration) shouldRetry(err error, resp *http.Response, numRe
return false
}

// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request.
func (s *BackendConfiguration) sleepTime(numRetries int) time.Duration {
// We disable sleeping in some cases for tests.
if !s.networkRetriesSleep {
return 0 * time.Second
}

// Apply exponential backoff with minNetworkRetriesDelay on the
// number of num_retries so far as inputs.
delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries)

// Do not allow the number to exceed maxNetworkRetriesDelay.
if delay > maxNetworkRetriesDelay {
delay = maxNetworkRetriesDelay
}

// Apply some jitter by randomizing the value in the range of 75%-100%.
jitter := rand.Int63n(int64(delay / 4))
delay -= time.Duration(jitter)

// But never sleep less than the base sleep seconds.
if delay < minNetworkRetriesDelay {
delay = minNetworkRetriesDelay
}

return delay
}

// Backends are the currently supported endpoints.
type Backends struct {
API, Uploads Backend
Expand Down Expand Up @@ -859,26 +907,3 @@ func newBackendConfiguration(backendType SupportedBackend, config *BackendConfig
URL: config.URL,
}
}

// sleepTime calculates sleeping/delay time in milliseconds between failure and a new one request.
func sleepTime(numRetries int) time.Duration {
// Apply exponential backoff with minNetworkRetriesDelay on the
// number of num_retries so far as inputs.
delay := minNetworkRetriesDelay + minNetworkRetriesDelay*time.Duration(numRetries*numRetries)

// Do not allow the number to exceed maxNetworkRetriesDelay.
if delay > maxNetworkRetriesDelay {
delay = maxNetworkRetriesDelay
}

// Apply some jitter by randomizing the value in the range of 75%-100%.
jitter := rand.Int63n(int64(delay / 4))
delay -= time.Duration(jitter)

// But never sleep less than the base sleep seconds.
if delay < minNetworkRetriesDelay {
delay = minNetworkRetriesDelay
}

return delay
}
Loading

0 comments on commit 4b4a17b

Please sign in to comment.