From 268845fc7dcf1614330b214ac6609aaa814b3181 Mon Sep 17 00:00:00 2001 From: Hiroto Funakoshi Date: Sat, 22 Jun 2024 23:29:30 +0900 Subject: [PATCH] Bugfix that caused an error when argument has 3 or more nil arguments (#2517) * fix: bugfix a problem that caused an error when argument has 3 or more nil values Signed-off-by: hlts2 * fix: apply coderabbit suggestion Signed-off-by: hlts2 * Update internal/errors/errors.go Co-authored-by: Yusuke Kato Signed-off-by: Hiroto Funakoshi * fix: create error if e is nil Signed-off-by: hlts2 --------- Signed-off-by: hlts2 Signed-off-by: Hiroto Funakoshi Co-authored-by: Yusuke Kato --- internal/errors/errors.go | 14 ++- internal/errors/errors_test.go | 169 ++++++++++++++++----------------- 2 files changed, 94 insertions(+), 89 deletions(-) diff --git a/internal/errors/errors.go b/internal/errors/errors.go index b0e482ab26..e7165313e1 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -263,12 +263,17 @@ func Join(errs ...error) error { var e *joinError switch x := errs[0].(type) { case *joinError: - e = x + if x != nil && len(x.errs) != 0 { + e = x + } errs = errs[1:] case interface{ Unwrap() []error }: - e = &joinError{errs: x.Unwrap()} + if x != nil && len(x.Unwrap()) != 0 { + e = &joinError{errs: x.Unwrap()} + } errs = errs[1:] - default: + } + if e == nil { e = &joinError{ errs: make([]error, 0, l), } @@ -278,6 +283,9 @@ func Join(errs ...error) error { e.errs = append(e.errs, err) } } + if len(e.errs) == 0 { + return nil + } return e } diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go index 43a5457d99..170ff2a5a0 100644 --- a/internal/errors/errors_test.go +++ b/internal/errors/errors_test.go @@ -18,6 +18,8 @@ import ( "math" "reflect" "testing" + + "github.com/vdaas/vald/internal/test/goleak" ) func TestErrTimeoutParseFailed(t *testing.T) { @@ -1649,6 +1651,87 @@ func TestRemoveDuplicates(t *testing.T) { } } +func TestJoin(t *testing.T) { + type args struct { + errs []error + } + type want struct { + err error + } + type test struct { + name string + args args + want want + checkFunc func(want, error) error + beforeFunc func(*testing.T, args) + afterFunc func(*testing.T, args) + } + defaultCheckFunc := func(w want, err error) error { + if !Is(err, w.err) { + return Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + return nil + } + tests := []test{ + { + name: "return nil when all errors are nil", + args: args{ + errs: []error{ + nil, nil, nil, + }, + }, + }, + { + name: "returns an aggregated error when all errors are non-nil and different", + args: args{ + errs: []error{ + New("error1"), New("error2"), New("error3"), + }, + }, + want: want{ + err: &joinError{ + errs: []error{ + New("error1"), New("error2"), New("error3"), + }, + }, + }, + }, + { + name: "returns an error when errors are mixed nil and non-nil", + args: args{ + errs: []error{ + nil, New("error1"), nil, + }, + }, + want: want{ + err: New("error1"), + }, + }, + } + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc(tt, test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(tt, test.args) + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + + err := Join(test.args.errs...) + if err := checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + // NOT IMPLEMENTED BELOW // // func TestUnwrap(t *testing.T) { @@ -1737,92 +1820,6 @@ func TestRemoveDuplicates(t *testing.T) { // } // } // -// func TestJoin(t *testing.T) { -// type args struct { -// errs []error -// } -// type want struct { -// err error -// } -// type test struct { -// name string -// args args -// want want -// checkFunc func(want, error) error -// beforeFunc func(*testing.T, args) -// afterFunc func(*testing.T, args) -// } -// defaultCheckFunc := func(w want, err error) error { -// if !Is(err, w.err) { -// return Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) -// } -// return nil -// } -// tests := []test{ -// // TODO test cases -// /* -// { -// name: "test_case_1", -// args: args { -// errs:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T, args args) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T, args args) { -// t.Helper() -// }, -// }, -// */ -// -// // TODO test cases -// /* -// func() test { -// return test { -// name: "test_case_2", -// args: args { -// errs:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T, args args) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T, args args) { -// t.Helper() -// }, -// } -// }(), -// */ -// } -// -// for _, tc := range tests { -// test := tc -// t.Run(test.name, func(tt *testing.T) { -// tt.Parallel() -// defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) -// if test.beforeFunc != nil { -// test.beforeFunc(tt, test.args) -// } -// if test.afterFunc != nil { -// defer test.afterFunc(tt, test.args) -// } -// checkFunc := test.checkFunc -// if test.checkFunc == nil { -// checkFunc = defaultCheckFunc -// } -// -// err := Join(test.args.errs...) -// if err := checkFunc(test.want, err); err != nil { -// tt.Errorf("error = %v", err) -// } -// -// }) -// } -// } -// // func Test_joinError_Error(t *testing.T) { // type fields struct { // errs []error