Skip to content

Commit

Permalink
Merge pull request #692 from go-resty/revert-and-fix-pr-541
Browse files Browse the repository at this point in the history
Revert and fix pr 541
  • Loading branch information
jeevatkm authored Sep 17, 2023
2 parents 15cedd4 + 46e2525 commit 69a6954
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0

- name: Setup Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
cache: true
cache-dependency-path: go.sum

- name: Format
run: diff -u <(echo -n) <(go fmt $(go list ./...))
Expand Down
17 changes: 17 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,22 @@ func TestClientAllowsGetMethodPayload(t *testing.T) {
assertEqual(t, payload, resp.String())
}

func TestClientAllowsGetMethodPayloadIoReader(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

c := dc()
c.SetAllowGetMethodPayload(true)

payload := "test-payload"
body := bytes.NewReader([]byte(payload))
resp, err := c.R().SetBody(body).Get(ts.URL + "/get-method-payload-test")

assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, payload, resp.String())
}

func TestClientAllowsGetMethodPayloadDisabled(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
Expand Down Expand Up @@ -836,6 +852,7 @@ func TestClientOnResponseError(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
var assertErrorHook = func(r *Request, err error) {
assertNotNil(t, r)
v, ok := err.(*ResponseError)
Expand Down
15 changes: 11 additions & 4 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ CL:

func createHTTPRequest(c *Client, r *Request) (err error) {
if r.bodyBuf == nil {
if c.setContentLength || r.setContentLength {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
} else if c.setContentLength || r.setContentLength {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
} else {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
Expand Down Expand Up @@ -441,9 +443,14 @@ func handleRequestBody(c *Client, r *Request) (err error) {
r.bodyBuf = nil

if reader, ok := r.Body.(io.Reader); ok {
r.bodyBuf = acquireBuffer()
_, err = r.bodyBuf.ReadFrom(reader)
r.Body = nil
if c.setContentLength || r.setContentLength { // keep backward compatibility
r.bodyBuf = acquireBuffer()
_, err = r.bodyBuf.ReadFrom(reader)
r.Body = nil
} else {
// Otherwise buffer less processing for `io.Reader`, sounds good.
return
}
} else if b, ok := r.Body.([]byte); ok {
bodyBytes = b
} else if s, ok := r.Body.(string); ok {
Expand Down

0 comments on commit 69a6954

Please sign in to comment.