diff --git a/internal/net/http/transport/roundtrip.go b/internal/net/http/transport/roundtrip.go index be452c6ab0..5f185c98e2 100644 --- a/internal/net/http/transport/roundtrip.go +++ b/internal/net/http/transport/roundtrip.go @@ -32,6 +32,7 @@ type ert struct { bo backoff.Backoff } +// NewExpBackoff returns the backoff roundtripper implementation func NewExpBackoff(opts ...Option) http.RoundTripper { e := new(ert) for _, opt := range append(defaultOpts, opts...) { @@ -41,50 +42,78 @@ func NewExpBackoff(opts ...Option) http.RoundTripper { return e } +// RoundTrip round trip the HTTP request and return the response. +// If backoff is not set, the default roundTrip implementation will be used. +// It round trip the request and returns the response, and return any error occurred. +// It returns errors.ErrTransportRetryable to indicate if the request is consider as retryable. func (e *ert) RoundTrip(req *http.Request) (res *http.Response, err error) { if e.bo == nil { return e.roundTrip(req) } - var r interface{} - r, err = e.bo.Do(req.Context(), func() (interface{}, error) { - return e.roundTrip(req) + + var rterr error + _, err = e.bo.Do(req.Context(), func() (interface{}, error) { + r, reqerr := e.roundTrip(req) + if reqerr != nil { + // if the error is retryable, return the error and let backoff to retry. + if errors.Is(reqerr, errors.ErrTransportRetryable) { + return nil, reqerr + } + // if the error is not retryable, return nil error to terminate the backoff execution + rterr = reqerr + return nil, nil + } + res = r + return r, nil }) if err != nil { return nil, err } - - var ok bool - res, ok = r.(*http.Response) - if !ok { - return nil, errors.ErrInvalidTypeConversion(r, res) + if rterr != nil { + return nil, rterr } + return res, nil } func (e *ert) roundTrip(req *http.Request) (res *http.Response, err error) { res, err = e.transport.RoundTrip(req) - if err == nil { - return res, nil - } - if res == nil { - return nil, err - } - _, err = io.Copy(ioutil.Discard, res.Body) if err != nil { log.Error(err) + if res != nil { // just in case we check the response as it depends on RoundTrip impl. + closeBody(res.Body) + if retryableStatusCode(res.StatusCode) { + return nil, errors.Wrap(errors.ErrTransportRetryable, err.Error()) + } + } + return nil, err } - err = res.Body.Close() - if err != nil { - log.Error(err) + + if res != nil && retryableStatusCode(res.StatusCode) { + closeBody(res.Body) + return nil, errors.ErrTransportRetryable } - switch res.StatusCode { + return res, nil +} + +func retryableStatusCode(status int) bool { + switch status { case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusServiceUnavailable, http.StatusMovedPermanently, http.StatusBadGateway, http.StatusGatewayTimeout: - return nil, errors.ErrTransportRetryable + return true + } + return false +} + +func closeBody(rc io.ReadCloser) { + if _, err := io.Copy(ioutil.Discard, rc); err != nil { + log.Error(err) + } + if err := rc.Close(); err != nil { + log.Error(err) } - return res, nil } diff --git a/internal/net/http/transport/roundtrip_test.go b/internal/net/http/transport/roundtrip_test.go index d0d9e925ef..cb2bd5f597 100644 --- a/internal/net/http/transport/roundtrip_test.go +++ b/internal/net/http/transport/roundtrip_test.go @@ -13,221 +13,99 @@ // See the License for the specific language governing permissions and // limitations under the License. // + +// Package transport provides http transport roundtrip option package transport import ( "bytes" "context" + "io" "io/ioutil" "net/http" + "net/http/httptest" + "os" "reflect" "testing" "github.com/vdaas/vald/internal/backoff" "github.com/vdaas/vald/internal/errors" - + "github.com/vdaas/vald/internal/log" "go.uber.org/goleak" ) -func TestNewExpBackoff(t *testing.T) { - type test struct { - name string - opts []Option - initialized bool - } - - tests := []test{ - { - name: "initialize success", - initialized: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewExpBackoff(tt.opts...) - - if (got != nil) != tt.initialized { - t.Error("New() is wrong") - } - }) - } +func TestMain(m *testing.M) { + log.Init() + os.Exit(m.Run()) } -func TestRoundTrip(t *testing.T) { +func TestNewExpBackoff(t *testing.T) { type args struct { - req *http.Request + opts []Option } - - type field struct { - bo backoff.Backoff - transport http.RoundTripper + type want struct { + want http.RoundTripper } - type test struct { - name string - args args - field field - checkFunc func(*http.Response, error) error + name string + args args + want want + checkFunc func(want, http.RoundTripper) error + beforeFunc func(args) + afterFunc func(args) + } + defaultCheckFunc := func(w want, got http.RoundTripper) error { + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got = %v, want %v", got, w.want) + } + return nil } - tests := []test{ - func() test { - wantRes := new(http.Response) - - tr := &roundTripMock{ - RoundTripFunc: func(*http.Request) (*http.Response, error) { - return wantRes, nil - }, - } - - return test{ - name: "returns not error when backoff object is nil", - field: field{ - transport: tr, - }, - checkFunc: func(res *http.Response, err error) error { - if err != nil { - return errors.Errorf("error not nil. err: %v", err) - } - - if !reflect.DeepEqual(res, wantRes) { - return errors.Errorf("res not equals. want: %v, got: %v", wantRes, err) - } - - return nil - }, - } - }(), - - func() test { - wantRes := new(http.Response) - - tr := &roundTripMock{ - RoundTripFunc: func(*http.Request) (*http.Response, error) { - return wantRes, nil - }, - } - - bo := &backoffMock{ - DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { - return fn() - }, - } - - return test{ - name: "returns not error when backoff object is not nil", - args: args{ - req: new(http.Request), - }, - field: field{ - transport: tr, - bo: bo, - }, - checkFunc: func(res *http.Response, err error) error { - if err != nil { - return errors.Errorf("error not nil. err: %v", err) - } - - if !reflect.DeepEqual(res, wantRes) { - return errors.Errorf("res not equals. want: %v, got: %v", wantRes, err) - } - - return nil + { + name: "initialize success", + want: want{ + want: &ert{ + transport: http.DefaultTransport, }, - } - }(), - + }, + }, func() test { - res := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: ioutil.NopCloser(new(bytes.Buffer)), - } - - tr := &roundTripMock{ - RoundTripFunc: func(*http.Request) (*http.Response, error) { - return res, errors.New("faild") - }, - } - - bo := &backoffMock{ - DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { - return fn() - }, - } - + b := backoff.New() return test{ - name: "returns error when Do function returns error", + name: "initialize success with option", args: args{ - req: new(http.Request), + opts: []Option{ + WithBackoff(b), + }, }, - field: field{ - transport: tr, - bo: bo, - }, - checkFunc: func(res *http.Response, err error) error { - if err == nil { - return errors.New("err is nil") - } - - if res != nil { - return errors.Errorf("res not nil. res: %v", res) - } - - return nil + want: want{ + want: &ert{ + transport: http.DefaultTransport, + bo: b, + }, }, } }(), + } - func() test { - tr := &roundTripMock{ - RoundTripFunc: func(*http.Request) (*http.Response, error) { - return nil, nil - }, + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) + if test.beforeFunc != nil { + test.beforeFunc(test.args) } - - bo := &backoffMock{ - DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { - _, err := fn() - return "dumy", err - }, + if test.afterFunc != nil { + defer test.afterFunc(test.args) } - - return test{ - name: "returns error when type conversion error occurs", - args: args{ - req: new(http.Request), - }, - field: field{ - transport: tr, - bo: bo, - }, - checkFunc: func(res *http.Response, err error) error { - if err == nil { - return errors.New("err is nil") - } - - if res != nil { - return errors.Errorf("res not nil. res: %v", res) - } - - return nil - }, + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - e := &ert{ - transport: tt.field.transport, - bo: tt.field.bo, + got := NewExpBackoff(test.args.opts...) + if err := test.checkFunc(test.want, got); err != nil { + tt.Errorf("error = %v", err) } - res, err := e.RoundTrip(tt.args.req) - if err := tt.checkFunc(res, err); err != nil { - t.Error(err) - } }) } } @@ -263,44 +141,164 @@ func Test_ert_RoundTrip(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - req: nil, - }, - fields: fields { - transport: nil, - bo: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - req: nil, - }, - fields: fields { - transport: nil, - bo: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return roundtrip response if backoff is nil", + args: args{ + req: nil, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200", + }, nil + }, + }, + }, + want: want{ + wantRes: &http.Response{ + Status: "200", + }, + }, + }, + { + name: "return backoff response if backoff is not nil", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200", + }, nil + }, + }, + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return fn() + }, + }, + }, + want: want{ + wantRes: &http.Response{ + Status: "200", + }, + }, + }, + { + name: "return default roundtrip response if backoff is not nil", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200", + }, nil + }, + }, + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return fn() + }, + }, + }, + want: want{ + wantRes: &http.Response{ + Status: "200", + }, + }, + }, + { + name: "return default roundtrip response if backoff use the default roundtrip", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200", + }, nil + }, + }, + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return fn() + }, + }, + }, + want: want{ + wantRes: &http.Response{ + Status: "200", + }, + }, + }, + { + name: "return backoff error", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return nil, errors.New("error") + }, + }, + }, + want: want{ + err: errors.New("error"), + }, + }, + { + name: "return default roundtrip error if backoff use the default roundtrip", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return nil, errors.New("error") + }, + }, + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return fn() + }, + }, + }, + want: want{ + err: errors.New("error"), + }, + }, + { + name: "return retryable error", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return nil, errors.Wrap(errors.ErrTransportRetryable, "error") + }, + }, + bo: &backoffMock{ + DoFunc: func(ctx context.Context, fn func() (interface{}, error)) (interface{}, error) { + return fn() + }, + }, + }, + want: want{ + err: errors.Wrap(errors.ErrTransportRetryable, "error"), + }, + }, } for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -355,44 +353,85 @@ func Test_ert_roundTrip(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - req: nil, - }, - fields: fields { - transport: nil, - bo: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - req: nil, - }, - fields: fields { - transport: nil, - bo: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "roundtrip return success", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200", + }, nil + }, + }, + }, + want: want{ + wantRes: &http.Response{ + Status: "200", + }, + }, + }, + { + name: "roundtrip return error", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return nil, errors.New("error") + }, + }, + }, + want: want{ + err: errors.New("error"), + }, + }, + { + name: "roundtrip return retryable error", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("abc"))), + }, nil + }, + }, + }, + want: want{ + err: errors.ErrTransportRetryable, + }, + }, + { + name: "roundtrip return retryable error even when error occurred and roundtrip response is not nil", + args: args{ + req: &http.Request{}, + }, + fields: fields{ + transport: &roundTripMock{ + RoundTripFunc: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("abc"))), + }, errors.New("dummy") + }, + }, + }, + want: want{ + err: errors.ErrTransportRetryable, + }, + }, } for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -415,3 +454,123 @@ func Test_ert_roundTrip(t *testing.T) { }) } } + +func Test_retryableStatusCode(t *testing.T) { + type args struct { + status int + } + type want struct { + want bool + } + type test struct { + name string + args args + want want + checkFunc func(want, bool) error + beforeFunc func(args) + afterFunc func(args) + } + defaultCheckFunc := func(w want, got bool) error { + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got = %v, want %v", got, w.want) + } + return nil + } + tests := []test{ + { + name: "return true when response status is retryable", + args: args{ + status: http.StatusTooManyRequests, + }, + want: want{ + want: true, + }, + }, + { + name: "return false when response status is not retryable", + args: args{ + status: http.StatusOK, + }, + want: want{ + want: false, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) + if test.beforeFunc != nil { + test.beforeFunc(test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + + got := retryableStatusCode(test.args.status) + if err := test.checkFunc(test.want, got); err != nil { + tt.Errorf("error = %v", err) + } + + }) + } +} + +func Test_closeBody(t *testing.T) { + type args struct { + rc io.ReadCloser + } + type want struct { + } + type test struct { + name string + args args + want want + checkFunc func(io.ReadCloser, want) error + beforeFunc func(args) + afterFunc func(args) + } + defaultCheckFunc := func(rc io.ReadCloser, w want) error { + if i, err := rc.Read([]byte{}); i != 0 || err != io.EOF { + return errors.Errorf("connection not closed, num: %d, err: %v", i, err) + } + return nil + } + tests := []test{ + func() test { + rr := httptest.NewRecorder() + rr.WriteString("abc") + res := rr.Result() + + return test{ + name: "close response body", + args: args{ + rc: res.Body, + }, + } + }(), + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) + if test.beforeFunc != nil { + test.beforeFunc(test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + + closeBody(test.args.rc) + if err := test.checkFunc(test.args.rc, test.want); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +}