diff --git a/core/bloombits/matcher.go b/core/bloombits/matcher.go index 0d2f6f950d86..d8f932041bc0 100644 --- a/core/bloombits/matcher.go +++ b/core/bloombits/matcher.go @@ -83,7 +83,7 @@ type Matcher struct { retrievals chan chan *Retrieval // Retriever processes waiting for task allocations deliveries chan *Retrieval // Retriever processes waiting for task response deliveries - running uint32 // Atomic flag whether a session is live or not + running atomic.Bool // Atomic flag whether a session is live or not } // NewMatcher creates a new pipeline for retrieving bloom bit streams and doing @@ -146,10 +146,10 @@ func (m *Matcher) addScheduler(idx uint) { // channel is closed. func (m *Matcher) Start(ctx context.Context, begin, end uint64, results chan uint64) (*MatcherSession, error) { // Make sure we're not creating concurrent sessions - if atomic.SwapUint32(&m.running, 1) == 1 { + if m.running.Swap(true) { return nil, errors.New("matcher already running") } - defer atomic.StoreUint32(&m.running, 0) + defer m.running.Store(false) // Initiate a new matching round session := &MatcherSession{ diff --git a/core/bloombits/matcher_test.go b/core/bloombits/matcher_test.go index 93d4632b8587..36764c3f174b 100644 --- a/core/bloombits/matcher_test.go +++ b/core/bloombits/matcher_test.go @@ -160,7 +160,7 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, start, blocks uint64, in } } // Track the number of retrieval requests made - var requested uint32 + var requested atomic.Uint32 // Start the matching session for the filter and the retriever goroutines quit := make(chan struct{}) @@ -208,15 +208,15 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, start, blocks uint64, in session.Close() close(quit) - if retrievals != 0 && requested != retrievals { - t.Errorf("filter = %v blocks = %v intermittent = %v: request count mismatch, have #%v, want #%v", filter, blocks, intermittent, requested, retrievals) + if retrievals != 0 && requested.Load() != retrievals { + t.Errorf("filter = %v blocks = %v intermittent = %v: request count mismatch, have #%v, want #%v", filter, blocks, intermittent, requested.Load(), retrievals) } - return requested + return requested.Load() } // startRetrievers starts a batch of goroutines listening for section requests // and serving them. -func startRetrievers(session *MatcherSession, quit chan struct{}, retrievals *uint32, batch int) { +func startRetrievers(session *MatcherSession, quit chan struct{}, retrievals *atomic.Uint32, batch int) { requests := make(chan chan *Retrieval) for i := 0; i < 10; i++ { @@ -238,7 +238,7 @@ func startRetrievers(session *MatcherSession, quit chan struct{}, retrievals *ui for i, section := range task.Sections { if rand.Int()%4 != 0 { // Handle occasional missing deliveries task.Bitsets[i] = generateBitset(task.Bit, section) - atomic.AddUint32(retrievals, 1) + retrievals.Add(1) } } request <- task diff --git a/core/bloombits/scheduler_test.go b/core/bloombits/scheduler_test.go index 49e113c117ba..dcaaa915258a 100644 --- a/core/bloombits/scheduler_test.go +++ b/core/bloombits/scheduler_test.go @@ -45,13 +45,13 @@ func testScheduler(t *testing.T, clients int, fetchers int, requests int) { fetch := make(chan *request, 16) defer close(fetch) - var delivered uint32 + var delivered atomic.Uint32 for i := 0; i < fetchers; i++ { go func() { defer fetchPend.Done() for req := range fetch { - atomic.AddUint32(&delivered, 1) + delivered.Add(1) f.deliver([]uint64{ req.section + uint64(requests), // Non-requested data (ensure it doesn't go out of bounds) @@ -97,7 +97,7 @@ func testScheduler(t *testing.T, clients int, fetchers int, requests int) { } pend.Wait() - if have := atomic.LoadUint32(&delivered); int(have) != requests { + if have := delivered.Load(); int(have) != requests { t.Errorf("request count mismatch: have %v, want %v", have, requests) } } diff --git a/core/state/database.go b/core/state/database.go index a67c414d9307..82f620b46083 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -68,36 +68,36 @@ type Trie interface { // TODO(fjl): remove this when StateTrie is removed GetKey([]byte) []byte - // TryGetStorage returns the value for key stored in the trie. The value bytes + // GetStorage returns the value for key stored in the trie. The value bytes // must not be modified by the caller. If a node was not found in the database, // a trie.MissingNodeError is returned. - TryGetStorage(addr common.Address, key []byte) ([]byte, error) + GetStorage(addr common.Address, key []byte) ([]byte, error) - // TryGetAccount abstracts an account read from the trie. It retrieves the + // GetAccount abstracts an account read from the trie. It retrieves the // account blob from the trie with provided account address and decodes it // with associated decoding algorithm. If the specified account is not in // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes // are missing or the account blob is incorrect for decoding), an error will // be returned. - TryGetAccount(address common.Address) (*types.StateAccount, error) + GetAccount(address common.Address) (*types.StateAccount, error) - // TryUpdateStorage associates key with value in the trie. If value has length zero, + // UpdateStorage associates key with value in the trie. If value has length zero, // any existing value is deleted from the trie. The value bytes must not be modified // by the caller while they are stored in the trie. If a node was not found in the // database, a trie.MissingNodeError is returned. - TryUpdateStorage(addr common.Address, key, value []byte) error + UpdateStorage(addr common.Address, key, value []byte) error - // TryUpdateAccount abstracts an account write to the trie. It encodes the + // UpdateAccount abstracts an account write to the trie. It encodes the // provided account object with associated algorithm and then updates it // in the trie with provided address. - TryUpdateAccount(address common.Address, account *types.StateAccount) error + UpdateAccount(address common.Address, account *types.StateAccount) error - // TryDeleteStorage removes any existing value for key from the trie. If a node + // DeleteStorage removes any existing value for key from the trie. If a node // was not found in the database, a trie.MissingNodeError is returned. - TryDeleteStorage(addr common.Address, key []byte) error + DeleteStorage(addr common.Address, key []byte) error - // TryDeleteAccount abstracts an account deletion from the trie. - TryDeleteAccount(address common.Address) error + // DeleteAccount abstracts an account deletion from the trie. + DeleteAccount(address common.Address) error // Hash returns the root hash of the trie. It does not write to the database and // can be used even if the trie doesn't have one. diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go index f916a020e7bc..4701acccd399 100644 --- a/core/state/snapshot/difflayer.go +++ b/core/state/snapshot/difflayer.go @@ -103,7 +103,7 @@ type diffLayer struct { memory uint64 // Approximate guess as to how much memory we use root common.Hash // Root hash to which this snapshot diff belongs to - stale uint32 // Signals that the layer became stale (state progressed) + stale atomic.Bool // Signals that the layer became stale (state progressed) // destructSet is a very special helper marker. If an account is marked as // deleted, then it's recorded in this set. However it's allowed that an account @@ -267,7 +267,7 @@ func (dl *diffLayer) Parent() snapshot { // Stale return whether this layer has become stale (was flattened across) or if // it's still live. func (dl *diffLayer) Stale() bool { - return atomic.LoadUint32(&dl.stale) != 0 + return dl.stale.Load() } // Account directly retrieves the account associated with a particular hash in @@ -449,7 +449,7 @@ func (dl *diffLayer) flatten() snapshot { // Before actually writing all our data to the parent, first ensure that the // parent hasn't been 'corrupted' by someone else already flattening into it - if atomic.SwapUint32(&parent.stale, 1) != 0 { + if parent.stale.Swap(true) { panic("parent diff layer is stale") // we've flattened into the same parent from two children, boo } // Overwrite all the updated accounts blindly, merge the sorted list diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 0f3fa2c7a4f0..2e57a059dd7c 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "sync" - "sync/atomic" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -272,7 +271,7 @@ func (t *Tree) Disable() { case *diffLayer: // If the layer is a simple diff, simply mark as stale layer.lock.Lock() - atomic.StoreUint32(&layer.stale, 1) + layer.stale.Store(true) layer.lock.Unlock() default: @@ -726,7 +725,7 @@ func (t *Tree) Rebuild(root common.Hash) { case *diffLayer: // If the layer is a simple diff, simply mark as stale layer.lock.Lock() - atomic.StoreUint32(&layer.stale, 1) + layer.stale.Store(true) layer.lock.Unlock() default: diff --git a/core/state/state_object.go b/core/state/state_object.go index 936f6dae2f13..95fe6f52fa84 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -201,7 +201,7 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has s.db.setError(err) return common.Hash{} } - enc, err = tr.TryGetStorage(s.address, key.Bytes()) + enc, err = tr.GetStorage(s.address, key.Bytes()) if metrics.EnabledExpensive { s.db.StorageReads += time.Since(start) } @@ -294,7 +294,7 @@ func (s *stateObject) updateTrie(db Database) (Trie, error) { var v []byte if (value == common.Hash{}) { - if err := tr.TryDeleteStorage(s.address, key[:]); err != nil { + if err := tr.DeleteStorage(s.address, key[:]); err != nil { s.db.setError(err) return nil, err } @@ -302,7 +302,7 @@ func (s *stateObject) updateTrie(db Database) (Trie, error) { } else { // Encoding []byte cannot fail, ok to ignore the error. v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) - if err := tr.TryUpdateStorage(s.address, key[:], v); err != nil { + if err := tr.UpdateStorage(s.address, key[:], v); err != nil { s.db.setError(err) return nil, err } diff --git a/core/state/statedb.go b/core/state/statedb.go index 38583f51db20..b89f4bfd8c32 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -512,7 +512,7 @@ func (s *StateDB) updateStateObject(obj *stateObject) { } // Encode the account and update the account trie addr := obj.Address() - if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil { + if err := s.trie.UpdateAccount(addr, &obj.data); err != nil { s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) } @@ -533,7 +533,7 @@ func (s *StateDB) deleteStateObject(obj *stateObject) { } // Delete the account from the trie addr := obj.Address() - if err := s.trie.TryDeleteAccount(addr); err != nil { + if err := s.trie.DeleteAccount(addr); err != nil { s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) } } @@ -587,7 +587,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if data == nil { start := time.Now() var err error - data, err = s.trie.TryGetAccount(addr) + data, err = s.trie.GetAccount(addr) if metrics.EnabledExpensive { s.AccountReads += time.Since(start) } diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index edd370fbdbb6..844f72fc106b 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -339,9 +339,9 @@ func (sf *subfetcher) loop() { sf.dups++ } else { if len(task) == common.AddressLength { - sf.trie.TryGetAccount(common.BytesToAddress(task)) + sf.trie.GetAccount(common.BytesToAddress(task)) } else { - sf.trie.TryGetStorage(sf.addr, task) + sf.trie.GetStorage(sf.addr, task) } sf.seen[string(task)] = struct{}{} } diff --git a/core/types/hashing.go b/core/types/hashing.go index 3df75432a4b4..71295c107cf8 100644 --- a/core/types/hashing.go +++ b/core/types/hashing.go @@ -83,7 +83,7 @@ func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte { return common.CopyBytes(buf.Bytes()) } -// DeriveSha creates the tree hashes of transactions and receipts in a block header. +// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header. func DeriveSha(list DerivableList, hasher TrieHasher) common.Hash { hasher.Reset() diff --git a/core/vm/evm.go b/core/vm/evm.go index d78ea0792664..0084e4ca51fc 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -114,8 +114,7 @@ type EVM struct { // used throughout the execution of the tx. interpreter *EVMInterpreter // abort is used to abort the EVM calling operations - // NOTE: must be set atomically - abort int32 + abort atomic.Bool // callGasTemp holds the gas available for the current call. This is needed because the // available gas is calculated in gasCall* according to the 63/64 rule and later // applied in opCall*. @@ -147,12 +146,12 @@ func (evm *EVM) Reset(txCtx TxContext, statedb StateDB) { // Cancel cancels any running EVM operation. This may be called concurrently and // it's safe to be called multiple times. func (evm *EVM) Cancel() { - atomic.StoreInt32(&evm.abort, 1) + evm.abort.Store(true) } // Cancelled returns true if Cancel has been called func (evm *EVM) Cancelled() bool { - return atomic.LoadInt32(&evm.abort) == 1 + return evm.abort.Load() } // Interpreter returns the current interpreter diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 77b6e02bfcc7..21d8cc215ad9 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -17,8 +17,6 @@ package vm import ( - "sync/atomic" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -531,7 +529,7 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b } func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { - if atomic.LoadInt32(&interpreter.evm.abort) != 0 { + if interpreter.evm.abort.Load() { return nil, errStopToken } pos := scope.Stack.pop() @@ -543,7 +541,7 @@ func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byt } func opJumpi(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { - if atomic.LoadInt32(&interpreter.evm.abort) != 0 { + if interpreter.evm.abort.Load() { return nil, errStopToken } pos, cond := scope.Stack.pop(), scope.Stack.pop() diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index d7c940044010..55781ac54b74 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -418,7 +418,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP if err != nil { return nil, nil } - acc, err := accTrie.TryGetAccountByHash(account) + acc, err := accTrie.GetAccountByHash(account) if err != nil || acc == nil { return nil, nil } @@ -510,7 +510,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s case 1: // If we're only retrieving an account trie node, fetch it directly - blob, resolved, err := accTrie.TryGetNode(pathset[0]) + blob, resolved, err := accTrie.GetNode(pathset[0]) loads += resolved // always account database reads, even for failures if err != nil { break @@ -524,7 +524,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s if snap == nil { // We don't have the requested state snapshotted yet (or it is stale), // but can look up the account via the trie instead. - account, err := accTrie.TryGetAccountByHash(common.BytesToHash(pathset[0])) + account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0])) loads += 8 // We don't know the exact cost of lookup, this is an estimate if err != nil || account == nil { break @@ -545,7 +545,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s break } for _, path := range pathset[1:] { - blob, resolved, err := stTrie.TryGetNode(path) + blob, resolved, err := stTrie.GetNode(path) loads += resolved // always account database reads, even for failures if err != nil { break diff --git a/eth/tracers/internal/tracetest/calltrace_test.go b/eth/tracers/internal/tracetest/calltrace_test.go index 62182e3a82b4..c5405e245179 100644 --- a/eth/tracers/internal/tracetest/calltrace_test.go +++ b/eth/tracers/internal/tracetest/calltrace_test.go @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/tracers" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" @@ -260,75 +259,121 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { } } -// TestZeroValueToNotExitCall tests the calltracer(s) on the following: -// Tx to A, A calls B with zero value. B does not already exist. -// Expected: that enter/exit is invoked and the inner call is shown in the result -func TestZeroValueToNotExitCall(t *testing.T) { - var to = common.HexToAddress("0x00000000000000000000000000000000deadbeef") - privkey, err := crypto.HexToECDSA("0000000000000000deadbeef00000000000000000000000000000000deadbeef") - if err != nil { - t.Fatalf("err %v", err) - } - signer := types.NewEIP155Signer(big.NewInt(1)) - tx, err := types.SignNewTx(privkey, signer, &types.LegacyTx{ - GasPrice: big.NewInt(0), - Gas: 50000, - To: &to, - }) - if err != nil { - t.Fatalf("err %v", err) - } - origin, _ := signer.Sender(tx) - txContext := vm.TxContext{ - Origin: origin, - GasPrice: big.NewInt(1), - } - context := vm.BlockContext{ - CanTransfer: core.CanTransfer, - Transfer: core.Transfer, - Coinbase: common.Address{}, - BlockNumber: new(big.Int).SetUint64(8000000), - Time: 5, - Difficulty: big.NewInt(0x30000), - GasLimit: uint64(6000000), - } - var code = []byte{ - byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), // in and outs zero - byte(vm.DUP1), byte(vm.PUSH1), 0xff, byte(vm.GAS), // value=0,address=0xff, gas=GAS - byte(vm.CALL), +func TestInternals(t *testing.T) { + var ( + to = common.HexToAddress("0x00000000000000000000000000000000deadbeef") + origin = common.HexToAddress("0x00000000000000000000000000000000feed") + txContext = vm.TxContext{ + Origin: origin, + GasPrice: big.NewInt(1), + } + context = vm.BlockContext{ + CanTransfer: core.CanTransfer, + Transfer: core.Transfer, + Coinbase: common.Address{}, + BlockNumber: new(big.Int).SetUint64(8000000), + Time: 5, + Difficulty: big.NewInt(0x30000), + GasLimit: uint64(6000000), + } + ) + mkTracer := func(name string, cfg json.RawMessage) tracers.Tracer { + tr, err := tracers.DefaultDirectory.New(name, nil, cfg) + if err != nil { + t.Fatalf("failed to create call tracer: %v", err) + } + return tr } - var alloc = core.GenesisAlloc{ - to: core.GenesisAccount{ - Nonce: 1, - Code: code, + + for _, tc := range []struct { + name string + code []byte + tracer tracers.Tracer + want string + }{ + { + // TestZeroValueToNotExitCall tests the calltracer(s) on the following: + // Tx to A, A calls B with zero value. B does not already exist. + // Expected: that enter/exit is invoked and the inner call is shown in the result + name: "ZeroValueToNotExitCall", + code: []byte{ + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), // in and outs zero + byte(vm.DUP1), byte(vm.PUSH1), 0xff, byte(vm.GAS), // value=0,address=0xff, gas=GAS + byte(vm.CALL), + }, + tracer: mkTracer("callTracer", nil), + want: `{"from":"0x000000000000000000000000000000000000feed","gas":"0x7148","gasUsed":"0x54d8","to":"0x00000000000000000000000000000000deadbeef","input":"0x","calls":[{"from":"0x00000000000000000000000000000000deadbeef","gas":"0x6cbf","gasUsed":"0x0","to":"0x00000000000000000000000000000000000000ff","input":"0x","value":"0x0","type":"CALL"}],"value":"0x0","type":"CALL"}`, }, - origin: core.GenesisAccount{ - Nonce: 0, - Balance: big.NewInt(500000000000000), + { + name: "Stack depletion in LOG0", + code: []byte{byte(vm.LOG3)}, + tracer: mkTracer("callTracer", json.RawMessage(`{ "withLog": true }`)), + want: `{"from":"0x000000000000000000000000000000000000feed","gas":"0x7148","gasUsed":"0xc350","to":"0x00000000000000000000000000000000deadbeef","input":"0x","error":"stack underflow (0 \u003c=\u003e 5)","value":"0x0","type":"CALL"}`, }, - } - _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false) - // Create the tracer, the EVM environment and run it - tracer, err := tracers.DefaultDirectory.New("callTracer", nil, nil) - if err != nil { - t.Fatalf("failed to create call tracer: %v", err) - } - evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) - msg, err := core.TransactionToMessage(tx, signer, nil) - if err != nil { - t.Fatalf("failed to prepare transaction for tracing: %v", err) - } - st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) - if _, err = st.TransitionDb(); err != nil { - t.Fatalf("failed to execute transaction: %v", err) - } - // Retrieve the trace result and compare against the etalon - res, err := tracer.GetResult() - if err != nil { - t.Fatalf("failed to retrieve trace result: %v", err) - } - wantStr := `{"from":"0x682a80a6f560eec50d54e63cbeda1c324c5f8d1b","gas":"0x7148","gasUsed":"0x54d8","to":"0x00000000000000000000000000000000deadbeef","input":"0x","calls":[{"from":"0x00000000000000000000000000000000deadbeef","gas":"0x6cbf","gasUsed":"0x0","to":"0x00000000000000000000000000000000000000ff","input":"0x","value":"0x0","type":"CALL"}],"value":"0x0","type":"CALL"}` - if string(res) != wantStr { - t.Fatalf("trace mismatch\n have: %v\n want: %v\n", string(res), wantStr) + { + name: "Mem expansion in LOG0", + code: []byte{ + byte(vm.PUSH1), 0x1, + byte(vm.PUSH1), 0x0, + byte(vm.MSTORE), + byte(vm.PUSH1), 0xff, + byte(vm.PUSH1), 0x0, + byte(vm.LOG0), + }, + tracer: mkTracer("callTracer", json.RawMessage(`{ "withLog": true }`)), + want: `{"from":"0x000000000000000000000000000000000000feed","gas":"0x7148","gasUsed":"0x5b9e","to":"0x00000000000000000000000000000000deadbeef","input":"0x","logs":[{"address":"0x00000000000000000000000000000000deadbeef","topics":[],"data":"0x000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}],"value":"0x0","type":"CALL"}`, + }, + { + // Leads to OOM on the prestate tracer + name: "Prestate-tracer - mem expansion in CREATE2", + code: []byte{ + byte(vm.PUSH1), 0x1, + byte(vm.PUSH1), 0x0, + byte(vm.MSTORE), + byte(vm.PUSH1), 0x1, + byte(vm.PUSH5), 0xff, 0xff, 0xff, 0xff, 0xff, + byte(vm.PUSH1), 0x1, + byte(vm.PUSH1), 0x0, + byte(vm.CREATE2), + byte(vm.PUSH1), 0xff, + byte(vm.PUSH1), 0x0, + byte(vm.LOG0), + }, + tracer: mkTracer("prestateTracer", json.RawMessage(`{ "withLog": true }`)), + want: `{"0x0000000000000000000000000000000000000000":{"balance":"0x0"},"0x000000000000000000000000000000000000feed":{"balance":"0x1c6bf52640350"},"0x00000000000000000000000000000000deadbeef":{"balance":"0x0","code":"0x6001600052600164ffffffffff60016000f560ff6000a0"}}`, + }, + } { + _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), + core.GenesisAlloc{ + to: core.GenesisAccount{ + Code: tc.code, + }, + origin: core.GenesisAccount{ + Balance: big.NewInt(500000000000000), + }, + }, false) + evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tc.tracer}) + msg := &core.Message{ + To: &to, + From: origin, + Value: big.NewInt(0), + GasLimit: 50000, + GasPrice: big.NewInt(0), + GasFeeCap: big.NewInt(0), + GasTipCap: big.NewInt(0), + SkipAccountChecks: false, + } + st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(msg.GasLimit)) + if _, err := st.TransitionDb(); err != nil { + t.Fatalf("test %v: failed to execute transaction: %v", tc.name, err) + } + // Retrieve the trace result and compare against the expected + res, err := tc.tracer.GetResult() + if err != nil { + t.Fatalf("test %v: failed to retrieve trace result: %v", tc.name, err) + } + if string(res) != tc.want { + t.Fatalf("test %v: trace mismatch\n have: %v\n want: %v\n", tc.name, string(res), tc.want) + } } } diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index 8e52f5b21077..33298f3097fa 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -32,10 +32,6 @@ import ( jsassets "github.com/ethereum/go-ethereum/eth/tracers/js/internal/tracers" ) -const ( - memoryPadLimit = 1024 * 1024 -) - var assetTracers = make(map[string]string) // init retrieves the JavaScript transaction tracers included in go-ethereum. @@ -571,14 +567,10 @@ func (mo *memoryObj) slice(begin, end int64) ([]byte, error) { if end < begin || begin < 0 { return nil, fmt.Errorf("tracer accessed out of bound memory: offset %d, end %d", begin, end) } - mlen := mo.memory.Len() - if end-int64(mlen) > memoryPadLimit { - return nil, fmt.Errorf("tracer reached limit for padding memory slice: end %d, memorySize %d", end, mlen) + slice, err := tracers.GetMemoryCopyPadded(mo.memory, begin, end-begin) + if err != nil { + return nil, err } - slice := make([]byte, end-begin) - end = min(end, int64(mo.memory.Len())) - ptr := mo.memory.GetPtr(begin, end-begin) - copy(slice[:], ptr[:]) return slice, nil } @@ -959,10 +951,3 @@ func (l *steplog) setupObject() *goja.Object { o.Set("contract", l.contract.setupObject()) return o } - -func min(a, b int64) int64 { - if a < b { - return a - } - return b -} diff --git a/eth/tracers/js/tracer_test.go b/eth/tracers/js/tracer_test.go index 524d17474930..1f3753721116 100644 --- a/eth/tracers/js/tracer_test.go +++ b/eth/tracers/js/tracer_test.go @@ -150,12 +150,12 @@ func TestTracer(t *testing.T) { }, { code: "{res: [], step: function(log) { if (log.op.toString() === 'STOP') { this.res.push(log.memory.slice(5, 1025 * 1024)) } }, fault: function() {}, result: function() { return this.res }}", want: "", - fail: "tracer reached limit for padding memory slice: end 1049600, memorySize 32 at step (:1:83(20)) in server-side tracer function 'step'", + fail: "reached limit for padding memory slice: 1049568 at step (:1:83(20)) in server-side tracer function 'step'", contract: []byte{byte(vm.PUSH1), byte(0xff), byte(vm.PUSH1), byte(0x00), byte(vm.MSTORE8), byte(vm.STOP)}, }, } { if have, err := execTracer(tt.code, tt.contract); tt.want != string(have) || tt.fail != err { - t.Errorf("testcase %d: expected return value to be '%s' got '%s', error to be '%s' got '%s'\n\tcode: %v", i, tt.want, string(have), tt.fail, err, tt.code) + t.Errorf("testcase %d: expected return value to be \n'%s'\n\tgot\n'%s'\nerror to be\n'%s'\n\tgot\n'%s'\n\tcode: %v", i, tt.want, string(have), tt.fail, err, tt.code) } } } diff --git a/eth/tracers/native/call.go b/eth/tracers/native/call.go index 02ee152a5aa1..88fc2be40c25 100644 --- a/eth/tracers/native/call.go +++ b/eth/tracers/native/call.go @@ -148,6 +148,10 @@ func (t *callTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { // CaptureState implements the EVMLogger interface to trace a single step of VM execution. func (t *callTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + // skip if the previous op caused an error + if err != nil { + return + } // Only logs need to be captured via opcode processing if !t.config.WithLog { return @@ -176,7 +180,12 @@ func (t *callTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, sco topics[i] = common.Hash(topic.Bytes32()) } - data := scope.Memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) + data, err := tracers.GetMemoryCopyPadded(scope.Memory, int64(mStart.Uint64()), int64(mSize.Uint64())) + if err != nil { + // mSize was unrealistically large + return + } + log := callLog{Address: scope.Contract.Address(), Topics: topics, Data: hexutil.Bytes(data)} t.callstack[len(t.callstack)-1].Logs = append(t.callstack[len(t.callstack)-1].Logs, log) } diff --git a/eth/tracers/native/prestate.go b/eth/tracers/native/prestate.go index 948d09ef767c..42ec4fe7828b 100644 --- a/eth/tracers/native/prestate.go +++ b/eth/tracers/native/prestate.go @@ -133,6 +133,9 @@ func (t *prestateTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { // CaptureState implements the EVMLogger interface to trace a single step of VM execution. func (t *prestateTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + if err != nil { + return + } stack := scope.Stack stackData := stack.Data() stackLen := len(stackData) diff --git a/eth/tracers/tracers.go b/eth/tracers/tracers.go index 856f52a10de3..023f73ef3708 100644 --- a/eth/tracers/tracers.go +++ b/eth/tracers/tracers.go @@ -19,6 +19,7 @@ package tracers import ( "encoding/json" + "fmt" "math/big" "github.com/ethereum/go-ethereum/common" @@ -95,3 +96,27 @@ func (d *directory) IsJS(name string) bool { // JS eval will execute JS code return true } + +const ( + memoryPadLimit = 1024 * 1024 +) + +// GetMemoryCopyPadded returns offset + size as a new slice. +// It zero-pads the slice if it extends beyond memory bounds. +func GetMemoryCopyPadded(m *vm.Memory, offset, size int64) ([]byte, error) { + if offset < 0 || size < 0 { + return nil, fmt.Errorf("offset or size must not be negative") + } + if int(offset+size) < m.Len() { // slice fully inside memory + return m.GetCopy(offset, size), nil + } + paddingNeeded := int(offset+size) - m.Len() + if paddingNeeded > memoryPadLimit { + return nil, fmt.Errorf("reached limit for padding memory slice: %d", paddingNeeded) + } + cpy := make([]byte, size) + if overlap := int64(m.Len()) - offset; overlap > 0 { + copy(cpy, m.GetPtr(offset, overlap)) + } + return cpy, nil +} diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 7c5ec65650ee..46182db48b98 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -109,3 +109,41 @@ func BenchmarkTransactionTrace(b *testing.B) { tracer.Reset() } } + +func TestMemCopying(t *testing.T) { + for i, tc := range []struct { + memsize int64 + offset int64 + size int64 + wantErr string + wantSize int + }{ + {0, 0, 100, "", 100}, // Should pad up to 100 + {0, 100, 0, "", 0}, // No need to pad (0 size) + {100, 50, 100, "", 100}, // Should pad 100-150 + {100, 50, 5, "", 5}, // Wanted range fully within memory + {100, -50, 0, "offset or size must not be negative", 0}, // Errror + {0, 1, 1024*1024 + 1, "reached limit for padding memory slice: 1048578", 0}, // Errror + {10, 0, 1024*1024 + 100, "reached limit for padding memory slice: 1048666", 0}, // Errror + + } { + mem := vm.NewMemory() + mem.Resize(uint64(tc.memsize)) + cpy, err := GetMemoryCopyPadded(mem, tc.offset, tc.size) + if want := tc.wantErr; want != "" { + if err == nil { + t.Fatalf("test %d: want '%v' have no error", i, want) + } + if have := err.Error(); want != have { + t.Fatalf("test %d: want '%v' have '%v'", i, want, have) + } + continue + } + if err != nil { + t.Fatalf("test %d: unexpected error: %v", i, err) + } + if want, have := tc.wantSize, len(cpy); have != want { + t.Fatalf("test %d: want %v have %v", i, want, have) + } + } +} diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index c8353b25ae21..47beaf63cc0e 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -320,7 +320,14 @@ func (ec *Client) SyncProgress(ctx context.Context) (*ethereum.SyncProgress, err // SubscribeNewHead subscribes to notifications about the current blockchain head // on the given channel. func (ec *Client) SubscribeNewHead(ctx context.Context, ch chan<- *types.Header) (ethereum.Subscription, error) { - return ec.c.EthSubscribe(ctx, ch, "newHeads") + sub, err := ec.c.EthSubscribe(ctx, ch, "newHeads") + if err != nil { + // Defensively prefer returning nil interface explicitly on error-path, instead + // of letting default golang behavior wrap it with non-nil interface that stores + // nil concrete type value. + return nil, err + } + return sub, nil } // State Access @@ -389,7 +396,14 @@ func (ec *Client) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuer if err != nil { return nil, err } - return ec.c.EthSubscribe(ctx, ch, "logs", arg) + sub, err := ec.c.EthSubscribe(ctx, ch, "logs", arg) + if err != nil { + // Defensively prefer returning nil interface explicitly on error-path, instead + // of letting default golang behavior wrap it with non-nil interface that stores + // nil concrete type value. + return nil, err + } + return sub, nil } func toFilterArg(q ethereum.FilterQuery) (interface{}, error) { diff --git a/graphql/graphql.go b/graphql/graphql.go index 0c13cc80f555..3c637d2d2c03 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -24,6 +24,7 @@ import ( "math/big" "sort" "strconv" + "sync" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" @@ -80,12 +81,26 @@ type Account struct { r *Resolver address common.Address blockNrOrHash rpc.BlockNumberOrHash + state *state.StateDB + mu sync.Mutex } // getState fetches the StateDB object for an account. func (a *Account) getState(ctx context.Context) (*state.StateDB, error) { + a.mu.Lock() + defer a.mu.Unlock() + if a.state != nil { + return a.state, nil + } state, _, err := a.r.backend.StateAndHeaderByNumberOrHash(ctx, a.blockNrOrHash) - return state, err + if err != nil { + return nil, err + } + a.state = state + // Cache the state object. This is done so that concurrent resolvers + // don't have to fetch the object from DB individually. + a.state.GetOrNewStateObject(a.address) + return a.state, nil } func (a *Account) Address(ctx context.Context) (common.Address, error) { @@ -184,32 +199,39 @@ func (at *AccessTuple) StorageKeys(ctx context.Context) []common.Hash { // Transaction represents an Ethereum transaction. // backend and hash are mandatory; all others will be fetched when required. type Transaction struct { - r *Resolver - hash common.Hash + r *Resolver + hash common.Hash // Must be present after initialization + mu sync.Mutex + // mu protects following resources tx *types.Transaction block *Block index uint64 } // resolve returns the internal transaction object, fetching it if needed. -func (t *Transaction) resolve(ctx context.Context) (*types.Transaction, error) { - if t.tx == nil { - // Try to return an already finalized transaction - tx, blockHash, _, index, err := t.r.backend.GetTransaction(ctx, t.hash) - if err == nil && tx != nil { - t.tx = tx - blockNrOrHash := rpc.BlockNumberOrHashWithHash(blockHash, false) - t.block = &Block{ - r: t.r, - numberOrHash: &blockNrOrHash, - } - t.index = index - return t.tx, nil +// It also returns the block the tx blongs to, unless it is a pending tx. +func (t *Transaction) resolve(ctx context.Context) (*types.Transaction, *Block, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.tx != nil { + return t.tx, t.block, nil + } + // Try to return an already finalized transaction + tx, blockHash, _, index, err := t.r.backend.GetTransaction(ctx, t.hash) + if err == nil && tx != nil { + t.tx = tx + blockNrOrHash := rpc.BlockNumberOrHashWithHash(blockHash, false) + t.block = &Block{ + r: t.r, + numberOrHash: &blockNrOrHash, + hash: blockHash, } - // No finalized transaction, try to retrieve it from the pool - t.tx = t.r.backend.GetPoolTransaction(t.hash) + t.index = index + return t.tx, t.block, nil } - return t.tx, nil + // No finalized transaction, try to retrieve it from the pool + t.tx = t.r.backend.GetPoolTransaction(t.hash) + return t.tx, nil, nil } func (t *Transaction) Hash(ctx context.Context) common.Hash { @@ -217,7 +239,7 @@ func (t *Transaction) Hash(ctx context.Context) common.Hash { } func (t *Transaction) InputData(ctx context.Context) (hexutil.Bytes, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Bytes{}, err } @@ -225,7 +247,7 @@ func (t *Transaction) InputData(ctx context.Context) (hexutil.Bytes, error) { } func (t *Transaction) Gas(ctx context.Context) (hexutil.Uint64, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return 0, err } @@ -233,7 +255,7 @@ func (t *Transaction) Gas(ctx context.Context) (hexutil.Uint64, error) { } func (t *Transaction) GasPrice(ctx context.Context) (hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, block, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Big{}, err } @@ -241,8 +263,8 @@ func (t *Transaction) GasPrice(ctx context.Context) (hexutil.Big, error) { case types.AccessListTxType: return hexutil.Big(*tx.GasPrice()), nil case types.DynamicFeeTxType: - if t.block != nil { - if baseFee, _ := t.block.BaseFeePerGas(ctx); baseFee != nil { + if block != nil { + if baseFee, _ := block.BaseFeePerGas(ctx); baseFee != nil { // price = min(tip, gasFeeCap - baseFee) + baseFee return (hexutil.Big)(*math.BigMin(new(big.Int).Add(tx.GasTipCap(), baseFee.ToInt()), tx.GasFeeCap())), nil } @@ -254,15 +276,15 @@ func (t *Transaction) GasPrice(ctx context.Context) (hexutil.Big, error) { } func (t *Transaction) EffectiveGasPrice(ctx context.Context) (*hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, block, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } // Pending tx - if t.block == nil { + if block == nil { return nil, nil } - header, err := t.block.resolveHeader(ctx) + header, err := block.resolveHeader(ctx) if err != nil || header == nil { return nil, err } @@ -273,7 +295,7 @@ func (t *Transaction) EffectiveGasPrice(ctx context.Context) (*hexutil.Big, erro } func (t *Transaction) MaxFeePerGas(ctx context.Context) (*hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } @@ -288,7 +310,7 @@ func (t *Transaction) MaxFeePerGas(ctx context.Context) (*hexutil.Big, error) { } func (t *Transaction) MaxPriorityFeePerGas(ctx context.Context) (*hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } @@ -303,15 +325,15 @@ func (t *Transaction) MaxPriorityFeePerGas(ctx context.Context) (*hexutil.Big, e } func (t *Transaction) EffectiveTip(ctx context.Context) (*hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, block, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } // Pending tx - if t.block == nil { + if block == nil { return nil, nil } - header, err := t.block.resolveHeader(ctx) + header, err := block.resolveHeader(ctx) if err != nil || header == nil { return nil, err } @@ -327,7 +349,7 @@ func (t *Transaction) EffectiveTip(ctx context.Context) (*hexutil.Big, error) { } func (t *Transaction) Value(ctx context.Context) (hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Big{}, err } @@ -338,7 +360,7 @@ func (t *Transaction) Value(ctx context.Context) (hexutil.Big, error) { } func (t *Transaction) Nonce(ctx context.Context) (hexutil.Uint64, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return 0, err } @@ -346,7 +368,7 @@ func (t *Transaction) Nonce(ctx context.Context) (hexutil.Uint64, error) { } func (t *Transaction) To(ctx context.Context, args BlockNumberArgs) (*Account, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } @@ -362,7 +384,7 @@ func (t *Transaction) To(ctx context.Context, args BlockNumberArgs) (*Account, e } func (t *Transaction) From(ctx context.Context, args BlockNumberArgs) (*Account, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } @@ -376,17 +398,20 @@ func (t *Transaction) From(ctx context.Context, args BlockNumberArgs) (*Account, } func (t *Transaction) Block(ctx context.Context) (*Block, error) { - if _, err := t.resolve(ctx); err != nil { + _, block, err := t.resolve(ctx) + if err != nil { return nil, err } - return t.block, nil + return block, nil } func (t *Transaction) Index(ctx context.Context) (*int32, error) { - if _, err := t.resolve(ctx); err != nil { + _, block, err := t.resolve(ctx) + if err != nil { return nil, err } - if t.block == nil { + // Pending tx + if block == nil { return nil, nil } index := int32(t.index) @@ -395,13 +420,15 @@ func (t *Transaction) Index(ctx context.Context) (*int32, error) { // getReceipt returns the receipt associated with this transaction, if any. func (t *Transaction) getReceipt(ctx context.Context) (*types.Receipt, error) { - if _, err := t.resolve(ctx); err != nil { + _, block, err := t.resolve(ctx) + if err != nil { return nil, err } - if t.block == nil { + // Pending tx + if block == nil { return nil, nil } - receipts, err := t.block.resolveReceipts(ctx) + receipts, err := block.resolveReceipts(ctx) if err != nil { return nil, err } @@ -451,28 +478,25 @@ func (t *Transaction) CreatedContract(ctx context.Context, args BlockNumberArgs) } func (t *Transaction) Logs(ctx context.Context) (*[]*Log, error) { - if _, err := t.resolve(ctx); err != nil { + _, block, err := t.resolve(ctx) + if err != nil { return nil, err } - if t.block == nil { + // Pending tx + if block == nil { return nil, nil } - if _, ok := t.block.numberOrHash.Hash(); !ok { - header, err := t.r.backend.HeaderByNumberOrHash(ctx, *t.block.numberOrHash) - if err != nil { - return nil, err - } - hash := header.Hash() - t.block.numberOrHash.BlockHash = &hash + h, err := block.Hash(ctx) + if err != nil { + return nil, err } - return t.getLogs(ctx) + return t.getLogs(ctx, h) } // getLogs returns log objects for the given tx. // Assumes block hash is resolved. -func (t *Transaction) getLogs(ctx context.Context) (*[]*Log, error) { +func (t *Transaction) getLogs(ctx context.Context, hash common.Hash) (*[]*Log, error) { var ( - hash, _ = t.block.numberOrHash.Hash() filter = t.r.filterSystem.NewBlockFilter(hash, nil, nil) logs, err = filter.Logs(ctx) ) @@ -494,7 +518,7 @@ func (t *Transaction) getLogs(ctx context.Context) (*[]*Log, error) { } func (t *Transaction) Type(ctx context.Context) (*int32, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil { return nil, err } @@ -503,7 +527,7 @@ func (t *Transaction) Type(ctx context.Context) (*int32, error) { } func (t *Transaction) AccessList(ctx context.Context) (*[]*AccessTuple, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return nil, err } @@ -519,7 +543,7 @@ func (t *Transaction) AccessList(ctx context.Context) (*[]*AccessTuple, error) { } func (t *Transaction) R(ctx context.Context) (hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Big{}, err } @@ -528,7 +552,7 @@ func (t *Transaction) R(ctx context.Context) (hexutil.Big, error) { } func (t *Transaction) S(ctx context.Context) (hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Big{}, err } @@ -537,7 +561,7 @@ func (t *Transaction) S(ctx context.Context) (hexutil.Big, error) { } func (t *Transaction) V(ctx context.Context) (hexutil.Big, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Big{}, err } @@ -546,7 +570,7 @@ func (t *Transaction) V(ctx context.Context) (hexutil.Big, error) { } func (t *Transaction) Raw(ctx context.Context) (hexutil.Bytes, error) { - tx, err := t.resolve(ctx) + tx, _, err := t.resolve(ctx) if err != nil || tx == nil { return hexutil.Bytes{}, err } @@ -568,16 +592,20 @@ type BlockType int // when required. type Block struct { r *Resolver - numberOrHash *rpc.BlockNumberOrHash - hash common.Hash - header *types.Header - block *types.Block - receipts []*types.Receipt + numberOrHash *rpc.BlockNumberOrHash // Field resolvers assume numberOrHash is always present + mu sync.Mutex + // mu protects following resources + hash common.Hash // Must be resolved during initialization + header *types.Header + block *types.Block + receipts []*types.Receipt } // resolve returns the internal Block object representing this block, fetching // it if necessary. func (b *Block) resolve(ctx context.Context) (*types.Block, error) { + b.mu.Lock() + defer b.mu.Unlock() if b.block != nil { return b.block, nil } @@ -587,10 +615,10 @@ func (b *Block) resolve(ctx context.Context) (*types.Block, error) { } var err error b.block, err = b.r.backend.BlockByNumberOrHash(ctx, *b.numberOrHash) - if b.block != nil && b.header == nil { - b.header = b.block.Header() - if hash, ok := b.numberOrHash.Hash(); ok { - b.hash = hash + if b.block != nil { + b.hash = b.block.Hash() + if b.header == nil { + b.header = b.block.Header() } } return b.block, err @@ -600,39 +628,39 @@ func (b *Block) resolve(ctx context.Context) (*types.Block, error) { // if necessary. Call this function instead of `resolve` unless you need the // additional data (transactions and uncles). func (b *Block) resolveHeader(ctx context.Context) (*types.Header, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.header != nil { + return b.header, nil + } if b.numberOrHash == nil && b.hash == (common.Hash{}) { return nil, errBlockInvariant } var err error - if b.header == nil { - if b.hash != (common.Hash{}) { - b.header, err = b.r.backend.HeaderByHash(ctx, b.hash) - } else { - b.header, err = b.r.backend.HeaderByNumberOrHash(ctx, *b.numberOrHash) - } + b.header, err = b.r.backend.HeaderByNumberOrHash(ctx, *b.numberOrHash) + if err != nil { + return nil, err + } + if b.hash == (common.Hash{}) { + b.hash = b.header.Hash() } - return b.header, err + return b.header, nil } // resolveReceipts returns the list of receipts for this block, fetching them // if necessary. func (b *Block) resolveReceipts(ctx context.Context) ([]*types.Receipt, error) { - if b.receipts == nil { - hash := b.hash - if hash == (common.Hash{}) { - header, err := b.resolveHeader(ctx) - if err != nil { - return nil, err - } - hash = header.Hash() - } - receipts, err := b.r.backend.GetReceipts(ctx, hash) - if err != nil { - return nil, err - } - b.receipts = receipts + b.mu.Lock() + defer b.mu.Unlock() + if b.receipts != nil { + return b.receipts, nil } - return b.receipts, nil + receipts, err := b.r.backend.GetReceipts(ctx, b.hash) + if err != nil { + return nil, err + } + b.receipts = receipts + return receipts, nil } func (b *Block) Number(ctx context.Context) (Long, error) { @@ -645,13 +673,8 @@ func (b *Block) Number(ctx context.Context) (Long, error) { } func (b *Block) Hash(ctx context.Context) (common.Hash, error) { - if b.hash == (common.Hash{}) { - header, err := b.resolveHeader(ctx) - if err != nil { - return common.Hash{}, err - } - b.hash = header.Hash() - } + b.mu.Lock() + defer b.mu.Unlock() return b.hash, nil } @@ -705,11 +728,18 @@ func (b *Block) Parent(ctx context.Context) (*Block, error) { if b.header == nil || b.header.Number.Uint64() < 1 { return nil, nil } - num := rpc.BlockNumberOrHashWithNumber(rpc.BlockNumber(b.header.Number.Uint64() - 1)) + var ( + num = rpc.BlockNumber(b.header.Number.Uint64() - 1) + hash = b.header.ParentHash + numOrHash = rpc.BlockNumberOrHash{ + BlockNumber: &num, + BlockHash: &hash, + } + ) return &Block{ r: b.r, - numberOrHash: &num, - hash: b.header.ParentHash, + numberOrHash: &numOrHash, + hash: hash, }, nil } @@ -798,6 +828,7 @@ func (b *Block) Ommers(ctx context.Context) (*[]*Block, error) { r: b.r, numberOrHash: &blockNumberOrHash, header: uncle, + hash: uncle.Hash(), }) } return &ret, nil @@ -820,17 +851,13 @@ func (b *Block) LogsBloom(ctx context.Context) (hexutil.Bytes, error) { } func (b *Block) TotalDifficulty(ctx context.Context) (hexutil.Big, error) { - h := b.hash - if h == (common.Hash{}) { - header, err := b.resolveHeader(ctx) - if err != nil { - return hexutil.Big{}, err - } - h = header.Hash() + hash, err := b.Hash(ctx) + if err != nil { + return hexutil.Big{}, err } - td := b.r.backend.GetTd(ctx, h) + td := b.r.backend.GetTd(ctx, hash) if td == nil { - return hexutil.Big{}, fmt.Errorf("total difficulty not found %x", b.hash) + return hexutil.Big{}, fmt.Errorf("total difficulty not found %x", hash) } return hexutil.Big(*td), nil } @@ -948,6 +975,7 @@ func (b *Block) OmmerAt(ctx context.Context, args struct{ Index int32 }) (*Block r: b.r, numberOrHash: &blockNumberOrHash, header: uncle, + hash: uncle.Hash(), }, nil } @@ -997,15 +1025,11 @@ func (b *Block) Logs(ctx context.Context, args struct{ Filter BlockFilterCriteri if args.Filter.Topics != nil { topics = *args.Filter.Topics } - hash := b.hash - if hash == (common.Hash{}) { - header, err := b.resolveHeader(ctx) - if err != nil { - return nil, err - } - hash = header.Hash() - } // Construct the range filter + hash, err := b.Hash(ctx) + if err != nil { + return nil, err + } filter := b.r.filterSystem.NewBlockFilter(hash, addresses, topics) // Run the filter and return all the logs @@ -1015,12 +1039,6 @@ func (b *Block) Logs(ctx context.Context, args struct{ Filter BlockFilterCriteri func (b *Block) Account(ctx context.Context, args struct { Address common.Address }) (*Account, error) { - if b.numberOrHash == nil { - _, err := b.resolveHeader(ctx) - if err != nil { - return nil, err - } - } return &Account{ r: b.r, address: args.Address, @@ -1063,12 +1081,6 @@ func (c *CallResult) Status() Long { func (b *Block) Call(ctx context.Context, args struct { Data ethapi.TransactionArgs }) (*CallResult, error) { - if b.numberOrHash == nil { - _, err := b.resolve(ctx) - if err != nil { - return nil, err - } - } result, err := ethapi.DoCall(ctx, b.r.backend, args.Data, *b.numberOrHash, nil, b.r.backend.RPCEVMTimeout(), b.r.backend.RPCGasCap()) if err != nil { return nil, err @@ -1088,12 +1100,6 @@ func (b *Block) Call(ctx context.Context, args struct { func (b *Block) EstimateGas(ctx context.Context, args struct { Data ethapi.TransactionArgs }) (Long, error) { - if b.numberOrHash == nil { - _, err := b.resolveHeader(ctx) - if err != nil { - return 0, err - } - } gas, err := ethapi.DoEstimateGas(ctx, b.r.backend, args.Data, *b.numberOrHash, b.r.backend.RPCGasCap()) return Long(gas), err } @@ -1173,29 +1179,21 @@ func (r *Resolver) Block(ctx context.Context, args struct { Number *Long Hash *common.Hash }) (*Block, error) { - var block *Block + var numberOrHash rpc.BlockNumberOrHash if args.Number != nil { if *args.Number < 0 { return nil, nil } number := rpc.BlockNumber(*args.Number) - numberOrHash := rpc.BlockNumberOrHashWithNumber(number) - block = &Block{ - r: r, - numberOrHash: &numberOrHash, - } + numberOrHash = rpc.BlockNumberOrHashWithNumber(number) } else if args.Hash != nil { - numberOrHash := rpc.BlockNumberOrHashWithHash(*args.Hash, false) - block = &Block{ - r: r, - numberOrHash: &numberOrHash, - } + numberOrHash = rpc.BlockNumberOrHashWithHash(*args.Hash, false) } else { - numberOrHash := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) - block = &Block{ - r: r, - numberOrHash: &numberOrHash, - } + numberOrHash = rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) + } + block := &Block{ + r: r, + numberOrHash: &numberOrHash, } // Resolve the header, return nil if it doesn't exist. // Note we don't resolve block directly here since it will require an @@ -1256,7 +1254,7 @@ func (r *Resolver) Transaction(ctx context.Context, args struct{ Hash common.Has hash: args.Hash, } // Resolve the transaction; if it doesn't exist, return nil. - t, err := tx.resolve(ctx) + t, _, err := tx.resolve(ctx) if err != nil { return nil, err } else if t == nil { diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 46acd1529342..9354eac0f359 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -156,6 +156,7 @@ func TestGraphQLBlockSerialization(t *testing.T) { t.Fatalf("could not post: %v", err) } bodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() if err != nil { t.Fatalf("could not read from response body: %v", err) } @@ -239,6 +240,7 @@ func TestGraphQLBlockSerializationEIP2718(t *testing.T) { t.Fatalf("could not post: %v", err) } bodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() if err != nil { t.Fatalf("could not read from response body: %v", err) } @@ -263,11 +265,12 @@ func TestGraphQLHTTPOnSamePort_GQLRequest_Unsuccessful(t *testing.T) { if err != nil { t.Fatalf("could not post: %v", err) } + resp.Body.Close() // make sure the request is not handled successfully assert.Equal(t, http.StatusNotFound, resp.StatusCode) } -func TestGraphQLTransactionLogs(t *testing.T) { +func TestGraphQLConcurrentResolvers(t *testing.T) { var ( key, _ = crypto.GenerateKey() addr = crypto.PubkeyToAddress(key.PublicKey) @@ -292,8 +295,9 @@ func TestGraphQLTransactionLogs(t *testing.T) { ) defer stack.Close() - handler := newGQLService(t, stack, genesis, 1, func(i int, gen *core.BlockGen) { - tx, _ := types.SignNewTx(key, signer, &types.LegacyTx{To: &dad, Gas: 100000, GasPrice: big.NewInt(params.InitialBaseFee)}) + var tx *types.Transaction + handler, chain := newGQLService(t, stack, genesis, 1, func(i int, gen *core.BlockGen) { + tx, _ = types.SignNewTx(key, signer, &types.LegacyTx{To: &dad, Gas: 100000, GasPrice: big.NewInt(params.InitialBaseFee)}) gen.AddTx(tx) tx, _ = types.SignNewTx(key, signer, &types.LegacyTx{To: &dad, Nonce: 1, Gas: 100000, GasPrice: big.NewInt(params.InitialBaseFee)}) gen.AddTx(tx) @@ -304,18 +308,59 @@ func TestGraphQLTransactionLogs(t *testing.T) { if err := stack.Start(); err != nil { t.Fatalf("could not start node: %v", err) } - query := `{block { transactions { logs { account { address } } } } }` - res := handler.Schema.Exec(context.Background(), query, "", map[string]interface{}{}) - if res.Errors != nil { - t.Fatalf("graphql query failed: %v", res.Errors) - } - have, err := json.Marshal(res.Data) - if err != nil { - t.Fatalf("failed to encode graphql response: %s", err) - } - want := fmt.Sprintf(`{"block":{"transactions":[{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]},{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]},{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]}]}}`, dadStr, dadStr, dadStr, dadStr, dadStr, dadStr) - if string(have) != want { - t.Errorf("response unmatch. expected %s, got %s", want, have) + + for i, tt := range []struct { + body string + want string + }{ + // Multiple txes race to get/set the block hash. + { + body: "{block { transactions { logs { account { address } } } } }", + want: fmt.Sprintf(`{"block":{"transactions":[{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]},{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]},{"logs":[{"account":{"address":"%s"}},{"account":{"address":"%s"}}]}]}}`, dadStr, dadStr, dadStr, dadStr, dadStr, dadStr), + }, + // Multiple fields of a tx race to resolve it. Happens in this case + // because resolving the tx body belonging to a log is delayed. + { + body: `{block { logs(filter: {}) { transaction { nonce value gasPrice }}}}`, + want: `{"block":{"logs":[{"transaction":{"nonce":"0x0","value":"0x0","gasPrice":"0x3b9aca00"}},{"transaction":{"nonce":"0x0","value":"0x0","gasPrice":"0x3b9aca00"}},{"transaction":{"nonce":"0x1","value":"0x0","gasPrice":"0x3b9aca00"}},{"transaction":{"nonce":"0x1","value":"0x0","gasPrice":"0x3b9aca00"}},{"transaction":{"nonce":"0x2","value":"0x0","gasPrice":"0x3b9aca00"}},{"transaction":{"nonce":"0x2","value":"0x0","gasPrice":"0x3b9aca00"}}]}}`, + }, + // Multiple txes of a block race to set/retrieve receipts of a block. + { + body: "{block { transactions { status gasUsed } } }", + want: `{"block":{"transactions":[{"status":1,"gasUsed":21768},{"status":1,"gasUsed":21768},{"status":1,"gasUsed":21768}]}}`, + }, + // Multiple fields of block race to resolve header and body. + { + body: "{ block { number hash gasLimit ommerCount transactionCount totalDifficulty } }", + want: fmt.Sprintf(`{"block":{"number":1,"hash":"%s","gasLimit":11500000,"ommerCount":0,"transactionCount":3,"totalDifficulty":"0x200000"}}`, chain[len(chain)-1].Hash()), + }, + // Multiple fields of a block race to resolve the header and body. + { + body: fmt.Sprintf(`{ transaction(hash: "%s") { block { number hash gasLimit ommerCount transactionCount } } }`, tx.Hash()), + want: fmt.Sprintf(`{"transaction":{"block":{"number":1,"hash":"%s","gasLimit":11500000,"ommerCount":0,"transactionCount":3}}}`, chain[len(chain)-1].Hash()), + }, + // Account fields race the resolve the state object. + { + body: fmt.Sprintf(`{ block { account(address: "%s") { balance transactionCount code } } }`, dadStr), + want: `{"block":{"account":{"balance":"0x0","transactionCount":"0x0","code":"0x60006000a060006000a060006000f3"}}}`, + }, + // Test values for a non-existent account. + { + body: fmt.Sprintf(`{ block { account(address: "%s") { balance transactionCount code } } }`, "0x1111111111111111111111111111111111111111"), + want: `{"block":{"account":{"balance":"0x0","transactionCount":"0x0","code":"0x"}}}`, + }, + } { + res := handler.Schema.Exec(context.Background(), tt.body, "", map[string]interface{}{}) + if res.Errors != nil { + t.Fatalf("failed to execute query for testcase #%d: %v", i, res.Errors) + } + have, err := json.Marshal(res.Data) + if err != nil { + t.Fatalf("failed to encode graphql response for testcase #%d: %s", i, err) + } + if string(have) != tt.want { + t.Errorf("response unmatch for testcase #%d.\nExpected:\n%s\nGot:\n%s\n", i, tt.want, have) + } } } @@ -333,7 +378,7 @@ func createNode(t *testing.T) *node.Node { return stack } -func newGQLService(t *testing.T, stack *node.Node, gspec *core.Genesis, genBlocks int, genfunc func(i int, gen *core.BlockGen)) *handler { +func newGQLService(t *testing.T, stack *node.Node, gspec *core.Genesis, genBlocks int, genfunc func(i int, gen *core.BlockGen)) (*handler, []*types.Block) { ethConf := ðconfig.Config{ Genesis: gspec, Ethash: ethash.Config{ @@ -364,5 +409,5 @@ func newGQLService(t *testing.T, stack *node.Node, gspec *core.Genesis, genBlock if err != nil { t.Fatalf("could not create graphql service: %v", err) } - return handler + return handler, chain } diff --git a/internal/debug/flags.go b/internal/debug/flags.go index e014a85d7fc8..0bae9883ec3b 100644 --- a/internal/debug/flags.go +++ b/internal/debug/flags.go @@ -54,6 +54,11 @@ var ( Usage: "Format logs with JSON", Category: flags.LoggingCategory, } + logfmtFlag = &cli.BoolFlag{ + Name: "log.logfmt", + Usage: "Format logs with logfmt", + Category: flags.LoggingCategory, + } logFileFlag = &cli.StringFlag{ Name: "log.file", Usage: "Write logs to a file", @@ -115,6 +120,7 @@ var Flags = []cli.Flag{ verbosityFlag, vmoduleFlag, logjsonFlag, + logfmtFlag, logFileFlag, backtraceAtFlag, debugFlag, @@ -147,6 +153,8 @@ func Setup(ctx *cli.Context) error { var logfmt log.Format if ctx.Bool(logjsonFlag.Name) { logfmt = log.JSONFormat() + } else if ctx.Bool(logfmtFlag.Name) { + logfmt = log.LogfmtFormat() } else { logfmt = log.TerminalFormat(useColor) } diff --git a/light/trie.go b/light/trie.go index 0f05a80004ab..2f8f71c61e9a 100644 --- a/light/trie.go +++ b/light/trie.go @@ -105,7 +105,7 @@ type odrTrie struct { trie *trie.Trie } -func (t *odrTrie) TryGetStorage(_ common.Address, key []byte) ([]byte, error) { +func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { key = crypto.Keccak256(key) var res []byte err := t.do(key, func() (err error) { @@ -115,7 +115,7 @@ func (t *odrTrie) TryGetStorage(_ common.Address, key []byte) ([]byte, error) { return res, err } -func (t *odrTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) { +func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error) { var res types.StateAccount key := crypto.Keccak256(address.Bytes()) err := t.do(key, func() (err error) { @@ -131,7 +131,7 @@ func (t *odrTrie) TryGetAccount(address common.Address) (*types.StateAccount, er return &res, err } -func (t *odrTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error { +func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { key := crypto.Keccak256(address.Bytes()) value, err := rlp.EncodeToBytes(acc) if err != nil { @@ -142,14 +142,14 @@ func (t *odrTrie) TryUpdateAccount(address common.Address, acc *types.StateAccou }) } -func (t *odrTrie) TryUpdateStorage(_ common.Address, key, value []byte) error { +func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { return t.trie.TryUpdate(key, value) }) } -func (t *odrTrie) TryDeleteStorage(_ common.Address, key []byte) error { +func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { return t.trie.TryDelete(key) @@ -157,7 +157,7 @@ func (t *odrTrie) TryDeleteStorage(_ common.Address, key []byte) error { } // TryDeleteAccount abstracts an account deletion from the trie. -func (t *odrTrie) TryDeleteAccount(address common.Address) error { +func (t *odrTrie) DeleteAccount(address common.Address) error { key := crypto.Keccak256(address.Bytes()) return t.do(key, func() error { return t.trie.TryDelete(key) diff --git a/metrics/librato/client.go b/metrics/librato/client.go index 729c2da9a9bc..f1b9e1e91669 100644 --- a/metrics/librato/client.go +++ b/metrics/librato/client.go @@ -87,9 +87,11 @@ func (c *LibratoClient) PostMetrics(batch Batch) (err error) { req.Header.Set("Content-Type", "application/json") req.SetBasicAuth(c.Email, c.Token) - if resp, err = http.DefaultClient.Do(req); err != nil { + resp, err = http.DefaultClient.Do(req) + if err != nil { return } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { var body []byte diff --git a/node/node_test.go b/node/node_test.go index 560d487fa823..04810a815bf6 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -615,6 +615,7 @@ func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { if err != nil { t.Fatalf("could not issue a GET request to the given endpoint: %v", err) } + t.Cleanup(func() { resp.Body.Close() }) return resp } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 4d10e61e2dec..0790dddec411 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -320,6 +320,7 @@ func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) * if err != nil { t.Fatal(err) } + t.Cleanup(func() { resp.Body.Close() }) return resp } diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 7bfa8aab6d10..1d812514dee7 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -428,9 +428,11 @@ func execP2PNode() { // Send status to the host. statusJSON, _ := json.Marshal(status) - if _, err := http.Post(statusURL, "application/json", bytes.NewReader(statusJSON)); err != nil { + resp, err := http.Post(statusURL, "application/json", bytes.NewReader(statusJSON)) + if err != nil { log.Crit("Can't post startup info", "url", statusURL, "err", err) } + resp.Body.Close() if stackErr != nil { os.Exit(1) } diff --git a/p2p/simulations/mocker_test.go b/p2p/simulations/mocker_test.go index 56d81942bbd0..0112ee5cfd6e 100644 --- a/p2p/simulations/mocker_test.go +++ b/p2p/simulations/mocker_test.go @@ -124,6 +124,7 @@ func TestMocker(t *testing.T) { if err != nil { t.Fatalf("Could not start mocker: %s", err) } + resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("Invalid Status Code received for starting mocker, expected 200, got %d", resp.StatusCode) } @@ -145,15 +146,17 @@ func TestMocker(t *testing.T) { if err != nil { t.Fatalf("Could not stop mocker: %s", err) } + resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("Invalid Status Code received for stopping mocker, expected 200, got %d", resp.StatusCode) } //reset the network - _, err = http.Post(s.URL+"/reset", "", nil) + resp, err = http.Post(s.URL+"/reset", "", nil) if err != nil { t.Fatalf("Could not reset network: %s", err) } + resp.Body.Close() //now the number of nodes in the network should be zero nodesInfo, err = client.GetNodes() diff --git a/rpc/http_test.go b/rpc/http_test.go index 528e1bcfc5e7..584842a9aa8d 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -94,6 +94,7 @@ func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body if err != nil { t.Fatalf("request failed: %v", err) } + resp.Body.Close() confirmStatusCode(t, resp.StatusCode, expectedStatusCode) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 01e004d02ed5..b25bf758c15d 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -75,25 +75,25 @@ func NewStateTrie(id *ID, db *Database) (*StateTrie, error) { // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *StateTrie) Get(key []byte) []byte { - res, err := t.TryGetStorage(common.Address{}, key) + res, err := t.GetStorage(common.Address{}, key) if err != nil { log.Error("Unhandled trie error in StateTrie.Get", "err", err) } return res } -// TryGet returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -// If the specified node is not in the trie, nil will be returned. +// GetStorage attempts to retrieve a storage slot with provided account address +// and slot key. The value bytes must not be modified by the caller. +// If the specified storage slot is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) TryGetStorage(_ common.Address, key []byte) ([]byte, error) { +func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { return t.trie.TryGet(t.hashKey(key)) } -// TryGetAccount attempts to retrieve an account with provided account address. +// GetAccount attempts to retrieve an account with provided account address. // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) { +func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { res, err := t.trie.TryGet(t.hashKey(address.Bytes())) if res == nil || err != nil { return nil, err @@ -103,10 +103,10 @@ func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, return ret, err } -// TryGetAccountByHash does the same thing as TryGetAccount, however -// it expects an account hash that is the hash of address. This constitutes an -// abstraction leak, since the client code needs to know the key format. -func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { +// GetAccountByHash does the same thing as GetAccount, however it expects an +// account hash that is the hash of address. This constitutes an abstraction +// leak, since the client code needs to know the key format. +func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { res, err := t.trie.TryGet(addrHash.Bytes()) if res == nil || err != nil { return nil, err @@ -116,11 +116,11 @@ func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccou return ret, err } -// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not +// GetNode attempts to retrieve a trie node by compact-encoded path. It is not // possible to use keybyte-encoding as the path might contain odd nibbles. // If the specified trie node is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { +func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { return t.trie.TryGetNode(path) } @@ -131,12 +131,12 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { // The value bytes must not be modified by the caller while they are // stored in the trie. func (t *StateTrie) Update(key, value []byte) { - if err := t.TryUpdateStorage(common.Address{}, key, value); err != nil { + if err := t.UpdateStorage(common.Address{}, key, value); err != nil { log.Error("Unhandled trie error in StateTrie.Update", "err", err) } } -// TryUpdate associates key with value in the trie. Subsequent calls to +// UpdateStorage associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // @@ -144,7 +144,7 @@ func (t *StateTrie) Update(key, value []byte) { // stored in the trie. // // If a node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) TryUpdateStorage(_ common.Address, key, value []byte) error { +func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) err := t.trie.TryUpdate(hk, value) if err != nil { @@ -154,9 +154,8 @@ func (t *StateTrie) TryUpdateStorage(_ common.Address, key, value []byte) error return nil } -// TryUpdateAccount account will abstract the write of an account to the -// secure trie. -func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error { +// UpdateAccount will abstract the write of an account to the secure trie. +func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { hk := t.hashKey(address.Bytes()) data, err := rlp.EncodeToBytes(acc) if err != nil { @@ -171,22 +170,22 @@ func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAcc // Delete removes any existing value for key from the trie. func (t *StateTrie) Delete(key []byte) { - if err := t.TryDeleteStorage(common.Address{}, key); err != nil { + if err := t.DeleteStorage(common.Address{}, key); err != nil { log.Error("Unhandled trie error in StateTrie.Delete", "err", err) } } -// TryDelete removes any existing value for key from the trie. +// DeleteStorage removes any existing storage slot from the trie. // If the specified trie node is not in the trie, nothing will be changed. // If a node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) TryDeleteStorage(_ common.Address, key []byte) error { +func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) return t.trie.TryDelete(hk) } -// TryDeleteAccount abstracts an account deletion from the trie. -func (t *StateTrie) TryDeleteAccount(address common.Address) error { +// DeleteAccount abstracts an account deletion from the trie. +func (t *StateTrie) DeleteAccount(address common.Address) error { hk := t.hashKey(address.Bytes()) delete(t.getSecKeyCache(), string(hk)) return t.trie.TryDelete(hk) diff --git a/trie/sync_test.go b/trie/sync_test.go index 8fec378333da..f404dc1cda80 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -154,7 +154,7 @@ func testIterativeSync(t *testing.T, count int, bypath bool) { } } else { for i, element := range elements { - data, _, err := srcTrie.TryGetNode(element.syncPath[len(element.syncPath)-1]) + data, _, err := srcTrie.GetNode(element.syncPath[len(element.syncPath)-1]) if err != nil { t.Fatalf("failed to retrieve node data for path %x: %v", element.path, err) }