From be2c0a28bfaf810310500adcd14fa6af1f1bd0fd Mon Sep 17 00:00:00 2001 From: Unique-Divine Date: Thu, 24 Oct 2024 03:37:23 -0500 Subject: [PATCH] fix: improve StateDB robustness. Make commit debuggable --- x/evm/evmmodule/genesis_test.go | 4 +- x/evm/keeper/erc20.go | 12 +-- x/evm/keeper/erc20_test.go | 4 +- x/evm/keeper/msg_server.go | 4 +- x/evm/precompile/funtoken.go | 2 +- x/evm/precompile/precompile.go | 4 +- x/evm/statedb/journal.go | 44 ++++++++++- x/evm/statedb/journal_test.go | 136 ++++++++++++++++++++++++++------ x/evm/statedb/state_object.go | 22 +++--- x/evm/statedb/statedb.go | 50 ++++++------ 10 files changed, 207 insertions(+), 75 deletions(-) diff --git a/x/evm/evmmodule/genesis_test.go b/x/evm/evmmodule/genesis_test.go index 690745e84..ff19ef46f 100644 --- a/x/evm/evmmodule/genesis_test.go +++ b/x/evm/evmmodule/genesis_test.go @@ -54,11 +54,11 @@ func (s *Suite) TestExportInitGenesis() { s.Require().NoError(err) // Transfer ERC-20 tokens to user A - _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserA, amountToSendA, deps.Ctx) + _, _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserA, amountToSendA, deps.Ctx) s.Require().NoError(err) // Transfer ERC-20 tokens to user B - _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserB, amountToSendB, deps.Ctx) + _, _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserB, amountToSendB, deps.Ctx) s.Require().NoError(err) // Create fungible token from bank coin diff --git a/x/evm/keeper/erc20.go b/x/evm/keeper/erc20.go index 8a41cb237..59c98b628 100644 --- a/x/evm/keeper/erc20.go +++ b/x/evm/keeper/erc20.go @@ -73,23 +73,23 @@ Transfer implements "ERC20.transfer" func (e erc20Calls) Transfer( contract, from, to gethcommon.Address, amount *big.Int, ctx sdk.Context, -) (out bool, err error) { +) (out bool, evmObj *vm.EVM, err error) { input, err := e.ABI.Pack("transfer", to, amount) if err != nil { - return false, fmt.Errorf("failed to pack ABI args: %w", err) + return false, nil, fmt.Errorf("failed to pack ABI args: %w", err) } - resp, _, err := e.CallContractWithInput(ctx, from, &contract, true, input) + resp, evmObj, err := e.CallContractWithInput(ctx, from, &contract, true, input) if err != nil { - return false, err + return false, nil, err } var erc20Bool ERC20Bool err = e.ABI.UnpackIntoInterface(&erc20Bool, "transfer", resp.Ret) if err != nil { - return false, err + return false, nil, err } - return erc20Bool.Value, nil + return erc20Bool.Value, evmObj, nil } // BalanceOf retrieves the balance of an ERC20 token for a specific account. diff --git a/x/evm/keeper/erc20_test.go b/x/evm/keeper/erc20_test.go index 2b49e8192..15807b714 100644 --- a/x/evm/keeper/erc20_test.go +++ b/x/evm/keeper/erc20_test.go @@ -34,7 +34,7 @@ func (s *Suite) TestERC20Calls() { s.T().Log("Transfer - Not enough funds") { - _, err := deps.EvmKeeper.ERC20().Transfer(contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS, big.NewInt(9_420), deps.Ctx) + _, _, err := deps.EvmKeeper.ERC20().Transfer(contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS, big.NewInt(9_420), deps.Ctx) s.ErrorContains(err, "ERC20: transfer amount exceeds balance") // balances unchanged evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, deps.Sender.EthAddr, big.NewInt(0)) @@ -43,7 +43,7 @@ func (s *Suite) TestERC20Calls() { s.T().Log("Transfer - Success (sanity check)") { - _, err := deps.EvmKeeper.ERC20().Transfer(contract, evm.EVM_MODULE_ADDRESS, deps.Sender.EthAddr, big.NewInt(9_420), deps.Ctx) + _, _, err := deps.EvmKeeper.ERC20().Transfer(contract, evm.EVM_MODULE_ADDRESS, deps.Sender.EthAddr, big.NewInt(9_420), deps.Ctx) s.Require().NoError(err) evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, deps.Sender.EthAddr, big.NewInt(9_420)) evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, evm.EVM_MODULE_ADDRESS, big.NewInt(60_000)) diff --git a/x/evm/keeper/msg_server.go b/x/evm/keeper/msg_server.go index 4656f7db2..5c52acc0f 100644 --- a/x/evm/keeper/msg_server.go +++ b/x/evm/keeper/msg_server.go @@ -590,7 +590,7 @@ func (k Keeper) convertCoinNativeERC20( } // unescrow ERC-20 tokens from EVM module address - res, err := k.ERC20().Transfer( + res, _, err := k.ERC20().Transfer( erc20Addr, evm.EVM_MODULE_ADDRESS, recipient, @@ -686,7 +686,7 @@ func (k *Keeper) EmitEthereumTxEvents( // Emit typed events if !evmResp.Failed() { if recipient == nil { // contract creation - var contractAddr = crypto.CreateAddress(msg.From(), msg.Nonce()) + contractAddr := crypto.CreateAddress(msg.From(), msg.Nonce()) _ = ctx.EventManager().EmitTypedEvent(&evm.EventContractDeployed{ Sender: msg.From().Hex(), ContractAddr: contractAddr.String(), diff --git a/x/evm/precompile/funtoken.go b/x/evm/precompile/funtoken.go index b1050dfe2..3b453d597 100644 --- a/x/evm/precompile/funtoken.go +++ b/x/evm/precompile/funtoken.go @@ -141,7 +141,7 @@ func (p precompileFunToken) bankSend( // Caller transfers ERC20 to the EVM account transferTo := evm.EVM_MODULE_ADDRESS - _, err = p.evmKeeper.ERC20().Transfer(erc20, caller, transferTo, amount, ctx) + _, _, err = p.evmKeeper.ERC20().Transfer(erc20, caller, transferTo, amount, ctx) if err != nil { return nil, fmt.Errorf("failed to send from caller to the EVM account: %w", err) } diff --git a/x/evm/precompile/precompile.go b/x/evm/precompile/precompile.go index 8330e0428..a6bbfefc4 100644 --- a/x/evm/precompile/precompile.go +++ b/x/evm/precompile/precompile.go @@ -171,7 +171,7 @@ type OnRunStartResult struct { // } // // ... // // Use res.Ctx for state changes -// // Use res.StateDB.CommitContext() before any non-EVM state changes +// // Use res.StateDB.Commit() before any non-EVM state changes // // to guarantee the context and [statedb.StateDB] are in sync. // } // ``` @@ -189,7 +189,7 @@ func OnRunStart( return } ctx := stateDB.GetContext() - if err = stateDB.CommitContext(ctx); err != nil { + if err = stateDB.Commit(); err != nil { return res, fmt.Errorf("error committing dirty journal entries: %w", err) } diff --git a/x/evm/statedb/journal.go b/x/evm/statedb/journal.go index 0c75cf0e3..2a6ffe953 100644 --- a/x/evm/statedb/journal.go +++ b/x/evm/statedb/journal.go @@ -91,8 +91,46 @@ func (j *journal) Length() int { return len(j.entries) } -func (j *journal) DirtiesLen() int { - return len(j.dirties) +// DirtiesCount is a test helper to inspect how many entries in the journal are +// still dirty (uncommitted). After calling [StateDB.Commit], this function should +// return zero. +func (s *StateDB) DirtiesCount() int { + dirtiesCount := 0 + for _, dirtyCount := range s.Journal.dirties { + dirtiesCount += dirtyCount + } + return dirtiesCount + // for addr := range s.Journal.dirties { + // obj := s.stateObjects[addr] + // // suicided without deletion means obj is dirty + // if obj.Suicided { + // dirtiesCount++ + // // continue + // } + // // dirty code means obj is dirty + // if obj.code != nil && obj.DirtyCode { + // dirtiesCount++ + // // continue + // } + + // // mismatch between dirty storage and origin means obj is dirty + // for k, v := range obj.DirtyStorage { + // // All object (k,v) tuples matching between dirty and origin storage + // // signifies that the entry is committed. + // if v != obj.OriginStorage[k] { + // dirtiesCount++ + // } + // } + // } + // return dirtiesCount +} + +func (s *StateDB) Dirties() map[common.Address]int { + return s.Journal.dirties +} + +func (s *StateDB) Entries() []JournalChange { + return s.Journal.entries } // ------------------------------------------------------ @@ -148,7 +186,7 @@ var _ JournalChange = suicideChange{} func (ch suicideChange) Revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { - obj.suicided = ch.prev + obj.Suicided = ch.prev obj.setBalance(ch.prevbalance) } } diff --git a/x/evm/statedb/journal_test.go b/x/evm/statedb/journal_test.go index 0aadf3c53..b7894bb98 100644 --- a/x/evm/statedb/journal_test.go +++ b/x/evm/statedb/journal_test.go @@ -3,10 +3,14 @@ package statedb_test import ( "fmt" "math/big" + "strings" + "testing" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/core/vm" + serverconfig "github.com/NibiruChain/nibiru/v2/app/server/config" + "github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp" "github.com/NibiruChain/nibiru/v2/x/evm" "github.com/NibiruChain/nibiru/v2/x/evm/embeds" @@ -50,8 +54,6 @@ func (s *Suite) TestPrecompileSnapshots() { assertionsBeforeRun(&deps) - s.T().Log("Populate dirty journal entries") - deployArgs := []any{"name", "SYMBOL", uint8(18)} deployResp, err := evmtest.DeployContract( &deps, @@ -60,34 +62,120 @@ func (s *Suite) TestPrecompileSnapshots() { ) s.Require().NoError(err, deployResp) - // deps.EvmKeeper.ERC20().Mint() contract := deployResp.ContractAddr _, evmObj, err := deps.EvmKeeper.ERC20().Mint( - contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS, + contract, deps.Sender.EthAddr, deps.Sender.EthAddr, big.NewInt(69_420), deps.Ctx, ) s.Require().NoError(err) - stateDB := evmObj.StateDB.(*statedb.StateDB) - s.Equal(50, stateDB.Journal.DirtiesLen()) - // evmtest.TransferWei() - - s.T().Log("Run state transition") - - evmObj = run(&deps) - stateDB, ok := evmObj.StateDB.(*statedb.StateDB) - s.Require().True(ok, "error retrieving StateDB from the EVM") - s.Equal(0, stateDB.Journal.DirtiesLen()) - - _, _ = deps.Ctx.CacheContext() - assertionsAfterRun(&deps) - err = stateDB.Commit() - s.NoError(err) - assertionsAfterRun(&deps) + s.Run("Populate dirty journal entries. Remove with Commit", func() { + stateDB := evmObj.StateDB.(*statedb.StateDB) + s.Equal(0, stateDB.DirtiesCount()) + + randomAcc := evmtest.NewEthPrivAcc().EthAddr + balDelta := evm.NativeToWei(big.NewInt(4)) + // 2 dirties from [createObjectChange, balanceChange] + stateDB.AddBalance(randomAcc, balDelta) + // 1 dirties from [balanceChange] + stateDB.AddBalance(randomAcc, balDelta) + // 1 dirties from [balanceChange] + stateDB.SubBalance(randomAcc, balDelta) + if stateDB.DirtiesCount() != 4 { + debugDirtiesCountMismatch(stateDB, s.T()) + s.FailNow("expected 4 dirty journal changes") + } + + err = stateDB.Commit() // Dirties should be gone + s.NoError(err) + if stateDB.DirtiesCount() != 0 { + debugDirtiesCountMismatch(stateDB, s.T()) + s.FailNow("expected 0 dirty journal changes") + } + }) + + s.Run("Emulate a contract that calls another contract", func() { + randomAcc := evmtest.NewEthPrivAcc().EthAddr + to, amount := randomAcc, big.NewInt(69_000) + input, err := embeds.SmartContract_ERC20Minter.ABI.Pack("transfer", to, amount) + s.Require().NoError(err) + + leftoverGas := serverconfig.DefaultEthCallGasLimit + _, _, err = evmObj.Call( + vm.AccountRef(deps.Sender.EthAddr), + contract, + input, + leftoverGas, + big.NewInt(0), + ) + s.Require().NoError(err) + stateDB := evmObj.StateDB.(*statedb.StateDB) + if stateDB.DirtiesCount() != 2 { + debugDirtiesCountMismatch(stateDB, s.T()) + s.FailNow("expected 2 dirty journal changes") + } + + // The contract calling itself is invalid in this context. + // Note the comment in vm.Contract: + // + // type Contract struct { + // // CallerAddress is the result of the caller which initialized this + // // contract. However when the "call method" is delegated this value + // // needs to be initialized to that of the caller's caller. + // CallerAddress common.Address + // // ... + // } + // // + _, _, err = evmObj.Call( + vm.AccountRef(contract), + contract, + input, + leftoverGas, + big.NewInt(0), + ) + s.Require().ErrorContains(err, vm.ErrExecutionReverted.Error()) + }) + + s.Run("Precompile calls also start and end clean (no dirty changes)", func() { + evmObj = run(&deps) + assertionsAfterRun(&deps) + stateDB, ok := evmObj.StateDB.(*statedb.StateDB) + s.Require().True(ok, "error retrieving StateDB from the EVM") + if stateDB.DirtiesCount() != 0 { + debugDirtiesCountMismatch(stateDB, s.T()) + s.FailNow("expected 0 dirty journal changes") + } + }) +} - s.Equal(0, stateDB.Journal.DirtiesLen()) +func debugDirtiesCountMismatch(db *statedb.StateDB, t *testing.T) string { + lines := []string{} + dirties := db.Dirties() + stateObjects := db.StateObjects() + for addr, dirtyCountForAddr := range dirties { + lines = append(lines, fmt.Sprintf("Dirty addr: %s, dirtyCountForAddr=%d", addr, dirtyCountForAddr)) + + // Inspect the actual state object + maybeObj := stateObjects[addr] + if maybeObj == nil { + lines = append(lines, " no state object found!") + continue + } + obj := *maybeObj + + lines = append(lines, fmt.Sprintf(" balance: %s", obj.Balance())) + lines = append(lines, fmt.Sprintf(" suicided: %v", obj.Suicided)) + lines = append(lines, fmt.Sprintf(" dirtyCode: %v", obj.DirtyCode)) + + // Print storage state + lines = append(lines, fmt.Sprintf(" len(obj.DirtyStorage) entries: %d", len(obj.DirtyStorage))) + for k, v := range obj.DirtyStorage { + lines = append(lines, fmt.Sprintf(" key: %s, value: %s", k.Hex(), v.Hex())) + origVal := obj.OriginStorage[k] + lines = append(lines, fmt.Sprintf(" origin value: %s", origVal.Hex())) + } + } - // s.Require().EqualValues(ctxBefore, deps.Ctx, - // "StateDB should have been committed by the precompile", - // ) + t.Log("debugDirtiesCountMismatch:\n", strings.Join(lines, "\n")) + return "" } diff --git a/x/evm/statedb/state_object.go b/x/evm/statedb/state_object.go index ce9d21f5b..e371beae0 100644 --- a/x/evm/statedb/state_object.go +++ b/x/evm/statedb/state_object.go @@ -115,14 +115,14 @@ type stateObject struct { code []byte // state storage - originStorage Storage - dirtyStorage Storage + OriginStorage Storage + DirtyStorage Storage address common.Address // flags - dirtyCode bool - suicided bool + DirtyCode bool + Suicided bool } // newObject creates a state object. @@ -138,8 +138,8 @@ func newObject(db *StateDB, address common.Address, account Account) *stateObjec address: address, // Reflect the micronibi (unibi) balance in wei account: account.ToWei(), - originStorage: make(Storage), - dirtyStorage: make(Storage), + OriginStorage: make(Storage), + DirtyStorage: make(Storage), } } @@ -223,7 +223,7 @@ func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { func (s *stateObject) setCode(codeHash common.Hash, code []byte) { s.code = code s.account.CodeHash = codeHash[:] - s.dirtyCode = true + s.DirtyCode = true } // SetNonce set nonce to account @@ -256,18 +256,18 @@ func (s *stateObject) Nonce() uint64 { // GetCommittedState query the committed state func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { - if value, cached := s.originStorage[key]; cached { + if value, cached := s.OriginStorage[key]; cached { return value } // If no live objects are available, load it from keeper value := s.db.keeper.GetState(s.db.ctx, s.Address(), key) - s.originStorage[key] = value + s.OriginStorage[key] = value return value } // GetState query the current state (including dirty state) func (s *stateObject) GetState(key common.Hash) common.Hash { - if value, dirty := s.dirtyStorage[key]; dirty { + if value, dirty := s.DirtyStorage[key]; dirty { return value } return s.GetCommittedState(key) @@ -290,5 +290,5 @@ func (s *stateObject) SetState(key common.Hash, value common.Hash) { } func (s *stateObject) setState(key, value common.Hash) { - s.dirtyStorage[key] = value + s.DirtyStorage[key] = value } diff --git a/x/evm/statedb/statedb.go b/x/evm/statedb/statedb.go index ab33cc4ea..a29371431 100644 --- a/x/evm/statedb/statedb.go +++ b/x/evm/statedb/statedb.go @@ -194,7 +194,7 @@ func (s *StateDB) GetRefund() uint64 { func (s *StateDB) HasSuicided(addr common.Address) bool { stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.suicided + return stateObject.Suicided } return false } @@ -275,7 +275,7 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common. return nil } s.keeper.ForEachStorage(s.ctx, addr, func(key, value common.Hash) bool { - if value, dirty := so.dirtyStorage[key]; dirty { + if value, dirty := so.DirtyStorage[key]; dirty { return cb(key, value) } if len(value) > 0 { @@ -346,10 +346,10 @@ func (s *StateDB) Suicide(addr common.Address) bool { } s.Journal.append(suicideChange{ account: &addr, - prev: stateObject.suicided, + prev: stateObject.Suicided, prevbalance: new(big.Int).Set(stateObject.Balance()), }) - stateObject.suicided = true + stateObject.Suicided = true stateObject.account.BalanceWei = new(big.Int) return true @@ -450,43 +450,49 @@ func errorf(format string, args ...any) error { return fmt.Errorf("StateDB error: "+format, args...) } -// CommitContext writes the dirty journal state changes to the EVM Keeper. The -// StateDB object cannot be reused after [CommitContext] has completed. A new +// Commit writes the dirty journal state changes to the EVM Keeper. The +// StateDB object cannot be reused after [Commit] has completed. A new // object needs to be created from the EVM. -func (s *StateDB) CommitContext(ctx sdk.Context) error { +func (s *StateDB) Commit() error { + ctx := s.GetContext() for _, addr := range s.Journal.sortedDirties() { - obj := s.stateObjects[addr] - if obj.suicided { + obj := s.getStateObject(addr) + if obj == nil { + continue + } + if obj.Suicided { + // Invariant: After [StateDB.Suicide] for some address, the + // corresponding account's state object is only available until the + // state is committed. if err := s.keeper.DeleteAccount(ctx, obj.Address()); err != nil { return errorf("failed to delete account: %w", err) } + delete(s.stateObjects, addr) } else { - if obj.code != nil && obj.dirtyCode { + if obj.code != nil && obj.DirtyCode { s.keeper.SetCode(ctx, obj.CodeHash(), obj.code) } if err := s.keeper.SetAccount(ctx, obj.Address(), obj.account.ToNative()); err != nil { return errorf("failed to set account: %w", err) } - for _, key := range obj.dirtyStorage.SortedKeys() { - value := obj.dirtyStorage[key] - // Skip noop changes, persist actual changes - if value == obj.originStorage[key] { + for _, key := range obj.DirtyStorage.SortedKeys() { + dirtyVal := obj.DirtyStorage[key] + // Values that match origin storage are not dirty. + if dirtyVal == obj.OriginStorage[key] { continue } - s.keeper.SetState(ctx, obj.Address(), key, value.Bytes()) + // Persist committed changes + s.keeper.SetState(ctx, obj.Address(), key, dirtyVal.Bytes()) + obj.OriginStorage[key] = dirtyVal } } + // Clear the dirty counts because all state changes have been + // committed. + s.Journal.dirties[addr] = 0 } return nil } -// Commit writes the dirty journal state changes to the EVM Keeper. The -// StateDB object cannot be reused after [CommitContext] has completed. A new -// object needs to be created from the EVM. -func (s *StateDB) Commit() error { - return s.CommitContext(s.ctx) -} - // StateObjects: Returns a copy of the [StateDB.stateObjects] map. func (s *StateDB) StateObjects() map[common.Address]*stateObject { copyOfMap := make(map[common.Address]*stateObject)