From 5e6cd52db1122dd65107db936a4ce8dd86db80d1 Mon Sep 17 00:00:00 2001 From: Nick Ethier Date: Mon, 11 Nov 2019 12:58:22 -0500 Subject: [PATCH] add Group struct for collecting errors from goroutines into a multierror --- group.go | 38 ++++++++++++++++++++++++++++++++++++++ group_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 group.go create mode 100644 group_test.go diff --git a/group.go b/group.go new file mode 100644 index 0000000..9c29efb --- /dev/null +++ b/group.go @@ -0,0 +1,38 @@ +package multierror + +import "sync" + +// Group is a collection of goroutines which return errors that need to be +// coalesced. +type Group struct { + mutex sync.Mutex + err *Error + wg sync.WaitGroup +} + +// Go calls the given function in a new goroutine. +// +// If the function returns an error it is added to the group multierror which +// is returned by Wait. +func (g *Group) Go(f func() error) { + g.wg.Add(1) + + go func() { + defer g.wg.Done() + + if err := f(); err != nil { + g.mutex.Lock() + g.err = Append(g.err, err) + g.mutex.Unlock() + } + }() +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the multierror. +func (g *Group) Wait() *Error { + g.wg.Wait() + g.mutex.Lock() + defer g.mutex.Unlock() + return g.err +} diff --git a/group_test.go b/group_test.go new file mode 100644 index 0000000..9d472fd --- /dev/null +++ b/group_test.go @@ -0,0 +1,44 @@ +package multierror + +import ( + "errors" + "strings" + "testing" +) + +func TestGroup(t *testing.T) { + err1 := errors.New("group_test: 1") + err2 := errors.New("group_test: 2") + + cases := []struct { + errs []error + nilResult bool + }{ + {errs: []error{}, nilResult: true}, + {errs: []error{nil}, nilResult: true}, + {errs: []error{err1}}, + {errs: []error{err1, nil}}, + {errs: []error{err1, nil, err2}}, + } + + for _, tc := range cases { + var g Group + + for _, err := range tc.errs { + err := err + g.Go(func() error { return err }) + + } + + gErr := g.Wait() + if gErr != nil { + for i := range tc.errs { + if tc.errs[i] != nil && !strings.Contains(gErr.Error(), tc.errs[i].Error()) { + t.Fatalf("expected error to contain %q, actual: %v", tc.errs[i].Error(), gErr) + } + } + } else if !tc.nilResult { + t.Fatalf("Group.Wait() should not have returned nil for errs: %v", tc.errs) + } + } +}