From 40b799264c8a2969a03f84d6d909a051fdd96ff0 Mon Sep 17 00:00:00 2001 From: zhyass <34016424+zhyass@users.noreply.github.com> Date: Mon, 1 Jun 2020 16:56:54 +0800 Subject: [PATCH] *: use errgroup to handle error under concurrency #480 [summary] Package errgroup provides synchronization, error propagation, and Context cancelation for groups of goroutines working on subtasks of a common task. We can use the errgroup to handle error under concurrency. [test case] N/A [patch codecov] src/backend/txn.go 89.8% src/backend/xa.go 91.1% src/executor/engine/join_engine.go 95.7% src/executor/engine/union_engine.go 93.9% --- src/backend/txn.go | 113 +++++------ src/backend/xa.go | 40 ++-- src/executor/engine/join_engine.go | 33 ++-- src/executor/engine/union_engine.go | 31 ++- .../golang/sync/errgroup/errgroup.go | 66 +++++++ .../errgroup/errgroup_example_md5all_test.go | 101 ++++++++++ .../golang/sync/errgroup/errgroup_test.go | 176 ++++++++++++++++++ 7 files changed, 427 insertions(+), 133 deletions(-) create mode 100644 src/vendor/github.com/golang/sync/errgroup/errgroup.go create mode 100644 src/vendor/github.com/golang/sync/errgroup/errgroup_example_md5all_test.go create mode 100644 src/vendor/github.com/golang/sync/errgroup/errgroup_test.go diff --git a/src/backend/txn.go b/src/backend/txn.go index 9d423189..0332d42c 100644 --- a/src/backend/txn.go +++ b/src/backend/txn.go @@ -17,6 +17,7 @@ import ( "config" "xbase/sync2" + "github.com/golang/sync/errgroup" "github.com/pkg/errors" "github.com/xelabs/go-mysqlstack/driver" "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" @@ -436,13 +437,11 @@ func (txn *Txn) Execute(req *xcontext.RequestContext) (*sqltypes.Result, error) // Execute used to execute a query to backends. func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) { - var err error var mu sync.Mutex - var wg sync.WaitGroup + var eg errgroup.Group log := txn.log qr := &sqltypes.Result{} - allErrors := make([]error, 0, 8) if txn.twopc { defer queryStats.Record("txn.2pc.execute", time.Now()) @@ -453,10 +452,9 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) } // Execute backend-querys. - oneShard := func(back string, txn *Txn, querys []string) { + oneShard := func(back string, txn *Txn, querys []string) error { var x error var c Connection - defer wg.Done() if c, x = txn.fetchOneConnection(back); x != nil { log.Error("txn.fetch.connection.on[%s].querys[%v].error:%+v", back, querys, x) @@ -475,12 +473,7 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) mu.Unlock() } } - - if x != nil { - mu.Lock() - allErrors = append(allErrors, x) - mu.Unlock() - } + return x } switch req.Mode { @@ -492,25 +485,24 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) if poolz.conf.Role != config.NormalBackend { continue } - - wg.Add(1) - oneShard(back, txn, qs) - break + return qr, oneShard(back, txn, qs) } // ReqScatter mode: execute on the all shards of txn.backends. case xcontext.ReqScatter: qs := []string{req.RawQuery} beLen := len(txn.backends) - for back, poolz := range txn.backends { + for b, poolz := range txn.backends { if poolz.conf.Role != config.NormalBackend { continue } - wg.Add(1) + back := b if beLen > 1 { - go oneShard(back, txn, qs) + eg.Go(func() error { + return oneShard(back, txn, qs) + }) } else { - oneShard(back, txn, qs) + return qr, oneShard(back, txn, qs) } } // ReqNormal mode: execute on the some shards of txn.backends. @@ -527,32 +519,29 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) queryMap[query.Backend] = v } beLen := len(queryMap) - for back, qs := range queryMap { - wg.Add(1) + for b, qs := range queryMap { + back := b + querys := qs if beLen > 1 { - go oneShard(back, txn, qs) + eg.Go(func() error { + return oneShard(back, txn, querys) + }) } else { - oneShard(back, txn, qs) + return qr, oneShard(back, txn, qs) } } } - - wg.Wait() - if len(allErrors) > 0 { - err = allErrors[0] - } - return qr, err + return qr, eg.Wait() } // ExecuteStreamFetch used to execute stream fetch query. func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*sqltypes.Result) error, streamBufferSize int) error { var err error var mu sync.Mutex - var wg sync.WaitGroup + var eg errgroup.Group log := txn.log cursors := make([]driver.Rows, 0, 8) - allErrors := make([]error, 0, 8) defer func() { for _, cursor := range cursors { @@ -560,18 +549,14 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* } }() - oneShard := func(c Connection, query string) { - defer wg.Done() + oneShard := func(c Connection, query string) error { cursor, x := c.ExecuteStreamFetch(query) - if x != nil { + if x == nil { mu.Lock() - allErrors = append(allErrors, x) + cursors = append(cursors, cursor) mu.Unlock() - return } - mu.Lock() - cursors = append(cursors, cursor) - mu.Unlock() + return x } for _, qt := range req.Querys { @@ -579,12 +564,13 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* if conn, err = txn.fetchOneConnection(qt.Backend); err != nil { return err } - wg.Add(1) - go oneShard(conn, qt.Query) + query := qt.Query + eg.Go(func() error { + return oneShard(conn, query) + }) } - wg.Wait() - if len(allErrors) > 0 { - return allErrors[0] + if err = eg.Wait(); err != nil { + return err } // Send Fields. @@ -598,25 +584,23 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* cursorFinished := 0 rows := make(chan []sqltypes.Value, 65536) stop := make(chan bool) - oneFetch := func(name string, cursor driver.Rows) { - defer wg.Done() + oneFetch := func(name string, cursor driver.Rows) error { for { if cursor.Next() { row, err := cursor.RowValues() if err != nil { log.Error("txn.stream.cursor[%s].RowValues.error:%+v", name, err) mu.Lock() - allErrors = append(allErrors, err) cursorFinished++ if cursorFinished == len(cursors) { close(rows) } mu.Unlock() - return + return err } select { case <-stop: - return + return nil case rows <- row: } } else { @@ -626,7 +610,7 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* close(rows) } mu.Unlock() - return + return nil } } } @@ -634,19 +618,19 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* // producer. for i, cursor := range cursors { name := req.Querys[i].Backend - wg.Add(1) - go oneFetch(name, cursor) + rows := cursor + eg.Go(func() error { + return oneFetch(name, rows) + }) } // consumer. var allRowCount uint64 - wg.Add(1) - go func() { + eg.Go(func() error { var allByteCount, allBatchCount uint64 byteCount := 0 qr := &sqltypes.Result{Fields: fields, Rows: make([][]sqltypes.Value, 0, 256), State: sqltypes.RStateRows} defer func() { close(stop) - wg.Done() }() for { if row, ok := <-rows; ok { @@ -659,10 +643,7 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* if byteCount >= streamBufferSize { if x := callback(qr); x != nil { log.Error("txn.stream.cursor.send1.error:%+v", x) - mu.Lock() - allErrors = append(allErrors, x) - mu.Unlock() - return + return x } qr.Rows = qr.Rows[:0] allBatchCount++ @@ -672,20 +653,16 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(* if len(qr.Rows) > 0 { if x := callback(qr); x != nil { log.Error("txn.stream.cursor.send2.error:%+v", x) - mu.Lock() - allErrors = append(allErrors, x) - mu.Unlock() - return + return x } } log.Warning("txn.stream.send.done[allRows:%v, allBytes:%v, allBatches:%v]", allRowCount, allByteCount, allBatchCount) - return + return nil } } - }() - wg.Wait() - if len(allErrors) > 0 { - return allErrors[0] + }) + if err = eg.Wait(); err != nil { + return err } // Send finished. diff --git a/src/backend/xa.go b/src/backend/xa.go index c1ff1500..d6812de6 100644 --- a/src/backend/xa.go +++ b/src/backend/xa.go @@ -10,10 +10,10 @@ package backend import ( "fmt" - "sync" "time" "xcontext" + "github.com/golang/sync/errgroup" "github.com/xelabs/go-mysqlstack/sqldb" ) @@ -64,19 +64,15 @@ func (txn *Txn) executeXACommand(query string, state txnXAState) error { // executeXA only used to execute the 'XA START','XA END', 'XA PREPARE', 'XA COMMIT'/'XA ROLLBACK' statements. func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error { - var err error - var mu sync.Mutex - var wg sync.WaitGroup + var eg errgroup.Group log := txn.log - allErrors := make([]error, 0, 8) txn.state.Set(int32(txnStateExecutingTwoPC)) defer queryStats.Record("txn.2pc.execute", time.Now()) - oneShard := func(state txnXAState, back string, txn *Txn, query string) { + oneShard := func(state txnXAState, back string, txn *Txn, query string) error { var x error var c Connection - defer wg.Done() switch state { case txnXAStateStart, txnXAStateEnd, txnXAStatePrepare: @@ -124,12 +120,7 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error break } } - - if x != nil { - mu.Lock() - allErrors = append(allErrors, x) - mu.Unlock() - } + return x } switch req.Mode { @@ -152,9 +143,11 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error defer txn.mgr.CommitUnlock() } - for back := range backends { - wg.Add(1) - go oneShard(state, back, txn, req.RawQuery) + for b := range backends { + back := b + eg.Go(func() error { + return oneShard(state, back, txn, req.RawQuery) + }) } } case xcontext.ReqScatter: @@ -166,17 +159,14 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error defer txn.mgr.CommitUnlock() } - for back := range backends { - wg.Add(1) - go oneShard(state, back, txn, req.RawQuery) + for b := range backends { + back := b + eg.Go(func() error { + return oneShard(state, back, txn, req.RawQuery) + }) } } - - wg.Wait() - if len(allErrors) > 0 { - err = allErrors[0] - } - return err + return eg.Wait() } func (txn *Txn) xaStart() error { diff --git a/src/executor/engine/join_engine.go b/src/executor/engine/join_engine.go index 5c75de1a..fc526ff4 100644 --- a/src/executor/engine/join_engine.go +++ b/src/executor/engine/join_engine.go @@ -9,13 +9,12 @@ package engine import ( - "sync" - "backend" "executor/engine/operator" "planner/builder" "xcontext" + "github.com/golang/sync/errgroup" "github.com/pkg/errors" querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" @@ -45,17 +44,8 @@ func NewJoinEngine(log *xlog.Log, node *builder.JoinNode, txn backend.Transactio // Execute used to execute the executor. func (j *JoinEngine) Execute(ctx *xcontext.ResultContext) error { - var mu sync.Mutex - var wg sync.WaitGroup - allErrors := make([]error, 0, 2) - oneExec := func(exec PlanEngine, ctx *xcontext.ResultContext) { - defer wg.Done() - if err := exec.Execute(ctx); err != nil { - mu.Lock() - allErrors = append(allErrors, err) - mu.Unlock() - } - } + var eg errgroup.Group + var err error maxrow := j.txn.MaxJoinRows() if j.node.Strategy == builder.NestLoop { @@ -66,13 +56,15 @@ func (j *JoinEngine) Execute(ctx *xcontext.ResultContext) error { } else { lctx := xcontext.NewResultContext() rctx := xcontext.NewResultContext() - wg.Add(1) - go oneExec(j.left, lctx) - wg.Add(1) - go oneExec(j.right, rctx) - wg.Wait() - if len(allErrors) > 0 { - return allErrors[0] + + eg.Go(func() error { + return j.left.Execute(lctx) + }) + eg.Go(func() error { + return j.right.Execute(rctx) + }) + if err = eg.Wait(); err != nil { + return err } ctx.Results = &sqltypes.Result{} @@ -81,7 +73,6 @@ func (j *JoinEngine) Execute(ctx *xcontext.ResultContext) error { return nil } - var err error if len(rctx.Results.Rows) == 0 { err = concatLeftAndNil(lctx.Results.Rows, j.node, ctx.Results, maxrow) } else { diff --git a/src/executor/engine/union_engine.go b/src/executor/engine/union_engine.go index c435b051..394ab971 100644 --- a/src/executor/engine/union_engine.go +++ b/src/executor/engine/union_engine.go @@ -10,13 +10,13 @@ package engine import ( "errors" - "sync" "backend" "executor/engine/operator" "planner/builder" "xcontext" + "github.com/golang/sync/errgroup" "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" @@ -46,26 +46,19 @@ func NewUnionEngine(log *xlog.Log, node *builder.UnionNode, txn backend.Transact // Execute used to execute the executor. func (u *UnionEngine) Execute(ctx *xcontext.ResultContext) error { - var mu sync.Mutex - var wg sync.WaitGroup - allErrors := make([]error, 0, 2) - oneExec := func(exec PlanEngine, ctx *xcontext.ResultContext) { - defer wg.Done() - if err := exec.Execute(ctx); err != nil { - mu.Lock() - allErrors = append(allErrors, err) - mu.Unlock() - } - } + var eg errgroup.Group + lctx := xcontext.NewResultContext() rctx := xcontext.NewResultContext() - wg.Add(1) - go oneExec(u.left, lctx) - wg.Add(1) - go oneExec(u.right, rctx) - wg.Wait() - if len(allErrors) > 0 { - return allErrors[0] + + eg.Go(func() error { + return u.left.Execute(lctx) + }) + eg.Go(func() error { + return u.right.Execute(rctx) + }) + if err := eg.Wait(); err != nil { + return err } if len(lctx.Results.Fields) != len(rctx.Results.Fields) { diff --git a/src/vendor/github.com/golang/sync/errgroup/errgroup.go b/src/vendor/github.com/golang/sync/errgroup/errgroup.go new file mode 100644 index 00000000..9857fe53 --- /dev/null +++ b/src/vendor/github.com/golang/sync/errgroup/errgroup.go @@ -0,0 +1,66 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "context" + "sync" +) + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + errOnce sync.Once + err error +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// +// The first call to return a non-nil error cancels the group; its error will be +// returned by Wait. +func (g *Group) Go(f func() error) { + g.wg.Add(1) + + go func() { + defer g.wg.Done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} diff --git a/src/vendor/github.com/golang/sync/errgroup/errgroup_example_md5all_test.go b/src/vendor/github.com/golang/sync/errgroup/errgroup_example_md5all_test.go new file mode 100644 index 00000000..816278fd --- /dev/null +++ b/src/vendor/github.com/golang/sync/errgroup/errgroup_example_md5all_test.go @@ -0,0 +1,101 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errgroup_test + +import ( + "context" + "crypto/md5" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + + "github.com/golang/sync/errgroup" +) + +// Pipeline demonstrates the use of a Group to implement a multi-stage +// pipeline: a version of the MD5All function with bounded parallelism from +// https://blog.golang.org/pipelines. +func ExampleGroup_pipeline() { + m, err := MD5All(context.Background(), ".") + if err != nil { + log.Fatal(err) + } + + for k, sum := range m { + fmt.Printf("%s:\t%x\n", k, sum) + } +} + +type result struct { + path string + sum [md5.Size]byte +} + +// MD5All reads all the files in the file tree rooted at root and returns a map +// from file path to the MD5 sum of the file's contents. If the directory walk +// fails or any read operation fails, MD5All returns an error. +func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) { + // ctx is canceled when g.Wait() returns. When this version of MD5All returns + // - even in case of error! - we know that all of the goroutines have finished + // and the memory they were using can be garbage-collected. + g, ctx := errgroup.WithContext(ctx) + paths := make(chan string) + + g.Go(func() error { + defer close(paths) + return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return nil + } + select { + case paths <- path: + case <-ctx.Done(): + return ctx.Err() + } + return nil + }) + }) + + // Start a fixed number of goroutines to read and digest files. + c := make(chan result) + const numDigesters = 20 + for i := 0; i < numDigesters; i++ { + g.Go(func() error { + for path := range paths { + data, err := ioutil.ReadFile(path) + if err != nil { + return err + } + select { + case c <- result{path, md5.Sum(data)}: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + } + go func() { + g.Wait() + close(c) + }() + + m := make(map[string][md5.Size]byte) + for r := range c { + m[r.path] = r.sum + } + // Check whether any of the goroutines failed. Since g is accumulating the + // errors, we don't need to send them (or check for them) in the individual + // results sent on the channel. + if err := g.Wait(); err != nil { + return nil, err + } + return m, nil +} diff --git a/src/vendor/github.com/golang/sync/errgroup/errgroup_test.go b/src/vendor/github.com/golang/sync/errgroup/errgroup_test.go new file mode 100644 index 00000000..62715c6d --- /dev/null +++ b/src/vendor/github.com/golang/sync/errgroup/errgroup_test.go @@ -0,0 +1,176 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errgroup_test + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + + "github.com/golang/sync/errgroup" +) + +var ( + Web = fakeSearch("web") + Image = fakeSearch("image") + Video = fakeSearch("video") +) + +type Result string +type Search func(ctx context.Context, query string) (Result, error) + +func fakeSearch(kind string) Search { + return func(_ context.Context, query string) (Result, error) { + return Result(fmt.Sprintf("%s result for %q", kind, query)), nil + } +} + +// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to +// simplify goroutine counting and error handling. This example is derived from +// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup. +func ExampleGroup_justErrors() { + var g errgroup.Group + var urls = []string{ + "http://www.golang.org/", + "http://www.google.com/", + "http://www.somestupidname.com/", + } + for _, url := range urls { + // Launch a goroutine to fetch the URL. + url := url // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func() error { + // Fetch the URL. + resp, err := http.Get(url) + if err == nil { + resp.Body.Close() + } + return err + }) + } + // Wait for all HTTP fetches to complete. + if err := g.Wait(); err == nil { + fmt.Println("Successfully fetched all URLs.") + } +} + +// Parallel illustrates the use of a Group for synchronizing a simple parallel +// task: the "Google Search 2.0" function from +// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context +// and error-handling. +func ExampleGroup_parallel() { + Google := func(ctx context.Context, query string) ([]Result, error) { + g, ctx := errgroup.WithContext(ctx) + + searches := []Search{Web, Image, Video} + results := make([]Result, len(searches)) + for i, search := range searches { + i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func() error { + result, err := search(ctx, query) + if err == nil { + results[i] = result + } + return err + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + return results, nil + } + + results, err := Google(context.Background(), "golang") + if err != nil { + fmt.Fprintln(os.Stderr, err) + return + } + for _, result := range results { + fmt.Println(result) + } + + // Output: + // web result for "golang" + // image result for "golang" + // video result for "golang" +} + +func TestZeroGroup(t *testing.T) { + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + cases := []struct { + errs []error + }{ + {errs: []error{}}, + {errs: []error{nil}}, + {errs: []error{err1}}, + {errs: []error{err1, nil}}, + {errs: []error{err1, nil, err2}}, + } + + for _, tc := range cases { + var g errgroup.Group + + var firstErr error + for i, err := range tc.errs { + err := err + g.Go(func() error { return err }) + + if firstErr == nil && err != nil { + firstErr = err + } + + if gErr := g.Wait(); gErr != firstErr { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "g.Wait() = %v; want %v", + g, tc.errs[:i+1], err, firstErr) + } + } + } +} + +func TestWithContext(t *testing.T) { + errDoom := errors.New("group_test: doomed") + + cases := []struct { + errs []error + want error + }{ + {want: nil}, + {errs: []error{nil}, want: nil}, + {errs: []error{errDoom}, want: errDoom}, + {errs: []error{errDoom, nil}, want: errDoom}, + } + + for _, tc := range cases { + g, ctx := errgroup.WithContext(context.Background()) + + for _, err := range tc.errs { + err := err + g.Go(func() error { return err }) + } + + if err := g.Wait(); err != tc.want { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "g.Wait() = %v; want %v", + g, tc.errs, err, tc.want) + } + + canceled := false + select { + case <-ctx.Done(): + canceled = true + default: + } + if !canceled { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "ctx.Done() was not closed", + g, tc.errs) + } + } +}