Skip to content

Commit

Permalink
Merge pull request #372 from okta/ttimonen_369
Browse files Browse the repository at this point in the history
  • Loading branch information
monde authored Apr 12, 2023
2 parents bdb33e7 + 0674179 commit 44b659d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 9 deletions.
24 changes: 15 additions & 9 deletions okta/requestExecutor.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ func (re *RequestExecutor) Do(ctx context.Context, req *http.Request, v interfac
re.freshCache = false
}
if !inCache {
resp, err := re.doWithRetries(ctx, req)
resp, done, err := re.doWithRetries(ctx, req)
defer done()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -492,12 +493,13 @@ func (o *oktaBackoff) Context() context.Context {
return o.ctx
}

func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, error) {
func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, func(), error) {
var bodyReader func() io.ReadCloser
done := func() {}
if req.Body != nil {
buf, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
return nil, done, err
}
bodyReader = func() io.ReadCloser {
return io.NopCloser(bytes.NewReader(buf))
Expand All @@ -508,9 +510,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request)
err error
)
if re.config.Okta.Client.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout))
defer cancel()
ctx, done = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout))
}
bOff := &oktaBackoff{
ctx: ctx,
Expand Down Expand Up @@ -549,7 +549,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request)
return errors.New("too many requests")
}
err = backoff.Retry(operation, bOff)
return resp, err
return resp, done, err
}

func tooManyRequests(resp *http.Response) bool {
Expand Down Expand Up @@ -649,7 +649,10 @@ func CheckResponseForError(resp *http.Response) error {
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
Expand All @@ -668,7 +671,10 @@ func buildResponse(resp *http.Response, re *RequestExecutor, v interface{}) (*Re
if err != nil {
return response, err
}
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close() // close it to avoid memory leaks
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/request_executor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package unit

import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"

"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/tests"
"github.com/stretchr/testify/assert"
)

// readerFun makes it easier to implement an inline reader as a closure function.
type readerFun func(p []byte) (n int, err error)

// Read, part of io.Reader interface.
func (r readerFun) Read(p []byte) (n int, err error) { return r(p) }

// slowTransport provides a dummy http-like transport serving fixed content, but slowly.
type slowTransport struct{}

// RoundTrip, part of http.Transport interface. This servers 42 as a JSON response, but slowly.
// In particular, we serve the response immediately, but getting the body takes some milliseconds.
func (t slowTransport) RoundTrip(req *http.Request) (*http.Response, error) {
realBody := strings.NewReader("42")
// This body takes 1 millisecond to read. It also needs a valid context for the whole duration.
slowBody := func(p []byte) (n int, err error) {
select {
case <-req.Context().Done():
return 0, req.Context().Err()
case <-time.After(1 * time.Millisecond):
return realBody.Read(p)
}
}

rsp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(readerFun(slowBody)),
Header: http.Header{},
Request: req,
}
rsp.Header.Set("Content-Type", "application/json")
return rsp, nil
}

// TestExecuteRequest tests that the request executor can handle a slow response.
// In particular, we want to make sure that the context is properly passed through
// and not canceled too early.
func TestExecuteRequest(t *testing.T) {
cfg := []okta.ConfigSetter{
okta.WithOrgUrl("https://fakeurl"), // This is ignored, but required for validator.
okta.WithToken("foo"), // ditto.
okta.WithHttpClientPtr(&http.Client{Transport: slowTransport{}}), // Use our more realistic transport.
okta.WithRequestTimeout(10), // The context issues are gated with actually having a timeout.
}
ctx, cl, err := tests.NewClient(context.Background(), cfg...)
assert.NoError(t, err, "Basic client errors")
req, err := http.NewRequest("GET", "https://fakeurl", http.NoBody)
assert.NoError(t, err, "Request building")
var out int
rs, err := cl.GetRequestExecutor().Do(ctx, req, &out)
assert.NoError(t, err, "Request execution")
if rs.StatusCode != 200 || out != 42 {
t.Errorf("Got val=%d status=%d, want 42 status=200", out, rs.StatusCode)
}
}

0 comments on commit 44b659d

Please sign in to comment.