Skip to content

Commit

Permalink
[OCC] Add scheduler goroutine pool and optimizations (#362)
Browse files Browse the repository at this point in the history
## Describe your changes and provide context
- adds pool optimizations (bounds by tasks / workers)
- adds validateAll shortcut (starts at first non-validated entry)
- adds invalidation of future tasks on invalidation

## Testing performed to validate your change
- unit tests are passing with full conflicting txs
  • Loading branch information
stevenlanders authored and codchen committed Feb 6, 2024
1 parent dca8696 commit 888cf6c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 71 deletions.
172 changes: 101 additions & 71 deletions tasks/scheduler.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package tasks

import (
"context"
"crypto/sha256"
"fmt"
"sort"
"sync"

"github.com/tendermint/tendermint/abci/types"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"

"github.com/cosmos/cosmos-sdk/store/multiversion"
store "github.com/cosmos/cosmos-sdk/store/types"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/occ"
"github.com/cosmos/cosmos-sdk/utils/tracing"
"github.com/tendermint/tendermint/abci/types"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

type status string
Expand Down Expand Up @@ -49,7 +48,7 @@ type deliverTxTask struct {
Request types.RequestDeliverTx
Response *types.ResponseDeliverTx
VersionStores map[sdk.StoreKey]*multiversion.VersionIndexedStore
ValidateCh chan struct{}
ValidateCh chan status
}

func (dt *deliverTxTask) Reset() {
Expand All @@ -63,7 +62,7 @@ func (dt *deliverTxTask) Reset() {

func (dt *deliverTxTask) Increment() {
dt.Incarnation++
dt.ValidateCh = make(chan struct{}, 1)
dt.ValidateCh = make(chan status, 1)
}

// Scheduler processes tasks concurrently
Expand All @@ -77,6 +76,8 @@ type scheduler struct {
multiVersionStores map[sdk.StoreKey]multiversion.MultiVersionStore
tracingInfo *tracing.Info
allTasks []*deliverTxTask
executeCh chan func()
validateCh chan func()
}

// NewScheduler creates a new scheduler
Expand All @@ -96,6 +97,29 @@ func (s *scheduler) invalidateTask(task *deliverTxTask) {
}
}

func start(ctx context.Context, ch chan func(), workers int) {
for i := 0; i < workers; i++ {
go func() {
for {
select {
case <-ctx.Done():
return
case work := <-ch:
work()
}
}
}()
}
}

func (s *scheduler) DoValidate(work func()) {
s.validateCh <- work
}

func (s *scheduler) DoExecute(work func()) {
s.executeCh <- work
}

func (s *scheduler) findConflicts(task *deliverTxTask) (bool, []int) {
var conflicts []int
uniq := make(map[int]struct{})
Expand All @@ -122,7 +146,7 @@ func toTasks(reqs []*sdk.DeliverTxEntry) []*deliverTxTask {
Request: r.Request,
Index: idx,
Status: statusPending,
ValidateCh: make(chan struct{}, 1),
ValidateCh: make(chan status, 1),
})
}
return res
Expand Down Expand Up @@ -166,7 +190,7 @@ func allValidated(tasks []*deliverTxTask) bool {
return true
}

func (s *scheduler) PrefillEstimates(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) {
func (s *scheduler) PrefillEstimates(reqs []*sdk.DeliverTxEntry) {
// iterate over TXs, update estimated writesets where applicable
for i, req := range reqs {
mappedWritesets := req.EstimatedWritesets
Expand All @@ -182,9 +206,27 @@ func (s *scheduler) ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]t
// initialize mutli-version stores if they haven't been initialized yet
s.tryInitMultiVersionStore(ctx)
// prefill estimates
s.PrefillEstimates(ctx, reqs)
s.PrefillEstimates(reqs)
tasks := toTasks(reqs)
s.allTasks = tasks
s.executeCh = make(chan func(), len(tasks))
s.validateCh = make(chan func(), len(tasks))

// default to number of tasks if workers is negative or 0 by this point
workers := s.workers
if s.workers < 1 {
workers = len(tasks)
}

workerCtx, cancel := context.WithCancel(ctx.Context())
defer cancel()

// execution tasks are limited by workers
start(workerCtx, s.executeCh, workers)

// validation tasks uses length of tasks to avoid blocking on validation
start(workerCtx, s.validateCh, len(tasks))

toExecute := tasks
for !allValidated(tasks) {
var err error
Expand Down Expand Up @@ -271,96 +313,87 @@ func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*del
var mx sync.Mutex
var res []*deliverTxTask

startIdx, anyLeft := s.findFirstNonValidated()

if !anyLeft {
return nil, nil
}

wg := sync.WaitGroup{}
for i := 0; i < len(tasks); i++ {
for i := startIdx; i < len(tasks); i++ {
t := tasks[i]
wg.Add(1)
go func(task *deliverTxTask) {
s.DoValidate(func() {
defer wg.Done()
if !s.validateTask(ctx, task) {
task.Reset()
task.Increment()
if !s.validateTask(ctx, t) {
t.Reset()
t.Increment()
mx.Lock()
res = append(res, task)
res = append(res, t)
mx.Unlock()
}
}(tasks[i])
})
}
wg.Wait()

return res, nil
}

// ExecuteAll executes all tasks concurrently
// Tasks are updated with their status
// TODO: error scenarios
func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error {
ctx, span := s.traceSpan(ctx, "SchedulerExecuteAll", nil)
defer span.End()

ch := make(chan *deliverTxTask, len(tasks))
grp, gCtx := errgroup.WithContext(ctx.Context())

// a workers value < 1 means no limit
workers := s.workers
if s.workers < 1 {
workers = len(tasks)
}

// validationWg waits for all validations to complete
// validations happen in separate goroutines in order to wait on previous index
validationWg := &sync.WaitGroup{}
validationWg.Add(len(tasks))
grp.Go(func() error {
validationWg.Wait()
return nil
})

for i := 0; i < workers; i++ {
grp.Go(func() error {
for {
select {
case <-gCtx.Done():
return gCtx.Err()
case task, ok := <-ch:
if !ok {
return nil
}
s.prepareAndRunTask(validationWg, ctx, task)
}
}
for _, task := range tasks {
t := task
s.DoExecute(func() {
s.prepareAndRunTask(validationWg, ctx, t)
})
}

for _, task := range tasks {
ch <- task
}
close(ch)
validationWg.Wait()

return nil
}

if err := grp.Wait(); err != nil {
return err
func (s *scheduler) waitOnPreviousAndValidate(wg *sync.WaitGroup, task *deliverTxTask) {
defer wg.Done()
defer close(task.ValidateCh)
// wait on previous task to finish validation
// if a previous task fails validation, then subsequent should fail too (cascade)
if task.Index > 0 {
res, ok := <-s.allTasks[task.Index-1].ValidateCh
if ok && res != statusValidated {
task.Reset()
task.ValidateCh <- task.Status
return
}
}
// if not validated, reset the task
if !s.validateTask(task.Ctx, task) {
task.Reset()
}

return nil
// notify next task of this one's status
task.ValidateCh <- task.Status
}

func (s *scheduler) prepareAndRunTask(wg *sync.WaitGroup, ctx sdk.Context, task *deliverTxTask) {
eCtx, eSpan := s.traceSpan(ctx, "SchedulerExecute", task)
defer eSpan.End()
task.Ctx = eCtx

s.executeTask(task.Ctx, task)
go func() {
defer wg.Done()
defer close(task.ValidateCh)
// wait on previous task to finish validation
if task.Index > 0 {
<-s.allTasks[task.Index-1].ValidateCh
}
if !s.validateTask(task.Ctx, task) {
task.Reset()
}
task.ValidateCh <- struct{}{}
}()
s.prepareTask(task)
s.executeTask(task)

s.DoValidate(func() {
s.waitOnPreviousAndValidate(wg, task)
})
}

func (s *scheduler) traceSpan(ctx sdk.Context, name string, task *deliverTxTask) (sdk.Context, trace.Span) {
Expand All @@ -375,8 +408,8 @@ func (s *scheduler) traceSpan(ctx sdk.Context, name string, task *deliverTxTask)
}

// prepareTask initializes the context and version stores for a task
func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) {
ctx = ctx.WithTxIndex(task.Index)
func (s *scheduler) prepareTask(task *deliverTxTask) {
ctx := task.Ctx.WithTxIndex(task.Index)

_, span := s.traceSpan(ctx, "SchedulerPrepare", task)
defer span.End()
Expand Down Expand Up @@ -409,10 +442,7 @@ func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) {
}

// executeTask executes a single task
func (s *scheduler) executeTask(ctx sdk.Context, task *deliverTxTask) {

s.prepareTask(ctx, task)

func (s *scheduler) executeTask(task *deliverTxTask) {
dCtx, dSpan := s.traceSpan(task.Ctx, "SchedulerDeliverTx", task)
defer dSpan.End()
task.Ctx = dCtx
Expand Down
35 changes: 35 additions & 0 deletions tasks/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ func TestProcessAll(t *testing.T) {
workers: 50,
runs: 50,
addStores: true,
requests: requestList(100),
deliverTxFunc: func(ctx sdk.Context, req types.RequestDeliverTx) types.ResponseDeliverTx {
// all txs read and write to the same key to maximize conflicts
kv := ctx.MultiStore().GetKVStore(testStoreKey)
val := string(kv.Get(itemKey))

// write to the store with this tx's index
kv.Set(itemKey, req.Tx)

// return what was read from the store (final attempt should be index-1)
return types.ResponseDeliverTx{
Info: val,
}
},
assertions: func(t *testing.T, ctx sdk.Context, res []types.ResponseDeliverTx) {
for idx, response := range res {
if idx == 0 {
require.Equal(t, "", response.Info)
} else {
// the info is what was read from the kv store by the tx
// each tx writes its own index, so the info should be the index of the previous tx
require.Equal(t, fmt.Sprintf("%d", idx-1), response.Info)
}
}
// confirm last write made it to the parent store
latest := ctx.MultiStore().GetKVStore(testStoreKey).Get(itemKey)
require.Equal(t, []byte(fmt.Sprintf("%d", len(res)-1)), latest)
},
expectedErr: nil,
},
{
name: "Test few workers many txs",
workers: 5,
runs: 10,
addStores: true,
requests: requestList(50),
deliverTxFunc: func(ctx sdk.Context, req types.RequestDeliverTx) types.ResponseDeliverTx {
// all txs read and write to the same key to maximize conflicts
Expand Down

0 comments on commit 888cf6c

Please sign in to comment.