diff --git a/baseapp/baseapp.go b/baseapp/baseapp.go index c34761276..259b02f2b 100644 --- a/baseapp/baseapp.go +++ b/baseapp/baseapp.go @@ -157,8 +157,9 @@ type BaseApp struct { //nolint: maligned ChainID string - votesInfoLock sync.RWMutex - commitLock *sync.Mutex + votesInfoLock sync.RWMutex + commitLock *sync.Mutex + checkTxStateLock *sync.RWMutex compactionInterval uint64 @@ -274,7 +275,8 @@ func NewBaseApp( TracingInfo: &tracing.Info{ Tracer: &tr, }, - commitLock: &sync.Mutex{}, + commitLock: &sync.Mutex{}, + checkTxStateLock: &sync.RWMutex{}, } app.TracingInfo.SetContext(context.Background()) @@ -529,6 +531,8 @@ func (app *BaseApp) IsSealed() bool { return app.sealed } func (app *BaseApp) setCheckState(header tmproto.Header) { ms := app.cms.CacheMultiStore() ctx := sdk.NewContext(ms, header, true, app.logger).WithMinGasPrices(app.minGasPrices) + app.checkTxStateLock.Lock() + defer app.checkTxStateLock.Unlock() if app.checkState == nil { app.checkState = &state{ ms: ms, @@ -978,6 +982,9 @@ func (app *BaseApp) runTx(ctx sdk.Context, mode runTxMode, tx sdk.Tx, checksum [ // append the events in the order of occurrence result.Events = append(anteEvents, result.Events...) } + if ctx.CheckTxCallback() != nil { + ctx.CheckTxCallback()(err) + } return gInfo, result, anteEvents, priority, pendingTxChecker, err } @@ -1168,5 +1175,7 @@ func (app *BaseApp) ReloadDB() error { } func (app *BaseApp) GetCheckCtx() sdk.Context { + app.checkTxStateLock.RLock() + defer app.checkTxStateLock.RUnlock() return app.checkState.ctx } diff --git a/types/context.go b/types/context.go index 29731f88e..2ee371e13 100644 --- a/types/context.go +++ b/types/context.go @@ -41,6 +41,7 @@ type Context struct { eventManager *EventManager priority int64 // The tx priority, only relevant in CheckTx pendingTxChecker abci.PendingTxChecker // Checker for pending transaction, only relevant in CheckTx + checkTxCallback func(error) // callback to make at the end of CheckTx. Input param is the error (nil-able) of `runMsgs` txBlockingChannels acltypes.MessageAccessOpsChannelMapping txCompletionChannels acltypes.MessageAccessOpsChannelMapping @@ -121,6 +122,10 @@ func (c Context) PendingTxChecker() abci.PendingTxChecker { return c.pendingTxChecker } +func (c Context) CheckTxCallback() func(error) { + return c.checkTxCallback +} + func (c Context) TxCompletionChannels() acltypes.MessageAccessOpsChannelMapping { return c.txCompletionChannels } @@ -359,6 +364,11 @@ func (c Context) WithPendingTxChecker(checker abci.PendingTxChecker) Context { return c } +func (c Context) WithCheckTxCallback(checkTxCallback func(error)) Context { + c.checkTxCallback = checkTxCallback + return c +} + // TODO: remove??? func (c Context) IsZero() bool { return c.ms == nil