Skip to content

Commit

Permalink
w3vm: Added VM Snapshotting (#147)
Browse files Browse the repository at this point in the history
* w3vm: added `ApplyWithSnapshot`

* changed snapshot logic

* added test

* updated testdata

* added benchmark

* improved doc

---------

Co-authored-by: lmittmann <[email protected]>
  • Loading branch information
lmittmann and lmittmann authored Jun 4, 2024
1 parent 56314e0 commit e6659f9
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 22 deletions.
55 changes: 55 additions & 0 deletions w3vm/bench_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package w3vm_test

import (
"fmt"
"math/big"
"testing"

Expand Down Expand Up @@ -102,3 +103,57 @@ func BenchmarkVM(b *testing.B) {
dur := b.Elapsed()
b.ReportMetric(float64(gasSimulated)/dur.Seconds(), "gas/s")
}

func BenchmarkVMSnapshot(b *testing.B) {
depositMsg := &w3types.Message{
From: addr0,
To: &addrWETH,
Value: w3.I("1 ether"),
}

runs := 2
b.Run(fmt.Sprintf("re-run %d", runs), func(b *testing.B) {
for range b.N {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("2 ether")},
}),
)

for i := 0; i < runs; i++ {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}
}
}
})

b.Run(fmt.Sprintf("snapshot %d", runs), func(b *testing.B) {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("2 ether")},
}),
)

for i := 0; i < runs-1; i++ {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}
}

snap := vm.Snapshot()

for range b.N {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}

vm.Rollback(snap.Copy())
}
})
}
13 changes: 9 additions & 4 deletions w3vm/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"errors"

"github.com/ethereum/go-ethereum/common"
gethState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
Expand All @@ -28,13 +28,18 @@ func newDB(fetcher Fetcher) *db {
// state.Database methods //////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////

func (db *db) OpenTrie(root common.Hash) (gethState.Trie, error) { return db, nil }
func (db *db) OpenTrie(root common.Hash) (state.Trie, error) { return db, nil }

func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie gethState.Trie) (gethState.Trie, error) {
func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie state.Trie) (state.Trie, error) {
return db, nil
}

func (*db) CopyTrie(gethState.Trie) gethState.Trie { panic("not implemented") }
func (*db) CopyTrie(trie state.Trie) state.Trie {
if db, ok := trie.(*db); ok {
return db
}
panic("not implemented")
}

func (db *db) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) {
if db.fetcher == nil {
Expand Down
14 changes: 7 additions & 7 deletions w3vm/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, bloc

var (
globalStateStoreMux sync.RWMutex
globalStateStore = make(map[string]*state)
globalStateStore = make(map[string]*testdataState)
)

func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error {
Expand All @@ -224,7 +224,7 @@ func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error {
fmt.Sprintf("%d_%v.json", chainID, f.blockNumber),
)

var s *state
var s *testdataState

// check if the state has already been loaded
globalStateStoreMux.RLock()
Expand Down Expand Up @@ -307,7 +307,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
defer f.mux2.RUnlock()
defer f.mux3.RUnlock()

s := &state{
s := &testdataState{
Accounts: make(map[common.Address]*account, len(f.accounts)),
HeaderHashes: make(map[hexutil.Uint64]common.Hash, len(f.headerHashes)),
}
Expand Down Expand Up @@ -363,12 +363,12 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
// create directory, if it does not exist
dirPath := filepath.Dir(fn)
if _, err := os.Stat(dirPath); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(dirPath, 0775); err != nil {
if err := os.MkdirAll(dirPath, 0o775); err != nil {
return err
}
}

file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0664)
file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664)
if err != nil {
return err
}
Expand All @@ -382,7 +382,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
return nil
}

type state struct {
type testdataState struct {
Accounts map[common.Address]*account `json:"accounts"`
HeaderHashes map[hexutil.Uint64]common.Hash `json:"headerHashes,omitempty"`
}
Expand All @@ -396,7 +396,7 @@ type account struct {

// mergeStates merges the source state into the destination state and returns
// whether the destination state has been modified.
func mergeStates(dst, src *state) (modified bool) {
func mergeStates(dst, src *testdataState) (modified bool) {
// merge accounts
for addr, acc := range src.Accounts {
if dstAcc, ok := dst.Accounts[addr]; !ok {
Expand Down
5 changes: 0 additions & 5 deletions w3vm/testdata/w3vm/1_17034867.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
{
"accounts": {
"0x0000000000000000000000000000000000000000": {
"nonce": "0x0",
"balance": "0x272392e2b6e127d35e3",
"code": "0x"
},
"0x0000000000000000000000000000000000000001": {
"nonce": "0x0",
"balance": "0x2e58c20c74febd3b7",
Expand Down
15 changes: 11 additions & 4 deletions w3vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
gethState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
Expand All @@ -39,7 +39,7 @@ type VM struct {
opts *options

txIndex uint64
db *gethState.StateDB
db *state.StateDB
}

// New creates a new VM, that is configured with the given options.
Expand All @@ -58,7 +58,7 @@ func New(opts ...Option) (*VM, error) {

// set DB
db := newDB(vm.opts.fetcher)
vm.db, _ = gethState.New(w3.Hash0, db, nil)
vm.db, _ = state.New(w3.Hash0, db, nil)
for addr, acc := range vm.opts.preState {
vm.db.SetNonce(addr, acc.Nonce)
if acc.Balance != nil {
Expand Down Expand Up @@ -229,9 +229,16 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err
return val, nil
}

// Snapshot the current state of the VM. The returned state can only be rolled
// back to once. Use [state.StateDB.Copy] if you need to rollback multiple times.
func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() }

// Rollback the state of the VM to the given snapshot.
func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot }

func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) {
nonce := msg.Nonce
if !skipAccChecks && nonce == 0 && msg.From != w3.Addr0 {
if !skipAccChecks && nonce == 0 {
var err error
nonce, err = v.Nonce(msg.From)
if err != nil {
Expand Down
56 changes: 54 additions & 2 deletions w3vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
coreState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -214,6 +214,58 @@ func TestVMApply(t *testing.T) {
}
}

func TestVMSnapshot(t *testing.T) {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("100 ether")},
}),
)

depositMsg := &w3types.Message{
From: addr0,
To: &addrWETH,
Value: w3.I("1 ether"),
}

getBalanceOf := func(t *testing.T, token, acc common.Address) *big.Int {
t.Helper()

var balance *big.Int
if err := vm.CallFunc(token, funcBalanceOf, acc).Returns(&balance); err != nil {
t.Fatalf("Failed to call balanceOf: %v", err)
}
return balance
}

if got := getBalanceOf(t, addrWETH, addr0); got.Sign() != 0 {
t.Fatalf("Balance: want 0 WETH, got %s WETH", w3.FromWei(got, 18))
}

var snap *state.StateDB
for i := range 100 {
if i == 42 {
snap = vm.Snapshot()
}

if _, err := vm.Apply(depositMsg); err != nil {
t.Fatalf("Failed to apply deposit msg: %v", err)
}

want := w3.I(strconv.Itoa(i+1) + " ether")
if got := getBalanceOf(t, addrWETH, addr0); want.Cmp(got) != 0 {
t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18))
}
}

vm.Rollback(snap)

want := w3.I("42 ether")
if got := getBalanceOf(t, addrWETH, addr0); got.Cmp(want) != 0 {
t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18))
}
}

func TestVMCall(t *testing.T) {
tests := []struct {
PreState w3types.State
Expand Down Expand Up @@ -506,7 +558,7 @@ func BenchmarkTransferWETH9(b *testing.B) {
})

b.Run("geth", func(b *testing.B) {
stateDB, _ := coreState.New(common.Hash{}, coreState.NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDB, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDB.SetCode(addrWETH, codeWETH)
stateDB.SetState(addrWETH, w3vm.WETHBalanceSlot(addr0), common.BigToHash(w3.I("1 ether")))

Expand Down

0 comments on commit e6659f9

Please sign in to comment.