Skip to content

Commit

Permalink
*: use errgroup to handle error under concurrency radondb#480
Browse files Browse the repository at this point in the history
[summary]
Package errgroup provides synchronization, error propagation, and Context cancelation for groups of goroutines working on subtasks of a common task.
We can use the errgroup to handle error under concurrency.
[test case]
N/A
[patch codecov]
src/backend/txn.go 89.8%
src/backend/xa.go 91.1%
src/executor/engine/join_engine.go 95.7%
src/executor/engine/union_engine.go 93.9%
  • Loading branch information
zhyass committed Jun 2, 2020
1 parent cbf377f commit 40b7992
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 133 deletions.
113 changes: 45 additions & 68 deletions src/backend/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"config"
"xbase/sync2"

"github.com/golang/sync/errgroup"
"github.com/pkg/errors"
"github.com/xelabs/go-mysqlstack/driver"
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
Expand Down Expand Up @@ -436,13 +437,11 @@ func (txn *Txn) Execute(req *xcontext.RequestContext) (*sqltypes.Result, error)

// Execute used to execute a query to backends.
func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error) {
var err error
var mu sync.Mutex
var wg sync.WaitGroup
var eg errgroup.Group

log := txn.log
qr := &sqltypes.Result{}
allErrors := make([]error, 0, 8)

if txn.twopc {
defer queryStats.Record("txn.2pc.execute", time.Now())
Expand All @@ -453,10 +452,9 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error)
}

// Execute backend-querys.
oneShard := func(back string, txn *Txn, querys []string) {
oneShard := func(back string, txn *Txn, querys []string) error {
var x error
var c Connection
defer wg.Done()

if c, x = txn.fetchOneConnection(back); x != nil {
log.Error("txn.fetch.connection.on[%s].querys[%v].error:%+v", back, querys, x)
Expand All @@ -475,12 +473,7 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error)
mu.Unlock()
}
}

if x != nil {
mu.Lock()
allErrors = append(allErrors, x)
mu.Unlock()
}
return x
}

switch req.Mode {
Expand All @@ -492,25 +485,24 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error)
if poolz.conf.Role != config.NormalBackend {
continue
}

wg.Add(1)
oneShard(back, txn, qs)
break
return qr, oneShard(back, txn, qs)
}
// ReqScatter mode: execute on the all shards of txn.backends.
case xcontext.ReqScatter:
qs := []string{req.RawQuery}
beLen := len(txn.backends)
for back, poolz := range txn.backends {
for b, poolz := range txn.backends {
if poolz.conf.Role != config.NormalBackend {
continue
}

wg.Add(1)
back := b
if beLen > 1 {
go oneShard(back, txn, qs)
eg.Go(func() error {
return oneShard(back, txn, qs)
})
} else {
oneShard(back, txn, qs)
return qr, oneShard(back, txn, qs)
}
}
// ReqNormal mode: execute on the some shards of txn.backends.
Expand All @@ -527,64 +519,58 @@ func (txn *Txn) execute(req *xcontext.RequestContext) (*sqltypes.Result, error)
queryMap[query.Backend] = v
}
beLen := len(queryMap)
for back, qs := range queryMap {
wg.Add(1)
for b, qs := range queryMap {
back := b
querys := qs
if beLen > 1 {
go oneShard(back, txn, qs)
eg.Go(func() error {
return oneShard(back, txn, querys)
})
} else {
oneShard(back, txn, qs)
return qr, oneShard(back, txn, qs)
}
}
}

wg.Wait()
if len(allErrors) > 0 {
err = allErrors[0]
}
return qr, err
return qr, eg.Wait()
}

// ExecuteStreamFetch used to execute stream fetch query.
func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*sqltypes.Result) error, streamBufferSize int) error {
var err error
var mu sync.Mutex
var wg sync.WaitGroup
var eg errgroup.Group

log := txn.log
cursors := make([]driver.Rows, 0, 8)
allErrors := make([]error, 0, 8)

defer func() {
for _, cursor := range cursors {
cursor.Close()
}
}()

oneShard := func(c Connection, query string) {
defer wg.Done()
oneShard := func(c Connection, query string) error {
cursor, x := c.ExecuteStreamFetch(query)
if x != nil {
if x == nil {
mu.Lock()
allErrors = append(allErrors, x)
cursors = append(cursors, cursor)
mu.Unlock()
return
}
mu.Lock()
cursors = append(cursors, cursor)
mu.Unlock()
return x
}

for _, qt := range req.Querys {
var conn Connection
if conn, err = txn.fetchOneConnection(qt.Backend); err != nil {
return err
}
wg.Add(1)
go oneShard(conn, qt.Query)
query := qt.Query
eg.Go(func() error {
return oneShard(conn, query)
})
}
wg.Wait()
if len(allErrors) > 0 {
return allErrors[0]
if err = eg.Wait(); err != nil {
return err
}

// Send Fields.
Expand All @@ -598,25 +584,23 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*
cursorFinished := 0
rows := make(chan []sqltypes.Value, 65536)
stop := make(chan bool)
oneFetch := func(name string, cursor driver.Rows) {
defer wg.Done()
oneFetch := func(name string, cursor driver.Rows) error {
for {
if cursor.Next() {
row, err := cursor.RowValues()
if err != nil {
log.Error("txn.stream.cursor[%s].RowValues.error:%+v", name, err)
mu.Lock()
allErrors = append(allErrors, err)
cursorFinished++
if cursorFinished == len(cursors) {
close(rows)
}
mu.Unlock()
return
return err
}
select {
case <-stop:
return
return nil
case rows <- row:
}
} else {
Expand All @@ -626,27 +610,27 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*
close(rows)
}
mu.Unlock()
return
return nil
}
}
}

// producer.
for i, cursor := range cursors {
name := req.Querys[i].Backend
wg.Add(1)
go oneFetch(name, cursor)
rows := cursor
eg.Go(func() error {
return oneFetch(name, rows)
})
}
// consumer.
var allRowCount uint64
wg.Add(1)
go func() {
eg.Go(func() error {
var allByteCount, allBatchCount uint64
byteCount := 0
qr := &sqltypes.Result{Fields: fields, Rows: make([][]sqltypes.Value, 0, 256), State: sqltypes.RStateRows}
defer func() {
close(stop)
wg.Done()
}()
for {
if row, ok := <-rows; ok {
Expand All @@ -659,10 +643,7 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*
if byteCount >= streamBufferSize {
if x := callback(qr); x != nil {
log.Error("txn.stream.cursor.send1.error:%+v", x)
mu.Lock()
allErrors = append(allErrors, x)
mu.Unlock()
return
return x
}
qr.Rows = qr.Rows[:0]
allBatchCount++
Expand All @@ -672,20 +653,16 @@ func (txn *Txn) ExecuteStreamFetch(req *xcontext.RequestContext, callback func(*
if len(qr.Rows) > 0 {
if x := callback(qr); x != nil {
log.Error("txn.stream.cursor.send2.error:%+v", x)
mu.Lock()
allErrors = append(allErrors, x)
mu.Unlock()
return
return x
}
}
log.Warning("txn.stream.send.done[allRows:%v, allBytes:%v, allBatches:%v]", allRowCount, allByteCount, allBatchCount)
return
return nil
}
}
}()
wg.Wait()
if len(allErrors) > 0 {
return allErrors[0]
})
if err = eg.Wait(); err != nil {
return err
}

// Send finished.
Expand Down
40 changes: 15 additions & 25 deletions src/backend/xa.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ package backend

import (
"fmt"
"sync"
"time"
"xcontext"

"github.com/golang/sync/errgroup"
"github.com/xelabs/go-mysqlstack/sqldb"
)

Expand Down Expand Up @@ -64,19 +64,15 @@ func (txn *Txn) executeXACommand(query string, state txnXAState) error {

// executeXA only used to execute the 'XA START','XA END', 'XA PREPARE', 'XA COMMIT'/'XA ROLLBACK' statements.
func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error {
var err error
var mu sync.Mutex
var wg sync.WaitGroup
var eg errgroup.Group

log := txn.log
allErrors := make([]error, 0, 8)

txn.state.Set(int32(txnStateExecutingTwoPC))
defer queryStats.Record("txn.2pc.execute", time.Now())
oneShard := func(state txnXAState, back string, txn *Txn, query string) {
oneShard := func(state txnXAState, back string, txn *Txn, query string) error {
var x error
var c Connection
defer wg.Done()

switch state {
case txnXAStateStart, txnXAStateEnd, txnXAStatePrepare:
Expand Down Expand Up @@ -124,12 +120,7 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error
break
}
}

if x != nil {
mu.Lock()
allErrors = append(allErrors, x)
mu.Unlock()
}
return x
}

switch req.Mode {
Expand All @@ -152,9 +143,11 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error
defer txn.mgr.CommitUnlock()
}

for back := range backends {
wg.Add(1)
go oneShard(state, back, txn, req.RawQuery)
for b := range backends {
back := b
eg.Go(func() error {
return oneShard(state, back, txn, req.RawQuery)
})
}
}
case xcontext.ReqScatter:
Expand All @@ -166,17 +159,14 @@ func (txn *Txn) executeXA(req *xcontext.RequestContext, state txnXAState) error
defer txn.mgr.CommitUnlock()
}

for back := range backends {
wg.Add(1)
go oneShard(state, back, txn, req.RawQuery)
for b := range backends {
back := b
eg.Go(func() error {
return oneShard(state, back, txn, req.RawQuery)
})
}
}

wg.Wait()
if len(allErrors) > 0 {
err = allErrors[0]
}
return err
return eg.Wait()
}

func (txn *Txn) xaStart() error {
Expand Down
Loading

0 comments on commit 40b7992

Please sign in to comment.