diff --git a/cmd/devp2p/internal/ethtest/large.go b/cmd/devp2p/internal/ethtest/large.go new file mode 100644 index 000000000000..deca00be5384 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/large.go @@ -0,0 +1,80 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "crypto/rand" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" +) + +// largeNumber returns a very large big.Int. +func largeNumber(megabytes int) *big.Int { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + bigint := new(big.Int) + bigint.SetBytes(buf) + return bigint +} + +// largeBuffer returns a very large buffer. +func largeBuffer(megabytes int) []byte { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + return buf +} + +// largeString returns a very large string. +func largeString(megabytes int) string { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + return hexutil.Encode(buf) +} + +func largeBlock() *types.Block { + return types.NewBlockWithHeader(largeHeader()) +} + +// Returns a random hash +func randHash() common.Hash { + var h common.Hash + rand.Read(h[:]) + return h +} + +func largeHeader() *types.Header { + return &types.Header{ + MixDigest: randHash(), + ReceiptHash: randHash(), + TxHash: randHash(), + Nonce: types.BlockNonce{}, + Extra: []byte{}, + Bloom: types.Bloom{}, + GasUsed: 0, + Coinbase: common.Address{}, + GasLimit: 0, + UncleHash: randHash(), + Time: 1337, + ParentHash: randHash(), + Root: randHash(), + Number: largeNumber(2), + Difficulty: largeNumber(2), + } +} diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index d5928bede44f..0348751f88a5 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -24,6 +24,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/rlpx" "github.com/stretchr/testify/assert" @@ -66,6 +67,9 @@ func (s *Suite) AllTests() []utesting.Test { {Name: "GetBlockHeaders", Fn: s.TestGetBlockHeaders}, {Name: "Broadcast", Fn: s.TestBroadcast}, {Name: "GetBlockBodies", Fn: s.TestGetBlockBodies}, + {Name: "TestLargeAnnounce", Fn: s.TestLargeAnnounce}, + {Name: "TestMaliciousHandshake", Fn: s.TestMaliciousHandshake}, + {Name: "TestMaliciousStatus", Fn: s.TestMaliciousStatus}, } } @@ -80,7 +84,7 @@ func (s *Suite) TestStatus(t *utesting.T) { // get protoHandshake conn.handshake(t) // get status - switch msg := conn.statusExchange(t, s.chain).(type) { + switch msg := conn.statusExchange(t, s.chain, nil).(type) { case *Status: t.Logf("got status message: %s", pretty.Sdump(msg)) default: @@ -88,6 +92,40 @@ func (s *Suite) TestStatus(t *utesting.T) { } } +// TestMaliciousStatus sends a status package with a large total difficulty. +func (s *Suite) TestMaliciousStatus(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + // get protoHandshake + conn.handshake(t) + status := &Status{ + ProtocolVersion: uint32(conn.ethProtocolVersion), + NetworkID: s.chain.chainConfig.ChainID.Uint64(), + TD: largeNumber(2), + Head: s.chain.blocks[s.chain.Len()-1].Hash(), + Genesis: s.chain.blocks[0].Hash(), + ForkID: s.chain.ForkID(), + } + // get status + switch msg := conn.statusExchange(t, s.chain, status).(type) { + case *Status: + t.Logf("%+v\n", msg) + default: + t.Fatalf("expected status, got: %#v ", msg) + } + timeout := 20 * time.Second + // wait for disconnect + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + return + default: + t.Fatalf("expected disconnect, got: %s", pretty.Sdump(msg)) + } +} + // TestGetBlockHeaders tests whether the given node can respond to // a `GetBlockHeaders` request and that the response is accurate. func (s *Suite) TestGetBlockHeaders(t *utesting.T) { @@ -97,7 +135,7 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) { } conn.handshake(t) - conn.statusExchange(t, s.chain) + conn.statusExchange(t, s.chain, nil) // get block headers req := &GetBlockHeaders{ @@ -136,7 +174,7 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) { } conn.handshake(t) - conn.statusExchange(t, s.chain) + conn.statusExchange(t, s.chain, nil) // create block bodies request req := &GetBlockBodies{s.chain.blocks[54].Hash(), s.chain.blocks[75].Hash()} if err := conn.Write(req); err != nil { @@ -155,34 +193,158 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) { // TestBroadcast tests whether a block announcement is correctly // propagated to the given node's peer(s). func (s *Suite) TestBroadcast(t *utesting.T) { - // create conn to send block announcement - sendConn, err := s.dial() - if err != nil { - t.Fatalf("could not dial: %v", err) + sendConn, receiveConn := s.setupConnection(t), s.setupConnection(t) + nextBlock := len(s.chain.blocks) + blockAnnouncement := &NewBlock{ + Block: s.fullChain.blocks[nextBlock], + TD: s.fullChain.TD(nextBlock + 1), + } + s.testAnnounce(t, sendConn, receiveConn, blockAnnouncement) + // update test suite chain + s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) + // wait for client to update its chain + if err := receiveConn.waitForBlock(s.chain.Head()); err != nil { + t.Fatal(err) } - // create conn to receive block announcement - receiveConn, err := s.dial() +} + +// TestMaliciousHandshake tries to send malicious data during the handshake. +func (s *Suite) TestMaliciousHandshake(t *utesting.T) { + conn, err := s.dial() if err != nil { t.Fatalf("could not dial: %v", err) } + // write hello to client + pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] + handshakes := []*Hello{ + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: pub0, + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, byte(0)), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, pub0...), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: largeBuffer(2), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: largeBuffer(2), + }, + } + for i, handshake := range handshakes { + fmt.Printf("Testing malicious handshake %v\n", i) + // Init the handshake + if err := conn.Write(handshake); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // check that the peer disconnected + timeout := 20 * time.Second + // Discard one hello + for i := 0; i < 2; i++ { + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + case *Hello: + // Hello's are send concurrently, so ignore them + continue + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } + } + // Dial for the next round + conn, err = s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + } +} - sendConn.handshake(t) - receiveConn.handshake(t) - - sendConn.statusExchange(t, s.chain) - receiveConn.statusExchange(t, s.chain) +// TestLargeAnnounce tests the announcement mechanism with a large block. +func (s *Suite) TestLargeAnnounce(t *utesting.T) { + nextBlock := len(s.chain.blocks) + blocks := []*NewBlock{ + { + Block: largeBlock(), + TD: s.fullChain.TD(nextBlock + 1), + }, + { + Block: s.fullChain.blocks[nextBlock], + TD: largeNumber(2), + }, + { + Block: largeBlock(), + TD: largeNumber(2), + }, + { + Block: s.fullChain.blocks[nextBlock], + TD: s.fullChain.TD(nextBlock + 1), + }, + } - // sendConn sends the block announcement - blockAnnouncement := &NewBlock{ - Block: s.fullChain.blocks[1000], - TD: s.fullChain.TD(1001), + for i, blockAnnouncement := range blocks[0:3] { + fmt.Printf("Testing malicious announcement: %v\n", i) + sendConn := s.setupConnection(t) + if err := sendConn.Write(blockAnnouncement); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // Invalid announcement, check that peer disconnected + timeout := 20 * time.Second + switch msg := sendConn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + break + default: + t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg)) + } + } + // Test the last block as a valid block + sendConn := s.setupConnection(t) + receiveConn := s.setupConnection(t) + s.testAnnounce(t, sendConn, receiveConn, blocks[3]) + // update test suite chain + s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) + // wait for client to update its chain + if err := receiveConn.waitForBlock(s.fullChain.blocks[nextBlock]); err != nil { + t.Fatal(err) } +} + +func (s *Suite) testAnnounce(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) { + // Announce the block. if err := sendConn.Write(blockAnnouncement); err != nil { t.Fatalf("could not write to connection: %v", err) } + s.waitAnnounce(t, receiveConn, blockAnnouncement) +} +func (s *Suite) waitAnnounce(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) { timeout := 20 * time.Second - switch msg := receiveConn.ReadAndServe(s.chain, timeout).(type) { + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { case *NewBlock: t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block)) assert.Equal(t, @@ -203,12 +365,17 @@ func (s *Suite) TestBroadcast(t *utesting.T) { default: t.Fatalf("unexpected: %s", pretty.Sdump(msg)) } - // update test suite chain - s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[1000]) - // wait for client to update its chain - if err := receiveConn.waitForBlock(s.chain.Head()); err != nil { - t.Fatal(err) +} + +func (s *Suite) setupConnection(t *utesting.T) *Conn { + // create conn + sendConn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) } + sendConn.handshake(t) + sendConn.statusExchange(t, s.chain, nil) + return sendConn } // dial attempts to dial the given node and perform a handshake, diff --git a/cmd/devp2p/internal/ethtest/types.go b/cmd/devp2p/internal/ethtest/types.go index 69367cb6cd48..a20e88c3728f 100644 --- a/cmd/devp2p/internal/ethtest/types.go +++ b/cmd/devp2p/internal/ethtest/types.go @@ -304,7 +304,7 @@ func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { // statusExchange performs a `Status` message exchange with the given // node. -func (c *Conn) statusExchange(t *utesting.T, chain *Chain) Message { +func (c *Conn) statusExchange(t *utesting.T, chain *Chain, status *Status) Message { defer c.SetDeadline(time.Time{}) c.SetDeadline(time.Now().Add(20 * time.Second)) @@ -338,16 +338,19 @@ loop: if c.ethProtocolVersion == 0 { t.Fatalf("eth protocol version must be set in Conn") } - // write status message to client - status := Status{ - ProtocolVersion: uint32(c.ethProtocolVersion), - NetworkID: chain.chainConfig.ChainID.Uint64(), - TD: chain.TD(chain.Len()), - Head: chain.blocks[chain.Len()-1].Hash(), - Genesis: chain.blocks[0].Hash(), - ForkID: chain.ForkID(), + if status == nil { + // write status message to client + status = &Status{ + ProtocolVersion: uint32(c.ethProtocolVersion), + NetworkID: chain.chainConfig.ChainID.Uint64(), + TD: chain.TD(chain.Len()), + Head: chain.blocks[chain.Len()-1].Hash(), + Genesis: chain.blocks[0].Hash(), + ForkID: chain.ForkID(), + } } - if err := c.Write(status); err != nil { + + if err := c.Write(*status); err != nil { t.Fatalf("could not write to connection: %v", err) }