Skip to content

Commit

Permalink
Merge pull request #92 from Chia-Network/peer-protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
cmmarslender authored Jun 28, 2023
2 parents a3e2fad + bbc7029 commit 6e25810
Show file tree
Hide file tree
Showing 17 changed files with 1,272 additions and 0 deletions.
193 changes: 193 additions & 0 deletions pkg/peerprotocol/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package peerprotocol

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"time"

"github.com/gorilla/websocket"

"github.com/chia-network/go-chia-libs/pkg/config"
"github.com/chia-network/go-chia-libs/pkg/protocols"
)

// Connection represents a connection with a peer and enables communication
type Connection struct {
chiaConfig *config.ChiaConfig

peerIP *net.IP
peerPort uint16
peerKeyPair *tls.Certificate
peerDialer *websocket.Dialer

handshakeTimeout time.Duration
conn *websocket.Conn
}

// PeerResponseHandlerFunc is a function that will be called when a response is returned from a peer
type PeerResponseHandlerFunc func(*protocols.Message, error)

// NewConnection creates a new connection object with the specified peer
func NewConnection(ip *net.IP, options ...ConnectionOptionFunc) (*Connection, error) {
cfg, err := config.GetChiaConfig()
if err != nil {
return nil, err
}

c := &Connection{
chiaConfig: cfg,
peerIP: ip,
peerPort: cfg.FullNode.Port,
}

for _, fn := range options {
if fn == nil {
continue
}
if err := fn(c); err != nil {
return nil, err
}
}

err = c.loadKeyPair()
if err != nil {
return nil, err
}

// Generate the websocket dialer
err = c.generateDialer()
if err != nil {
return nil, err
}

return c, nil
}

func (c *Connection) loadKeyPair() error {
var err error

c.peerKeyPair, err = c.chiaConfig.FullNode.SSL.LoadPublicKeyPair(c.chiaConfig.ChiaRoot)
if err != nil {
return err
}

return nil
}

func (c *Connection) generateDialer() error {
if c.peerDialer == nil {
c.peerDialer = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: c.handshakeTimeout,
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*c.peerKeyPair},
InsecureSkipVerify: true,
},
}
}

return nil
}

// ensureConnection ensures there is an open websocket connection
func (c *Connection) ensureConnection() error {
if c.conn == nil {
u := url.URL{Scheme: "wss", Host: fmt.Sprintf("%s:%d", c.peerIP.String(), c.peerPort), Path: "/ws"}
var err error
c.conn, _, err = c.peerDialer.Dial(u.String(), nil)
if err != nil {
return err
}
}

return nil
}

// Close closes the connection, if open
func (c *Connection) Close() {
if c.conn != nil {
err := c.conn.Close()
if err != nil {
return
}
c.conn = nil
}
}

// Handshake performs the RPC handshake. This should be called before any other method
func (c *Connection) Handshake() error {
// Handshake
handshake := &protocols.Handshake{
NetworkID: c.chiaConfig.SelectedNetwork,
ProtocolVersion: protocols.ProtocolVersion,
SoftwareVersion: "2.0.0",
ServerPort: c.peerPort,
NodeType: protocols.NodeTypeFullNode, // I guess we're a full node
Capabilities: []protocols.Capability{
{
Capability: protocols.CapabilityTypeBase,
Value: "1",
},
},
}

return c.Do(protocols.ProtocolMessageTypeHandshake, handshake)
}

// Do send a request over the websocket
func (c *Connection) Do(messageType protocols.ProtocolMessageType, data interface{}) error {
err := c.ensureConnection()
if err != nil {
return err
}

msgBytes, err := protocols.MakeMessageBytes(messageType, data)
if err != nil {
return err
}

return c.conn.WriteMessage(websocket.BinaryMessage, msgBytes)
}

// ReadSync Reads for async responses over the connection in a synchronous fashion, blocking anything else
func (c *Connection) ReadSync(handler PeerResponseHandlerFunc) error {
for {
_, bytes, err := c.conn.ReadMessage()
if err != nil {
// @TODO Handle Error
return err

}
handler(protocols.DecodeMessage(bytes))
}
}

// ReadOne reads and returns one message from the connection
func (c *Connection) ReadOne(timeout time.Duration) (*protocols.Message, error) {
chBytes := make(chan []byte, 1)
chErr := make(chan error, 1)
ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

go c.readOneCtx(ctxTimeout, chBytes, chErr)

select {
case <-ctxTimeout.Done():
return nil, fmt.Errorf("context cancelled: %v", ctxTimeout.Err())
case result := <-chBytes:
return protocols.DecodeMessage(result)
}
}

func (c *Connection) readOneCtx(ctx context.Context, chBytes chan []byte, chErr chan error) {
_, bytes, err := c.conn.ReadMessage()
if err != nil {
chErr <- err
}

chBytes <- bytes
}
16 changes: 16 additions & 0 deletions pkg/peerprotocol/connectionoptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package peerprotocol

import (
"time"
)

// ConnectionOptionFunc can be used to customize a new Connection
type ConnectionOptionFunc func(connection *Connection) error

// WithHandshakeTimeout sets the handshake timeout
func WithHandshakeTimeout(timeout time.Duration) ConnectionOptionFunc {
return func(c *Connection) error {
c.handshakeTimeout = timeout
return nil
}
}
22 changes: 22 additions & 0 deletions pkg/peerprotocol/fullnode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package peerprotocol

import (
"github.com/chia-network/go-chia-libs/pkg/protocols"
)

// FullNodeProtocol is for interfacing with full nodes via the peer protocol
type FullNodeProtocol struct {
connection *Connection
}

// NewFullNodeProtocol returns a new instance of the full node protocol
func NewFullNodeProtocol(connection *Connection) (*FullNodeProtocol, error) {
fnp := &FullNodeProtocol{connection: connection}

return fnp, nil
}

// RequestPeers asks the current peer to respond with their current peer list
func (c *FullNodeProtocol) RequestPeers() error {
return c.connection.Do(protocols.ProtocolMessageTypeRequestPeers, &protocols.RequestPeers{})
}
13 changes: 13 additions & 0 deletions pkg/protocols/fullnode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package protocols

import (
"github.com/chia-network/go-chia-libs/pkg/types"
)

// RequestPeers is an empty struct
type RequestPeers struct{}

// RespondPeers is the format for the request_peers response
type RespondPeers struct {
PeerList []types.TimestampedPeerInfo `streamable:""`
}
40 changes: 40 additions & 0 deletions pkg/protocols/fullnode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package protocols_test

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/assert"

"github.com/chia-network/go-chia-libs/pkg/protocols"
"github.com/chia-network/go-chia-libs/pkg/streamable"
)

func TestRespondPeers(t *testing.T) {
// Has one peer in the list
// IP 1.2.3.4
// Port 8444
// Timestamp 1643913969
hexStr := "0000000100000007312e322e332e3420fc0000000061fc22f1"

// Hex to bytes
encodedBytes, err := hex.DecodeString(hexStr)
assert.NoError(t, err)

rp := &protocols.RespondPeers{}

err = streamable.Unmarshal(encodedBytes, rp)
assert.NoError(t, err)

assert.Len(t, rp.PeerList, 1)

pl1 := rp.PeerList[0]
assert.Equal(t, "1.2.3.4", pl1.Host)
assert.Equal(t, uint16(8444), pl1.Port)
assert.Equal(t, uint64(1643913969), pl1.Timestamp)

// Test going the other direction
reencodedBytes, err := streamable.Marshal(rp)
assert.NoError(t, err)
assert.Equal(t, encodedBytes, reencodedBytes)
}
71 changes: 71 additions & 0 deletions pkg/protocols/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package protocols

import (
"github.com/chia-network/go-chia-libs/pkg/streamable"

"github.com/samber/mo"
)

// Message is a protocol message
type Message struct {
ProtocolMessageType ProtocolMessageType `streamable:""`
ID mo.Option[uint16] `streamable:""`
Data []byte `streamable:""`
}

// DecodeData decodes the data in the message to the provided type
func (m *Message) DecodeData(v interface{}) error {
return streamable.Unmarshal(m.Data, v)
}

// MakeMessage makes a new Message with the given data
func MakeMessage(messageType ProtocolMessageType, data interface{}) (*Message, error) {
msg := &Message{
ProtocolMessageType: messageType,
}

var dataBytes []byte
var err error
if data != nil {
dataBytes, err = streamable.Marshal(data)
if err != nil {
return nil, err
}
}

msg.Data = dataBytes

return msg, nil
}

// MakeMessageBytes calls MakeMessage and converts everything down to bytes
func MakeMessageBytes(messageType ProtocolMessageType, data interface{}) ([]byte, error) {
msg, err := MakeMessage(messageType, data)
if err != nil {
return nil, err
}

return streamable.Marshal(msg)
}

// DecodeMessage is a helper function to quickly decode bytes to Message
func DecodeMessage(bytes []byte) (*Message, error) {
msg := &Message{}

err := streamable.Unmarshal(bytes, msg)
if err != nil {
return nil, err
}

return msg, nil
}

// DecodeMessageData decodes a message.data into the given interface
func DecodeMessageData(bytes []byte, v interface{}) error {
msg, err := DecodeMessage(bytes)
if err != nil {
return err
}

return msg.DecodeData(v)
}
Loading

0 comments on commit 6e25810

Please sign in to comment.