diff --git a/main.go b/main.go index cde2660..caa0e7e 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,13 @@ import ( func main() { - tr := p2p.NewTCPTransport(":3000") + tcpOpts := p2p.TCPTransportOptions{ + ListenAddr: ":3000", + HandshakeFunc: p2p.NOPHandshakeFunc, + Decoder: p2p.NOPDecoder{}, + } + + tr := p2p.NewTCPTransport(tcpOpts) if err := tr.ListenAndAccept(); err != nil { log.Fatal(err) diff --git a/p2p/encoding.go b/p2p/encoding.go index 7feaef6..514153f 100644 --- a/p2p/encoding.go +++ b/p2p/encoding.go @@ -1,7 +1,31 @@ package p2p -import "io" +import ( + "bytes" + "encoding/gob" + "io" +) type Decoder interface { - Decode(io.Reader, any) error + Decode(io.Reader, *RPC) error +} + +type GOBDecoder struct{} + +func (dec GOBDecoder) Decode(r io.Reader, rpc *RPC) error { + return gob.NewDecoder(r).Decode(rpc) +} + +type NOPDecoder struct{} + +func (dec NOPDecoder) Decode(r io.Reader, rpc *RPC) error { + buf := new(bytes.Buffer) + n, err := buf.ReadFrom(r) + if err != nil { + return err + } + + rpc.Payload = buf.Bytes()[:n] + + return nil } diff --git a/p2p/handshake.go b/p2p/handshake.go index 560a683..e8c834d 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -1,5 +1,11 @@ package p2p +import "errors" + +//ErrInvalidHandshake is returned if the handshake between +// the local and remote node could not be established. +var ErrInvalidHandShake = errors.New("invalid handshake") + type HandshakeFunc func(Peer) error func NOPHandshakeFunc(Peer) error { diff --git a/p2p/message.go b/p2p/message.go new file mode 100644 index 0000000..5d8650e --- /dev/null +++ b/p2p/message.go @@ -0,0 +1,11 @@ +package p2p + +import "net" + +// Message represents any arbitrary data +// that is being sent over the transport +// between two nodes in the network. +type RPC struct { + From net.Addr + Payload []byte +} diff --git a/p2p/tcp-transport_test.go b/p2p/tcp-transport_test.go index 782dc80..117c2d1 100644 --- a/p2p/tcp-transport_test.go +++ b/p2p/tcp-transport_test.go @@ -7,10 +7,14 @@ import ( ) func TestTCPTransport(t *testing.T) { - listenAddr := ":3000" - tr := NewTCPTransport(listenAddr) + tcpOpts := TCPTransportOptions{ + ListenAddr: ":3000", + HandshakeFunc: NOPHandshakeFunc, + } - assert.Equal(t, tr.listenAddr, listenAddr) + tr := NewTCPTransport(tcpOpts) + + assert.Equal(t, tr.ListenAddr, ":3000") assert.Nil(t, tr.ListenAndAccept()) } diff --git a/p2p/tcp_transport.go b/p2p/tcp_transport.go index 818789b..25af96a 100644 --- a/p2p/tcp_transport.go +++ b/p2p/tcp_transport.go @@ -23,26 +23,29 @@ func NewTCPPeer(conn net.Conn, outbound bool) *TCPPeer { } } +type TCPTransportOptions struct { + ListenAddr string + HandshakeFunc HandshakeFunc + Decoder Decoder +} + type TCPTransport struct { - listenAddr string - listener net.Listener - shakeHands HandshakeFunc - decoder Decoder + TCPTransportOptions + listener net.Listener mu sync.RWMutex peers map[net.Addr]Peer } -func NewTCPTransport(listenAddr string) *TCPTransport { +func NewTCPTransport(opts TCPTransportOptions) *TCPTransport { return &TCPTransport{ - listenAddr: listenAddr, - shakeHands: NOPHandshakeFunc, + TCPTransportOptions: opts, } } func (t *TCPTransport) ListenAndAccept() error { var err error - t.listener, err = net.Listen("tcp", t.listenAddr) + t.listener, err = net.Listen("tcp", t.ListenAddr) if err != nil { return err } @@ -65,31 +68,33 @@ func (t *TCPTransport) startAcceptLoop() { } } -type Msg struct{} - func (t *TCPTransport) handleConn(conn net.Conn) { peer := NewTCPPeer(conn, true) - fmt.Printf("handling connection %+v\n", peer) + fmt.Printf("TCP handling connection %+v\n", peer) - if err := t.shakeHands(peer); err != nil { - fmt.Printf("handshake error: %v\n", err) + if err := t.HandshakeFunc(peer); err != nil { + fmt.Printf("TCP handshake error: %v\n", err) + conn.Close() return } lenDecodeError := 0 //Read Loop - msg := &Msg{} + rpc := &RPC{} for { // read from the connection - if err := t.decoder.Decode(conn, msg); err != nil { + if err := t.Decoder.Decode(conn, rpc); err != nil { lenDecodeError++ if lenDecodeError == 5 { - fmt.Printf("dropping connection due to multiple decode errors: %+v\n", peer) + fmt.Printf("TCP dropping connection due to multiple decode errors: %+v\n", peer) return } fmt.Printf("decode error: %v\n", err) continue } + + rpc.From = conn.RemoteAddr() + fmt.Printf("message: %+v\n", rpc) // write to the connection // close the connection }