Skip to content

Commit

Permalink
fix: improve StateDB robustness. Make commit debuggable
Browse files Browse the repository at this point in the history
  • Loading branch information
Unique-Divine committed Oct 24, 2024
1 parent f9e73d7 commit be2c0a2
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 75 deletions.
4 changes: 2 additions & 2 deletions x/evm/evmmodule/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ func (s *Suite) TestExportInitGenesis() {
s.Require().NoError(err)

// Transfer ERC-20 tokens to user A
_, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserA, amountToSendA, deps.Ctx)
_, _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserA, amountToSendA, deps.Ctx)
s.Require().NoError(err)

// Transfer ERC-20 tokens to user B
_, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserB, amountToSendB, deps.Ctx)
_, _, err = deps.EvmKeeper.ERC20().Transfer(erc20Addr, fromUser, toUserB, amountToSendB, deps.Ctx)
s.Require().NoError(err)

// Create fungible token from bank coin
Expand Down
12 changes: 6 additions & 6 deletions x/evm/keeper/erc20.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,23 @@ Transfer implements "ERC20.transfer"
func (e erc20Calls) Transfer(
contract, from, to gethcommon.Address, amount *big.Int,
ctx sdk.Context,
) (out bool, err error) {
) (out bool, evmObj *vm.EVM, err error) {
input, err := e.ABI.Pack("transfer", to, amount)
if err != nil {
return false, fmt.Errorf("failed to pack ABI args: %w", err)
return false, nil, fmt.Errorf("failed to pack ABI args: %w", err)
}
resp, _, err := e.CallContractWithInput(ctx, from, &contract, true, input)
resp, evmObj, err := e.CallContractWithInput(ctx, from, &contract, true, input)
if err != nil {
return false, err
return false, nil, err
}

var erc20Bool ERC20Bool
err = e.ABI.UnpackIntoInterface(&erc20Bool, "transfer", resp.Ret)
if err != nil {
return false, err
return false, nil, err
}

return erc20Bool.Value, nil
return erc20Bool.Value, evmObj, nil
}

// BalanceOf retrieves the balance of an ERC20 token for a specific account.
Expand Down
4 changes: 2 additions & 2 deletions x/evm/keeper/erc20_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (s *Suite) TestERC20Calls() {

s.T().Log("Transfer - Not enough funds")
{
_, err := deps.EvmKeeper.ERC20().Transfer(contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS, big.NewInt(9_420), deps.Ctx)
_, _, err := deps.EvmKeeper.ERC20().Transfer(contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS, big.NewInt(9_420), deps.Ctx)
s.ErrorContains(err, "ERC20: transfer amount exceeds balance")
// balances unchanged
evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, deps.Sender.EthAddr, big.NewInt(0))
Expand All @@ -43,7 +43,7 @@ func (s *Suite) TestERC20Calls() {

s.T().Log("Transfer - Success (sanity check)")
{
_, err := deps.EvmKeeper.ERC20().Transfer(contract, evm.EVM_MODULE_ADDRESS, deps.Sender.EthAddr, big.NewInt(9_420), deps.Ctx)
_, _, err := deps.EvmKeeper.ERC20().Transfer(contract, evm.EVM_MODULE_ADDRESS, deps.Sender.EthAddr, big.NewInt(9_420), deps.Ctx)
s.Require().NoError(err)
evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, deps.Sender.EthAddr, big.NewInt(9_420))
evmtest.AssertERC20BalanceEqual(s.T(), deps, contract, evm.EVM_MODULE_ADDRESS, big.NewInt(60_000))
Expand Down
4 changes: 2 additions & 2 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ func (k Keeper) convertCoinNativeERC20(
}

// unescrow ERC-20 tokens from EVM module address
res, err := k.ERC20().Transfer(
res, _, err := k.ERC20().Transfer(
erc20Addr,
evm.EVM_MODULE_ADDRESS,
recipient,
Expand Down Expand Up @@ -686,7 +686,7 @@ func (k *Keeper) EmitEthereumTxEvents(
// Emit typed events
if !evmResp.Failed() {
if recipient == nil { // contract creation
var contractAddr = crypto.CreateAddress(msg.From(), msg.Nonce())
contractAddr := crypto.CreateAddress(msg.From(), msg.Nonce())
_ = ctx.EventManager().EmitTypedEvent(&evm.EventContractDeployed{
Sender: msg.From().Hex(),
ContractAddr: contractAddr.String(),
Expand Down
2 changes: 1 addition & 1 deletion x/evm/precompile/funtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (p precompileFunToken) bankSend(

// Caller transfers ERC20 to the EVM account
transferTo := evm.EVM_MODULE_ADDRESS
_, err = p.evmKeeper.ERC20().Transfer(erc20, caller, transferTo, amount, ctx)
_, _, err = p.evmKeeper.ERC20().Transfer(erc20, caller, transferTo, amount, ctx)
if err != nil {
return nil, fmt.Errorf("failed to send from caller to the EVM account: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions x/evm/precompile/precompile.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ type OnRunStartResult struct {
// }
// // ...
// // Use res.Ctx for state changes
// // Use res.StateDB.CommitContext() before any non-EVM state changes
// // Use res.StateDB.Commit() before any non-EVM state changes
// // to guarantee the context and [statedb.StateDB] are in sync.
// }
// ```
Expand All @@ -189,7 +189,7 @@ func OnRunStart(
return
}
ctx := stateDB.GetContext()
if err = stateDB.CommitContext(ctx); err != nil {
if err = stateDB.Commit(); err != nil {
return res, fmt.Errorf("error committing dirty journal entries: %w", err)
}

Expand Down
44 changes: 41 additions & 3 deletions x/evm/statedb/journal.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,46 @@ func (j *journal) Length() int {
return len(j.entries)
}

func (j *journal) DirtiesLen() int {
return len(j.dirties)
// DirtiesCount is a test helper to inspect how many entries in the journal are
// still dirty (uncommitted). After calling [StateDB.Commit], this function should
// return zero.
func (s *StateDB) DirtiesCount() int {
dirtiesCount := 0
for _, dirtyCount := range s.Journal.dirties {
dirtiesCount += dirtyCount
}
return dirtiesCount
// for addr := range s.Journal.dirties {
// obj := s.stateObjects[addr]
// // suicided without deletion means obj is dirty
// if obj.Suicided {
// dirtiesCount++
// // continue
// }
// // dirty code means obj is dirty
// if obj.code != nil && obj.DirtyCode {
// dirtiesCount++
// // continue
// }

// // mismatch between dirty storage and origin means obj is dirty
// for k, v := range obj.DirtyStorage {
// // All object (k,v) tuples matching between dirty and origin storage
// // signifies that the entry is committed.
// if v != obj.OriginStorage[k] {
// dirtiesCount++
// }
// }
// }
// return dirtiesCount
}

func (s *StateDB) Dirties() map[common.Address]int {
return s.Journal.dirties
}

func (s *StateDB) Entries() []JournalChange {
return s.Journal.entries
}

// ------------------------------------------------------
Expand Down Expand Up @@ -148,7 +186,7 @@ var _ JournalChange = suicideChange{}
func (ch suicideChange) Revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.suicided = ch.prev
obj.Suicided = ch.prev
obj.setBalance(ch.prevbalance)
}
}
Expand Down
136 changes: 112 additions & 24 deletions x/evm/statedb/journal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ package statedb_test
import (
"fmt"
"math/big"
"strings"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/ethereum/go-ethereum/core/vm"

serverconfig "github.com/NibiruChain/nibiru/v2/app/server/config"

"github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp"
"github.com/NibiruChain/nibiru/v2/x/evm"
"github.com/NibiruChain/nibiru/v2/x/evm/embeds"
Expand Down Expand Up @@ -50,8 +54,6 @@ func (s *Suite) TestPrecompileSnapshots() {

assertionsBeforeRun(&deps)

s.T().Log("Populate dirty journal entries")

deployArgs := []any{"name", "SYMBOL", uint8(18)}
deployResp, err := evmtest.DeployContract(
&deps,
Expand All @@ -60,34 +62,120 @@ func (s *Suite) TestPrecompileSnapshots() {
)
s.Require().NoError(err, deployResp)

// deps.EvmKeeper.ERC20().Mint()
contract := deployResp.ContractAddr
_, evmObj, err := deps.EvmKeeper.ERC20().Mint(
contract, deps.Sender.EthAddr, evm.EVM_MODULE_ADDRESS,
contract, deps.Sender.EthAddr, deps.Sender.EthAddr,
big.NewInt(69_420), deps.Ctx,
)
s.Require().NoError(err)

stateDB := evmObj.StateDB.(*statedb.StateDB)
s.Equal(50, stateDB.Journal.DirtiesLen())
// evmtest.TransferWei()

s.T().Log("Run state transition")

evmObj = run(&deps)
stateDB, ok := evmObj.StateDB.(*statedb.StateDB)
s.Require().True(ok, "error retrieving StateDB from the EVM")
s.Equal(0, stateDB.Journal.DirtiesLen())

_, _ = deps.Ctx.CacheContext()
assertionsAfterRun(&deps)
err = stateDB.Commit()
s.NoError(err)
assertionsAfterRun(&deps)
s.Run("Populate dirty journal entries. Remove with Commit", func() {
stateDB := evmObj.StateDB.(*statedb.StateDB)
s.Equal(0, stateDB.DirtiesCount())

randomAcc := evmtest.NewEthPrivAcc().EthAddr
balDelta := evm.NativeToWei(big.NewInt(4))
// 2 dirties from [createObjectChange, balanceChange]
stateDB.AddBalance(randomAcc, balDelta)
// 1 dirties from [balanceChange]
stateDB.AddBalance(randomAcc, balDelta)
// 1 dirties from [balanceChange]
stateDB.SubBalance(randomAcc, balDelta)
if stateDB.DirtiesCount() != 4 {
debugDirtiesCountMismatch(stateDB, s.T())
s.FailNow("expected 4 dirty journal changes")
}

err = stateDB.Commit() // Dirties should be gone
s.NoError(err)
if stateDB.DirtiesCount() != 0 {
debugDirtiesCountMismatch(stateDB, s.T())
s.FailNow("expected 0 dirty journal changes")
}
})

s.Run("Emulate a contract that calls another contract", func() {
randomAcc := evmtest.NewEthPrivAcc().EthAddr
to, amount := randomAcc, big.NewInt(69_000)
input, err := embeds.SmartContract_ERC20Minter.ABI.Pack("transfer", to, amount)
s.Require().NoError(err)

leftoverGas := serverconfig.DefaultEthCallGasLimit
_, _, err = evmObj.Call(
vm.AccountRef(deps.Sender.EthAddr),
contract,
input,
leftoverGas,
big.NewInt(0),
)
s.Require().NoError(err)
stateDB := evmObj.StateDB.(*statedb.StateDB)
if stateDB.DirtiesCount() != 2 {
debugDirtiesCountMismatch(stateDB, s.T())
s.FailNow("expected 2 dirty journal changes")
}

// The contract calling itself is invalid in this context.
// Note the comment in vm.Contract:
//
// type Contract struct {
// // CallerAddress is the result of the caller which initialized this
// // contract. However when the "call method" is delegated this value
// // needs to be initialized to that of the caller's caller.
// CallerAddress common.Address
// // ...
// }
// //
_, _, err = evmObj.Call(
vm.AccountRef(contract),
contract,
input,
leftoverGas,
big.NewInt(0),
)
s.Require().ErrorContains(err, vm.ErrExecutionReverted.Error())
})

s.Run("Precompile calls also start and end clean (no dirty changes)", func() {
evmObj = run(&deps)
assertionsAfterRun(&deps)
stateDB, ok := evmObj.StateDB.(*statedb.StateDB)
s.Require().True(ok, "error retrieving StateDB from the EVM")
if stateDB.DirtiesCount() != 0 {
debugDirtiesCountMismatch(stateDB, s.T())
s.FailNow("expected 0 dirty journal changes")
}
})
}

s.Equal(0, stateDB.Journal.DirtiesLen())
func debugDirtiesCountMismatch(db *statedb.StateDB, t *testing.T) string {
lines := []string{}
dirties := db.Dirties()
stateObjects := db.StateObjects()
for addr, dirtyCountForAddr := range dirties {
lines = append(lines, fmt.Sprintf("Dirty addr: %s, dirtyCountForAddr=%d", addr, dirtyCountForAddr))

// Inspect the actual state object
maybeObj := stateObjects[addr]
if maybeObj == nil {
lines = append(lines, " no state object found!")
continue
}
obj := *maybeObj

lines = append(lines, fmt.Sprintf(" balance: %s", obj.Balance()))
lines = append(lines, fmt.Sprintf(" suicided: %v", obj.Suicided))
lines = append(lines, fmt.Sprintf(" dirtyCode: %v", obj.DirtyCode))

// Print storage state
lines = append(lines, fmt.Sprintf(" len(obj.DirtyStorage) entries: %d", len(obj.DirtyStorage)))
for k, v := range obj.DirtyStorage {
lines = append(lines, fmt.Sprintf(" key: %s, value: %s", k.Hex(), v.Hex()))
origVal := obj.OriginStorage[k]
lines = append(lines, fmt.Sprintf(" origin value: %s", origVal.Hex()))
}
}

// s.Require().EqualValues(ctxBefore, deps.Ctx,
// "StateDB should have been committed by the precompile",
// )
t.Log("debugDirtiesCountMismatch:\n", strings.Join(lines, "\n"))
return ""
}
22 changes: 11 additions & 11 deletions x/evm/statedb/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ type stateObject struct {
code []byte

// state storage
originStorage Storage
dirtyStorage Storage
OriginStorage Storage
DirtyStorage Storage

address common.Address

// flags
dirtyCode bool
suicided bool
DirtyCode bool
Suicided bool
}

// newObject creates a state object.
Expand All @@ -138,8 +138,8 @@ func newObject(db *StateDB, address common.Address, account Account) *stateObjec
address: address,
// Reflect the micronibi (unibi) balance in wei
account: account.ToWei(),
originStorage: make(Storage),
dirtyStorage: make(Storage),
OriginStorage: make(Storage),
DirtyStorage: make(Storage),
}
}

Expand Down Expand Up @@ -223,7 +223,7 @@ func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
s.code = code
s.account.CodeHash = codeHash[:]
s.dirtyCode = true
s.DirtyCode = true
}

// SetNonce set nonce to account
Expand Down Expand Up @@ -256,18 +256,18 @@ func (s *stateObject) Nonce() uint64 {

// GetCommittedState query the committed state
func (s *stateObject) GetCommittedState(key common.Hash) common.Hash {
if value, cached := s.originStorage[key]; cached {
if value, cached := s.OriginStorage[key]; cached {
return value
}
// If no live objects are available, load it from keeper
value := s.db.keeper.GetState(s.db.ctx, s.Address(), key)
s.originStorage[key] = value
s.OriginStorage[key] = value
return value
}

// GetState query the current state (including dirty state)
func (s *stateObject) GetState(key common.Hash) common.Hash {
if value, dirty := s.dirtyStorage[key]; dirty {
if value, dirty := s.DirtyStorage[key]; dirty {
return value
}
return s.GetCommittedState(key)
Expand All @@ -290,5 +290,5 @@ func (s *stateObject) SetState(key common.Hash, value common.Hash) {
}

func (s *stateObject) setState(key, value common.Hash) {
s.dirtyStorage[key] = value
s.DirtyStorage[key] = value
}
Loading

0 comments on commit be2c0a2

Please sign in to comment.