From eafe1f713698bf53f797fe92f234acbfa0131771 Mon Sep 17 00:00:00 2001 From: Steven Landers Date: Wed, 22 Nov 2023 13:46:38 -0500 Subject: [PATCH] add pool optimizations --- tasks/scheduler.go | 172 +++++++++++++++++++++++----------------- tasks/scheduler_test.go | 35 ++++++++ 2 files changed, 136 insertions(+), 71 deletions(-) diff --git a/tasks/scheduler.go b/tasks/scheduler.go index c00e70dbe..3d289bee2 100644 --- a/tasks/scheduler.go +++ b/tasks/scheduler.go @@ -1,21 +1,20 @@ package tasks import ( + "context" "crypto/sha256" "fmt" - "sort" - "sync" - - "github.com/tendermint/tendermint/abci/types" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" + "sort" + "sync" "github.com/cosmos/cosmos-sdk/store/multiversion" store "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/occ" "github.com/cosmos/cosmos-sdk/utils/tracing" + "github.com/tendermint/tendermint/abci/types" ) type status string @@ -49,7 +48,7 @@ type deliverTxTask struct { Request types.RequestDeliverTx Response *types.ResponseDeliverTx VersionStores map[sdk.StoreKey]*multiversion.VersionIndexedStore - ValidateCh chan struct{} + ValidateCh chan status } func (dt *deliverTxTask) Reset() { @@ -63,7 +62,7 @@ func (dt *deliverTxTask) Reset() { func (dt *deliverTxTask) Increment() { dt.Incarnation++ - dt.ValidateCh = make(chan struct{}, 1) + dt.ValidateCh = make(chan status, 1) } // Scheduler processes tasks concurrently @@ -77,6 +76,8 @@ type scheduler struct { multiVersionStores map[sdk.StoreKey]multiversion.MultiVersionStore tracingInfo *tracing.Info allTasks []*deliverTxTask + executeCh chan func() + validateCh chan func() } // NewScheduler creates a new scheduler @@ -94,6 +95,29 @@ func (s *scheduler) invalidateTask(task *deliverTxTask) { } } +func start(ctx context.Context, ch chan func(), workers int) { + for i := 0; i < workers; i++ { + go func() { + for { + select { + case <-ctx.Done(): + return + case work := <-ch: + work() + } + } + }() + } +} + +func (s *scheduler) DoValidate(work func()) { + s.validateCh <- work +} + +func (s *scheduler) DoExecute(work func()) { + s.executeCh <- work +} + func (s *scheduler) findConflicts(task *deliverTxTask) (bool, []int) { var conflicts []int uniq := make(map[int]struct{}) @@ -120,7 +144,7 @@ func toTasks(reqs []*sdk.DeliverTxEntry) []*deliverTxTask { Request: r.Request, Index: idx, Status: statusPending, - ValidateCh: make(chan struct{}, 1), + ValidateCh: make(chan status, 1), }) } return res @@ -164,7 +188,7 @@ func allValidated(tasks []*deliverTxTask) bool { return true } -func (s *scheduler) PrefillEstimates(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) { +func (s *scheduler) PrefillEstimates(reqs []*sdk.DeliverTxEntry) { // iterate over TXs, update estimated writesets where applicable for i, req := range reqs { mappedWritesets := req.EstimatedWritesets @@ -180,9 +204,27 @@ func (s *scheduler) ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]t // initialize mutli-version stores if they haven't been initialized yet s.tryInitMultiVersionStore(ctx) // prefill estimates - s.PrefillEstimates(ctx, reqs) + s.PrefillEstimates(reqs) tasks := toTasks(reqs) s.allTasks = tasks + s.executeCh = make(chan func(), len(tasks)) + s.validateCh = make(chan func(), len(tasks)) + + // default to number of tasks if workers is negative or 0 by this point + workers := s.workers + if s.workers < 1 { + workers = len(tasks) + } + + workerCtx, cancel := context.WithCancel(ctx.Context()) + defer cancel() + + // execution tasks are limited by workers + start(workerCtx, s.executeCh, workers) + + // validation tasks uses length of tasks to avoid blocking on validation + start(workerCtx, s.validateCh, len(tasks)) + toExecute := tasks for !allValidated(tasks) { var err error @@ -269,19 +311,26 @@ func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*del var mx sync.Mutex var res []*deliverTxTask + startIdx, anyLeft := s.findFirstNonValidated() + + if !anyLeft { + return nil, nil + } + wg := sync.WaitGroup{} - for i := 0; i < len(tasks); i++ { + for i := startIdx; i < len(tasks); i++ { + t := tasks[i] wg.Add(1) - go func(task *deliverTxTask) { + s.DoValidate(func() { defer wg.Done() - if !s.validateTask(ctx, task) { - task.Reset() - task.Increment() + if !s.validateTask(ctx, t) { + t.Reset() + t.Increment() mx.Lock() - res = append(res, task) + res = append(res, t) mx.Unlock() } - }(tasks[i]) + }) } wg.Wait() @@ -289,56 +338,47 @@ func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*del } // ExecuteAll executes all tasks concurrently -// Tasks are updated with their status -// TODO: error scenarios func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error { ctx, span := s.traceSpan(ctx, "SchedulerExecuteAll", nil) defer span.End() - ch := make(chan *deliverTxTask, len(tasks)) - grp, gCtx := errgroup.WithContext(ctx.Context()) - - // a workers value < 1 means no limit - workers := s.workers - if s.workers < 1 { - workers = len(tasks) - } - // validationWg waits for all validations to complete // validations happen in separate goroutines in order to wait on previous index validationWg := &sync.WaitGroup{} validationWg.Add(len(tasks)) - grp.Go(func() error { - validationWg.Wait() - return nil - }) - for i := 0; i < workers; i++ { - grp.Go(func() error { - for { - select { - case <-gCtx.Done(): - return gCtx.Err() - case task, ok := <-ch: - if !ok { - return nil - } - s.prepareAndRunTask(validationWg, ctx, task) - } - } + for _, task := range tasks { + t := task + s.DoExecute(func() { + s.prepareAndRunTask(validationWg, ctx, t) }) } - for _, task := range tasks { - ch <- task - } - close(ch) + validationWg.Wait() + + return nil +} - if err := grp.Wait(); err != nil { - return err +func (s *scheduler) waitOnPreviousAndValidate(wg *sync.WaitGroup, task *deliverTxTask) { + defer wg.Done() + defer close(task.ValidateCh) + // wait on previous task to finish validation + // if a previous task fails validation, then subsequent should fail too (cascade) + if task.Index > 0 { + res, ok := <-s.allTasks[task.Index-1].ValidateCh + if ok && res != statusValidated { + task.Reset() + task.ValidateCh <- task.Status + return + } + } + // if not validated, reset the task + if !s.validateTask(task.Ctx, task) { + task.Reset() } - return nil + // notify next task of this one's status + task.ValidateCh <- task.Status } func (s *scheduler) prepareAndRunTask(wg *sync.WaitGroup, ctx sdk.Context, task *deliverTxTask) { @@ -346,19 +386,12 @@ func (s *scheduler) prepareAndRunTask(wg *sync.WaitGroup, ctx sdk.Context, task defer eSpan.End() task.Ctx = eCtx - s.executeTask(task.Ctx, task) - go func() { - defer wg.Done() - defer close(task.ValidateCh) - // wait on previous task to finish validation - if task.Index > 0 { - <-s.allTasks[task.Index-1].ValidateCh - } - if !s.validateTask(task.Ctx, task) { - task.Reset() - } - task.ValidateCh <- struct{}{} - }() + s.prepareTask(task) + s.executeTask(task) + + s.DoValidate(func() { + s.waitOnPreviousAndValidate(wg, task) + }) } func (s *scheduler) traceSpan(ctx sdk.Context, name string, task *deliverTxTask) (sdk.Context, trace.Span) { @@ -373,8 +406,8 @@ func (s *scheduler) traceSpan(ctx sdk.Context, name string, task *deliverTxTask) } // prepareTask initializes the context and version stores for a task -func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) { - ctx = ctx.WithTxIndex(task.Index) +func (s *scheduler) prepareTask(task *deliverTxTask) { + ctx := task.Ctx.WithTxIndex(task.Index) _, span := s.traceSpan(ctx, "SchedulerPrepare", task) defer span.End() @@ -407,10 +440,7 @@ func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) { } // executeTask executes a single task -func (s *scheduler) executeTask(ctx sdk.Context, task *deliverTxTask) { - - s.prepareTask(ctx, task) - +func (s *scheduler) executeTask(task *deliverTxTask) { dCtx, dSpan := s.traceSpan(task.Ctx, "SchedulerDeliverTx", task) defer dSpan.End() task.Ctx = dCtx diff --git a/tasks/scheduler_test.go b/tasks/scheduler_test.go index 9d24b54a8..886bbe5ce 100644 --- a/tasks/scheduler_test.go +++ b/tasks/scheduler_test.go @@ -68,6 +68,41 @@ func TestProcessAll(t *testing.T) { workers: 50, runs: 50, addStores: true, + requests: requestList(100), + deliverTxFunc: func(ctx sdk.Context, req types.RequestDeliverTx) types.ResponseDeliverTx { + // all txs read and write to the same key to maximize conflicts + kv := ctx.MultiStore().GetKVStore(testStoreKey) + val := string(kv.Get(itemKey)) + + // write to the store with this tx's index + kv.Set(itemKey, req.Tx) + + // return what was read from the store (final attempt should be index-1) + return types.ResponseDeliverTx{ + Info: val, + } + }, + assertions: func(t *testing.T, ctx sdk.Context, res []types.ResponseDeliverTx) { + for idx, response := range res { + if idx == 0 { + require.Equal(t, "", response.Info) + } else { + // the info is what was read from the kv store by the tx + // each tx writes its own index, so the info should be the index of the previous tx + require.Equal(t, fmt.Sprintf("%d", idx-1), response.Info) + } + } + // confirm last write made it to the parent store + latest := ctx.MultiStore().GetKVStore(testStoreKey).Get(itemKey) + require.Equal(t, []byte(fmt.Sprintf("%d", len(res)-1)), latest) + }, + expectedErr: nil, + }, + { + name: "Test few workers many txs", + workers: 5, + runs: 10, + addStores: true, requests: requestList(50), deliverTxFunc: func(ctx sdk.Context, req types.RequestDeliverTx) types.ResponseDeliverTx { // all txs read and write to the same key to maximize conflicts