diff --git a/x/dex/contract/abci.go b/x/dex/contract/abci.go index e270f937ec..7a1dd30786 100644 --- a/x/dex/contract/abci.go +++ b/x/dex/contract/abci.go @@ -58,7 +58,7 @@ func EndBlockerAtomic(ctx sdk.Context, keeper *keeper.Keeper, validContractsInfo handleDeposits(spanCtx, cachedCtx, env, keeper, tracer) runner := NewParallelRunner(func(contract types.ContractInfoV2) { - orderMatchingRunnable(spanCtx, cachedCtx, env, keeper, contract, tracer) + OrderMatchingRunnable(spanCtx, cachedCtx, env, keeper, contract, tracer) }, validContractsInfo, cachedCtx) _, err := logging.LogIfNotDoneAfter(ctx.Logger(), func() (struct{}, error) { @@ -227,7 +227,7 @@ func handleUnfulfilledMarketOrders(ctx context.Context, sdkCtx sdk.Context, env } } -func orderMatchingRunnable(ctx context.Context, sdkContext sdk.Context, env *environment, keeper *keeper.Keeper, contractInfo types.ContractInfoV2, tracer *otrace.Tracer) { +func OrderMatchingRunnable(ctx context.Context, sdkContext sdk.Context, env *environment, keeper *keeper.Keeper, contractInfo types.ContractInfoV2, tracer *otrace.Tracer) { defer func() { if err := recover(); err != nil { msg := fmt.Sprintf("PANIC RECOVERED during order matching: %s", err) @@ -237,7 +237,7 @@ func orderMatchingRunnable(ctx context.Context, sdkContext sdk.Context, env *env } } }() - _, span := (*tracer).Start(ctx, "orderMatchingRunnable") + _, span := (*tracer).Start(ctx, "OrderMatchingRunnable") defer span.End() defer telemetry.MeasureSince(time.Now(), "dex", "order_matching_runnable") defer func() { diff --git a/x/dex/contract/abci_test.go b/x/dex/contract/abci_test.go index 194743acd4..3d0b7650b4 100644 --- a/x/dex/contract/abci_test.go +++ b/x/dex/contract/abci_test.go @@ -1,6 +1,7 @@ package contract_test import ( + "context" "testing" "time" @@ -32,3 +33,11 @@ func TestTransferRentFromDexToCollector(t *testing.T) { collectorBalance := bankkeeper.GetBalance(ctx, testApp.AccountKeeper.GetModuleAddress(authtypes.FeeCollectorName), "usei") require.Equal(t, int64(80), collectorBalance.Amount.Int64()) } + +func TestOrderMatchingRunnablePanicHandler(t *testing.T) { + testApp := keepertest.TestApp() + ctx := testApp.BaseApp.NewContext(false, tmproto.Header{Time: time.Now()}) + require.NotPanics(t, func() { + contract.OrderMatchingRunnable(context.Background(), ctx, nil, nil, types.ContractInfoV2{}, nil) + }) +} diff --git a/x/dex/contract/runner.go b/x/dex/contract/runner.go index e0c2a2623e..36b02ac7b2 100644 --- a/x/dex/contract/runner.go +++ b/x/dex/contract/runner.go @@ -138,6 +138,14 @@ func (r *ParallelRunner) wrapRunnable(contractAddr types.ContractAddress) { if err := recover(); err != nil { r.sdkCtx.Logger().Error(fmt.Sprintf("panic in parallel runner recovered: %s", err)) } + + atomic.AddInt64(&r.inProgressCnt, -1) // this has to happen after any potential increment to readyCnt + select { + case r.someContractFinished <- struct{}{}: + case <-r.done: + // make sure other goroutines can also receive from 'done' + r.done <- struct{}{} + } }() contractInfo, _ := r.contractAddrToInfo.Load(contractAddr) @@ -165,12 +173,4 @@ func (r *ParallelRunner) wrapRunnable(contractAddr types.ContractAddress) { } } } - - atomic.AddInt64(&r.inProgressCnt, -1) // this has to happen after any potential increment to readyCnt - select { - case r.someContractFinished <- struct{}{}: - case <-r.done: - // make sure other goroutines can also receive from 'done' - r.done <- struct{}{} - } } diff --git a/x/dex/contract/runner_test.go b/x/dex/contract/runner_test.go index 5a45d776f2..bf7e8c1d70 100644 --- a/x/dex/contract/runner_test.go +++ b/x/dex/contract/runner_test.go @@ -29,6 +29,10 @@ func idleRunnable(_ types.ContractInfoV2) { atomic.AddInt64(&counter, 1) } +func panicRunnable(_ types.ContractInfoV2) { + panic("") +} + func dependencyCheckRunnable(contractInfo types.ContractInfoV2) { if contractInfo.ContractAddr == "C" { _, hasA := dependencyCheck.Load("A") @@ -126,3 +130,12 @@ func TestRunnerParallelContractWithInvalidDependency(t *testing.T) { _, hasC := dependencyCheck.Load("C") require.False(t, hasC) } + +func TestRunnerPanicContract(t *testing.T) { + contractInfo := types.ContractInfoV2{ + ContractAddr: "A", + NumIncomingDependencies: 0, + } + runner := contract.NewParallelRunner(panicRunnable, []types.ContractInfoV2{contractInfo}, sdkCtx) + require.NotPanics(t, runner.Run) +}