From f79a9eb595a09158bc9467edde38dbe746a70d7a Mon Sep 17 00:00:00 2001 From: Chris Marslender Date: Wed, 28 Jun 2023 10:59:22 -0500 Subject: [PATCH 1/3] Initial port of the peer protocol/streamable work from cmmarslender/go-chia-lib and cmmarslender/go-chia-protocol --- pkg/peerprotocol/connection.go | 193 ++++++++++++++ pkg/peerprotocol/connectionoptions.go | 16 ++ pkg/protocols/fullnode.go | 13 + pkg/protocols/fullnode_test.go | 40 +++ pkg/protocols/message.go | 71 +++++ pkg/protocols/message_test.go | 72 ++++++ pkg/protocols/messagetypes.go | 17 ++ pkg/protocols/shared.go | 55 ++++ pkg/streamable/errors.go | 23 ++ pkg/streamable/streamable.go | 358 ++++++++++++++++++++++++++ pkg/streamable/streamable_test.go | 221 ++++++++++++++++ pkg/types/peerinfo.go | 8 + pkg/util/bytes.go | 15 ++ pkg/util/bytes_test.go | 48 ++++ pkg/util/uints.go | 59 +++++ 15 files changed, 1209 insertions(+) create mode 100644 pkg/peerprotocol/connection.go create mode 100644 pkg/peerprotocol/connectionoptions.go create mode 100644 pkg/protocols/fullnode.go create mode 100644 pkg/protocols/fullnode_test.go create mode 100644 pkg/protocols/message.go create mode 100644 pkg/protocols/message_test.go create mode 100644 pkg/protocols/messagetypes.go create mode 100644 pkg/protocols/shared.go create mode 100644 pkg/streamable/errors.go create mode 100644 pkg/streamable/streamable.go create mode 100644 pkg/streamable/streamable_test.go create mode 100644 pkg/types/peerinfo.go create mode 100644 pkg/util/bytes_test.go create mode 100644 pkg/util/uints.go diff --git a/pkg/peerprotocol/connection.go b/pkg/peerprotocol/connection.go new file mode 100644 index 0000000..b801736 --- /dev/null +++ b/pkg/peerprotocol/connection.go @@ -0,0 +1,193 @@ +package protocol + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "github.com/gorilla/websocket" + + "github.com/chia-network/go-chia-libs/pkg/config" + "github.com/chia-network/go-chia-libs/pkg/protocols" +) + +// Connection represents a connection with a peer and enables communication +type Connection struct { + chiaConfig *config.ChiaConfig + + peerIP *net.IP + peerPort uint16 + peerKeyPair *tls.Certificate + peerDialer *websocket.Dialer + + handshakeTimeout time.Duration + conn *websocket.Conn +} + +// PeerResponseHandlerFunc is a function that will be called when a response is returned from a peer +type PeerResponseHandlerFunc func(*protocols.Message, error) + +// NewConnection creates a new connection object with the specified peer +func NewConnection(ip *net.IP, options ...ConnectionOptionFunc) (*Connection, error) { + cfg, err := config.GetChiaConfig() + if err != nil { + return nil, err + } + + c := &Connection{ + chiaConfig: cfg, + peerIP: ip, + peerPort: cfg.FullNode.Port, + } + + for _, fn := range options { + if fn == nil { + continue + } + if err := fn(c); err != nil { + return nil, err + } + } + + err = c.loadKeyPair() + if err != nil { + return nil, err + } + + // Generate the websocket dialer + err = c.generateDialer() + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *Connection) loadKeyPair() error { + var err error + + c.peerKeyPair, err = c.chiaConfig.FullNode.SSL.LoadPublicKeyPair(c.chiaConfig.ChiaRoot) + if err != nil { + return err + } + + return nil +} + +func (c *Connection) generateDialer() error { + if c.peerDialer == nil { + c.peerDialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: c.handshakeTimeout, + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{*c.peerKeyPair}, + InsecureSkipVerify: true, + }, + } + } + + return nil +} + +// ensureConnection ensures there is an open websocket connection +func (c *Connection) ensureConnection() error { + if c.conn == nil { + u := url.URL{Scheme: "wss", Host: fmt.Sprintf("%s:%d", c.peerIP.String(), c.peerPort), Path: "/ws"} + var err error + c.conn, _, err = c.peerDialer.Dial(u.String(), nil) + if err != nil { + return err + } + } + + return nil +} + +// Close closes the connection, if open +func (c *Connection) Close() { + if c.conn != nil { + err := c.conn.Close() + if err != nil { + return + } + c.conn = nil + } +} + +// Handshake performs the RPC handshake. This should be called before any other method +func (c *Connection) Handshake() error { + // Handshake + handshake := &protocols.Handshake{ + NetworkID: c.chiaConfig.SelectedNetwork, + ProtocolVersion: protocols.ProtocolVersion, + SoftwareVersion: "2.0.0", + ServerPort: c.peerPort, + NodeType: protocols.NodeTypeFullNode, // I guess we're a full node + Capabilities: []protocols.Capability{ + { + Capability: protocols.CapabilityTypeBase, + Value: "1", + }, + }, + } + + return c.Do(protocols.ProtocolMessageTypeHandshake, handshake) +} + +// Do send a request over the websocket +func (c *Connection) Do(messageType protocols.ProtocolMessageType, data interface{}) error { + err := c.ensureConnection() + if err != nil { + return err + } + + msgBytes, err := protocols.MakeMessageBytes(messageType, data) + if err != nil { + return err + } + + return c.conn.WriteMessage(websocket.BinaryMessage, msgBytes) +} + +// ReadSync Reads for async responses over the connection in a synchronous fashion, blocking anything else +func (c *Connection) ReadSync(handler PeerResponseHandlerFunc) error { + for { + _, bytes, err := c.conn.ReadMessage() + if err != nil { + // @TODO Handle Error + return err + + } + handler(protocols.DecodeMessage(bytes)) + } +} + +// ReadOne reads and returns one message from the connection +func (c *Connection) ReadOne(timeout time.Duration) (*protocols.Message, error) { + chBytes := make(chan []byte, 1) + chErr := make(chan error, 1) + ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + go c.readOneCtx(ctxTimeout, chBytes, chErr) + + select { + case <-ctxTimeout.Done(): + return nil, fmt.Errorf("context cancelled: %v", ctxTimeout.Err()) + case result := <-chBytes: + return protocols.DecodeMessage(result) + } +} + +func (c *Connection) readOneCtx(ctx context.Context, chBytes chan []byte, chErr chan error) { + _, bytes, err := c.conn.ReadMessage() + if err != nil { + chErr <- err + } + + chBytes <- bytes +} diff --git a/pkg/peerprotocol/connectionoptions.go b/pkg/peerprotocol/connectionoptions.go new file mode 100644 index 0000000..ba53393 --- /dev/null +++ b/pkg/peerprotocol/connectionoptions.go @@ -0,0 +1,16 @@ +package protocol + +import ( + "time" +) + +// ConnectionOptionFunc can be used to customize a new Connection +type ConnectionOptionFunc func(connection *Connection) error + +// WithHandshakeTimeout sets the handshake timeout +func WithHandshakeTimeout(timeout time.Duration) ConnectionOptionFunc { + return func(c *Connection) error { + c.handshakeTimeout = timeout + return nil + } +} diff --git a/pkg/protocols/fullnode.go b/pkg/protocols/fullnode.go new file mode 100644 index 0000000..7cfad67 --- /dev/null +++ b/pkg/protocols/fullnode.go @@ -0,0 +1,13 @@ +package protocols + +import ( + "github.com/chia-network/go-chia-libs/pkg/types" +) + +// RequestPeers is an empty struct +type RequestPeers struct{} + +// RespondPeers is the format for the request_peers response +type RespondPeers struct { + PeerList []types.TimestampedPeerInfo `streamable:""` +} diff --git a/pkg/protocols/fullnode_test.go b/pkg/protocols/fullnode_test.go new file mode 100644 index 0000000..3e996d8 --- /dev/null +++ b/pkg/protocols/fullnode_test.go @@ -0,0 +1,40 @@ +package protocols_test + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/chia-network/go-chia-libs/pkg/protocols" + "github.com/chia-network/go-chia-libs/pkg/streamable" +) + +func TestRespondPeers(t *testing.T) { + // Has one peer in the list + // IP 1.2.3.4 + // Port 8444 + // Timestamp 1643913969 + hexStr := "0000000100000007312e322e332e3420fc0000000061fc22f1" + + // Hex to bytes + encodedBytes, err := hex.DecodeString(hexStr) + assert.NoError(t, err) + + rp := &protocols.RespondPeers{} + + err = streamable.Unmarshal(encodedBytes, rp) + assert.NoError(t, err) + + assert.Len(t, rp.PeerList, 1) + + pl1 := rp.PeerList[0] + assert.Equal(t, "1.2.3.4", pl1.Host) + assert.Equal(t, uint16(8444), pl1.Port) + assert.Equal(t, uint64(1643913969), pl1.Timestamp) + + // Test going the other direction + reencodedBytes, err := streamable.Marshal(rp) + assert.NoError(t, err) + assert.Equal(t, encodedBytes, reencodedBytes) +} diff --git a/pkg/protocols/message.go b/pkg/protocols/message.go new file mode 100644 index 0000000..33d9d11 --- /dev/null +++ b/pkg/protocols/message.go @@ -0,0 +1,71 @@ +package protocols + +import ( + "github.com/chia-network/go-chia-libs/pkg/streamable" + + "github.com/samber/mo" +) + +// Message is a protocol message +type Message struct { + ProtocolMessageType ProtocolMessageType `streamable:""` + ID mo.Option[uint16] `streamable:""` + Data []byte `streamable:""` +} + +// DecodeData decodes the data in the message to the provided type +func (m *Message) DecodeData(v interface{}) error { + return streamable.Unmarshal(m.Data, v) +} + +// MakeMessage makes a new Message with the given data +func MakeMessage(messageType ProtocolMessageType, data interface{}) (*Message, error) { + msg := &Message{ + ProtocolMessageType: messageType, + } + + var dataBytes []byte + var err error + if data != nil { + dataBytes, err = streamable.Marshal(data) + if err != nil { + return nil, err + } + } + + msg.Data = dataBytes + + return msg, nil +} + +// MakeMessageBytes calls MakeMessage and converts everything down to bytes +func MakeMessageBytes(messageType ProtocolMessageType, data interface{}) ([]byte, error) { + msg, err := MakeMessage(messageType, data) + if err != nil { + return nil, err + } + + return streamable.Marshal(msg) +} + +// DecodeMessage is a helper function to quickly decode bytes to Message +func DecodeMessage(bytes []byte) (*Message, error) { + msg := &Message{} + + err := streamable.Unmarshal(bytes, msg) + if err != nil { + return nil, err + } + + return msg, nil +} + +// DecodeMessageData decodes a message.data into the given interface +func DecodeMessageData(bytes []byte, v interface{}) error { + msg, err := DecodeMessage(bytes) + if err != nil { + return err + } + + return msg.DecodeData(v) +} diff --git a/pkg/protocols/message_test.go b/pkg/protocols/message_test.go new file mode 100644 index 0000000..fae8104 --- /dev/null +++ b/pkg/protocols/message_test.go @@ -0,0 +1,72 @@ +package protocols_test + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/chia-network/go-chia-libs/pkg/protocols" +) + +func TestMakeMessage(t *testing.T) { + +} + +func TestMakeMessage_NilData(t *testing.T) { + msg, err := protocols.MakeMessage(protocols.ProtocolMessageTypeHandshake, nil) + assert.NoError(t, err) + assert.Equal(t, protocols.ProtocolMessageTypeHandshake, msg.ProtocolMessageType) + assert.False(t, msg.ID.IsPresent()) + assert.Equal(t, []byte(nil), msg.Data) +} + +func TestDecodeMessage(t *testing.T) { + //Message( + // uint8(ProtocolMessageTypes.handshake.value), + // None, + // bytes("This is a sample message to decode".encode(encoding = 'UTF-8', errors = 'string')) + //) + encodedHex := "0100000000225468697320697320612073616d706c65206d65737361676520746f206465636f6465" + + messageBytes, err := hex.DecodeString(encodedHex) + assert.NoError(t, err) + + msg, err := protocols.DecodeMessage(messageBytes) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, protocols.ProtocolMessageTypeHandshake, msg.ProtocolMessageType) + assert.False(t, msg.ID.IsPresent()) + assert.Equal(t, []byte("This is a sample message to decode"), msg.Data) +} + +func TestDecodeMessageData(t *testing.T) { + //Message( + // uint8(ProtocolMessageTypes.handshake.value), + // None, + // Handshake( + // "mainnet", + // "0.0.33", + // "1.2.11", + // uint16(8444), + // uint8(1), + // [(uint16(Capability.BASE.value), "1")], + // ) + //) + encodedHex := "01000000002d000000076d61696e6e657400000006302e302e333300000006312e322e313120fc010000000100010000000131" + + messageBytes, err := hex.DecodeString(encodedHex) + assert.NoError(t, err) + + handshake := &protocols.Handshake{} + err = protocols.DecodeMessageData(messageBytes, handshake) + assert.NoError(t, err) + assert.Equal(t, "mainnet", handshake.NetworkID) + assert.Equal(t, "0.0.33", handshake.ProtocolVersion) + assert.Equal(t, "1.2.11", handshake.SoftwareVersion) + assert.Equal(t, uint16(8444), handshake.ServerPort) + assert.Equal(t, protocols.NodeTypeFullNode, handshake.NodeType) + assert.IsType(t, []protocols.Capability{}, handshake.Capabilities) + assert.Len(t, handshake.Capabilities, 1) +} diff --git a/pkg/protocols/messagetypes.go b/pkg/protocols/messagetypes.go new file mode 100644 index 0000000..2bfc4cb --- /dev/null +++ b/pkg/protocols/messagetypes.go @@ -0,0 +1,17 @@ +package protocols + +// ProtocolMessageType corresponds to ProtocolMessageTypes in Chia +type ProtocolMessageType uint8 + +const ( + // ProtocolMessageTypeHandshake Handshake + ProtocolMessageTypeHandshake ProtocolMessageType = 1 + + // there are many more of these in Chia - only listing the ones current is use for now + + // ProtocolMessageTypeRequestPeers request_peers + ProtocolMessageTypeRequestPeers ProtocolMessageType = 43 + + // ProtocolMessageTypeRespondPeers respond_peers + ProtocolMessageTypeRespondPeers ProtocolMessageType = 44 +) diff --git a/pkg/protocols/shared.go b/pkg/protocols/shared.go new file mode 100644 index 0000000..3b48681 --- /dev/null +++ b/pkg/protocols/shared.go @@ -0,0 +1,55 @@ +package protocols + +// ProtocolVersion Current supported Protocol Version +// Not all of this is supported, but this was the current version at the time +// This library was started +const ProtocolVersion string = "0.0.33" + +// NodeType is the type of peer (farmer, full node, etc) +// Source for node types is chia/server/outbound_messages.py +type NodeType uint8 + +const ( + // NodeTypeFullNode NodeType for full node + NodeTypeFullNode NodeType = 1 + + // NodeTypeHarvester NodeType for Harvester + NodeTypeHarvester NodeType = 2 + + // NodeTypeFarmer NodeType for Farmer + NodeTypeFarmer NodeType = 3 + + // NodeTypeTimelord NodeType for Timelord + NodeTypeTimelord NodeType = 4 + + // NodeTypeIntroducer NodeType for Introducer + NodeTypeIntroducer NodeType = 5 + + // NodeTypeWallet NodeType for Wallet + NodeTypeWallet NodeType = 6 +) + +// CapabilityType is an internal references for types of capabilities +type CapabilityType uint16 + +const ( + // CapabilityTypeBase just means it supports the chia protocol at mainnet + CapabilityTypeBase CapabilityType = 1 +) + +// Capability reflects a capability of the peer +// This represents the Tuple that exists in the Python code +type Capability struct { + Capability CapabilityType `streamable:""` + Value string `streamable:""` +} + +// Handshake is a handshake message +type Handshake struct { + NetworkID string `streamable:""` + ProtocolVersion string `streamable:""` + SoftwareVersion string `streamable:""` + ServerPort uint16 `streamable:""` + NodeType NodeType `streamable:""` + Capabilities []Capability `streamable:""` // List[Tuple[uint16, str]] +} diff --git a/pkg/streamable/errors.go b/pkg/streamable/errors.go new file mode 100644 index 0000000..dc4cd88 --- /dev/null +++ b/pkg/streamable/errors.go @@ -0,0 +1,23 @@ +package streamable + +import ( + "reflect" +) + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +// Error outputs the error message and satisfies the Error interface +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "streamable: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "streamable: Unmarshal(nil " + e.Type.String() + ")" +} diff --git a/pkg/streamable/streamable.go b/pkg/streamable/streamable.go new file mode 100644 index 0000000..1c465b0 --- /dev/null +++ b/pkg/streamable/streamable.go @@ -0,0 +1,358 @@ +package streamable + +import ( + "encoding/binary" + "fmt" + "reflect" + "strings" + "unsafe" + + "github.com/chia-network/go-chia-libs/pkg/util" +) + +const ( + // Name of the struct tag used to identify the streamable properties + tagName = "streamable" + + // Bytes that indicate bool yes or no when serialized + boolFalse uint8 = 0 + boolTrue uint8 = 1 +) + +// Unmarshal unmarshals a streamable type based on struct tags +// Struct order is extremely important in this decoding. Ensure the order/types are identical +// on both sides of the stream +func Unmarshal(bytes []byte, v interface{}) error { + tv := reflect.ValueOf(v) + if tv.Kind() != reflect.Ptr || tv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + // Gets rid of the pointer + tv = reflect.Indirect(tv) + + // Get the actual type + t := tv.Type() + + if t.Kind() != reflect.Struct { + return fmt.Errorf("streamable can't unmarshal into non-struct type") + } + + _, err := unmarshalStruct(bytes, t, tv) + return err +} + +func unmarshalStruct(bytes []byte, t reflect.Type, tv reflect.Value) ([]byte, error) { + var err error + + // Iterate over all available fields and read the tag value + for i := 0; i < t.NumField(); i++ { + structField := t.Field(i) + fieldValue := tv.Field(i) + fieldType := fieldValue.Type() + + bytes, err = unmarshalField(bytes, fieldType, fieldValue, structField) + if err != nil { + return bytes, err + } + } + + return bytes, nil +} + +func unmarshalSlice(bytes []byte, t reflect.Type, v reflect.Value) ([]byte, error) { + var err error + var newVal []byte + + // Slice/List is 4 byte prefix (number of items) and then serialization of each item + // Get 4 byte length prefix + var length []byte + length, bytes, err = util.ShiftNBytes(4, bytes) + if err != nil { + return nil, err + } + numItems := binary.BigEndian.Uint32(length) + + sliceKind := t.Elem().Kind() + switch sliceKind { + case reflect.Uint8: // same as byte + // In this case, numItems == numBytes, because its a uint8 + newVal, bytes, err = util.ShiftNBytes(uint(numItems), bytes) + if err != nil { + return bytes, err + } + if !v.CanSet() { + return bytes, fmt.Errorf("field %s is not settable", v.String()) + } + + sliceReflect := reflect.MakeSlice(v.Type(), 0, 0) + for _, newValBytes := range newVal { + sliceReflect = reflect.Append(sliceReflect, reflect.ValueOf(newValBytes)) + } + v.Set(sliceReflect) + case reflect.Struct: + sliceReflect := reflect.MakeSlice(v.Type(), 0, 0) + for j := uint32(0); j < numItems; j++ { + newValue := reflect.Indirect(reflect.New(v.Type().Elem())) + bytes, err = unmarshalStruct(bytes, t.Elem(), newValue) + if err != nil { + return nil, err + } + sliceReflect = reflect.Append(sliceReflect, newValue) + } + v.Set(sliceReflect) + default: + return bytes, fmt.Errorf("encountered type inside slice that is not implemented") + } + + return bytes, nil +} + +// Struct field is used to parse out the streamable tag +// Not needed for anything else +// When recursively calling this on a wrapper type like mo.Option, pass the parent/wrapping StructField +func unmarshalField(bytes []byte, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField) ([]byte, error) { + var tagPresent bool + if _, tagPresent = structField.Tag.Lookup(tagName); !tagPresent { + // Continuing because the tag isn't present + return bytes, nil + } + + var err error + var newVal []byte + + // Optionals are handled with mo.Option + // There will be one byte bool that indicates whether the field is present + if strings.HasPrefix(fieldType.String(), "mo.Option[") { + var presentFlag []byte + presentFlag, bytes, err = util.ShiftNBytes(1, bytes) + if err != nil { + return nil, err + } + if presentFlag[0] == boolFalse { + return bytes, nil + } + + // The unsafe.Pointer(..) stuff in here is to be able to set unexported fields of mo.Option + // See https://stackoverflow.com/questions/42664837/how-to-access-unexported-struct-fields + + // First we set the present attr to true + presentField := fieldValue.Field(0) + presentFieldType := presentField.Type() + reflect.NewAt(presentFieldType, unsafe.Pointer(presentField.UnsafeAddr())).Elem().SetBool(true) + + // Then, we can parse out the value of the field and set the value attr + optionalField := fieldValue.Field(1) + optionalType := optionalField.Type() + writableField := reflect.NewAt(optionalType, unsafe.Pointer(optionalField.UnsafeAddr())).Elem() + bytes, err = unmarshalField(bytes, optionalType, writableField, structField) + if err != nil { + return bytes, err + } + + return bytes, nil + } + + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + + // Need to init the field to something non-nil before using it + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + fieldValue = fieldValue.Elem() + } + + switch kind := fieldType.Kind(); kind { + case reflect.Uint8: + newVal, bytes, err = util.ShiftNBytes(1, bytes) + if err != nil { + return bytes, err + } + if !fieldValue.CanSet() { + return bytes, fmt.Errorf("field %s is not settable", fieldValue.String()) + } + fieldValue.SetUint(uint64(util.BytesToUint8(newVal))) + case reflect.Uint16: + newVal, bytes, err = util.ShiftNBytes(2, bytes) + if err != nil { + return bytes, err + } + if !fieldValue.CanSet() { + return bytes, fmt.Errorf("field %s is not settable", fieldValue.String()) + } + newInt := util.BytesToUint16(newVal) + fieldValue.SetUint(uint64(newInt)) + case reflect.Uint32: + newVal, bytes, err = util.ShiftNBytes(4, bytes) + if err != nil { + return bytes, err + } + if !fieldValue.CanSet() { + return bytes, fmt.Errorf("field %s is not settable", fieldValue.String()) + } + newInt := util.BytesToUint32(newVal) + fieldValue.SetUint(uint64(newInt)) + case reflect.Uint64: + newVal, bytes, err = util.ShiftNBytes(8, bytes) + if err != nil { + return bytes, err + } + if !fieldValue.CanSet() { + return bytes, fmt.Errorf("field %s is not settable", fieldValue.String()) + } + newInt := util.BytesToUint64(newVal) + fieldValue.SetUint(newInt) + case reflect.Slice: + bytes, err = unmarshalSlice(bytes, fieldType, fieldValue) + if err != nil { + return bytes, err + } + case reflect.String: + // 4 byte size prefix, then []byte which can be converted to utf-8 string + // Get 4 byte length prefix + var length []byte + length, bytes, err = util.ShiftNBytes(4, bytes) + if err != nil { + return nil, err + } + numBytes := binary.BigEndian.Uint32(length) + + var strBytes []byte + strBytes, bytes, err = util.ShiftNBytes(uint(numBytes), bytes) + if err != nil { + return nil, err + } + fieldValue.SetString(string(strBytes)) + default: + return bytes, fmt.Errorf("unimplemented type %s", fieldValue.Kind()) + } + + return bytes, nil +} + +// Marshal marshals the item into the streamable byte format +func Marshal(v interface{}) ([]byte, error) { + // Doesn't matter if a pointer or not for marshalling, so + // we just call this and let it deal with ptr or not ptr + tv := reflect.Indirect(reflect.ValueOf(v)) + + // Get the actual type + t := tv.Type() + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("streamable can't marshal a non-struct type") + } + + // This will become the final encoded data + var finalBytes []byte + + return marshalStruct(finalBytes, t, tv) +} + +func marshalStruct(finalBytes []byte, t reflect.Type, tv reflect.Value) ([]byte, error) { + var err error + + // Iterate over all available fields in the type and encode to bytes + for i := 0; i < t.NumField(); i++ { + structField := t.Field(i) + fieldValue := tv.Field(i) + fieldType := fieldValue.Type() + + finalBytes, err = marshalField(finalBytes, fieldType, fieldValue, structField) + if err != nil { + return finalBytes, err + } + } + + return finalBytes, nil +} + +func marshalField(finalBytes []byte, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField) ([]byte, error) { + var err error + + var tagPresent bool + if _, tagPresent = structField.Tag.Lookup(tagName); !tagPresent { + // Continuing because the tag isn't present + return finalBytes, nil + } + + // Optionals are handled with mo.Option + // If the value is not present, 0x00 and move on + // otherwise 0x01 and encode the value + if strings.HasPrefix(fieldType.String(), "mo.Option[") { + isPresent := fieldValue.MethodByName("IsPresent").Call([]reflect.Value{})[0].Bool() + + if !isPresent { + // Field is not present, insert `false` byte and continue + finalBytes = append(finalBytes, boolFalse) + return finalBytes, nil + } + + finalBytes = append(finalBytes, boolTrue) + + // Get the underlying value and encode it + optionalVal := fieldValue.MethodByName("MustGet").Call([]reflect.Value{})[0] + optionalType := optionalVal.Type() + + return marshalField(finalBytes, optionalType, optionalVal, structField) + } + + // If field is still a pointer, get rid of that now that we're past the optional checking + fieldValue = reflect.Indirect(fieldValue) + + switch fieldValue.Kind() { + case reflect.Uint8: + newInt := uint8(fieldValue.Uint()) + finalBytes = append(finalBytes, newInt) + case reflect.Uint16: + newInt := uint16(fieldValue.Uint()) + finalBytes = append(finalBytes, util.Uint16ToBytes(newInt)...) + case reflect.Uint32: + newInt := uint32(fieldValue.Uint()) + finalBytes = append(finalBytes, util.Uint32ToBytes(newInt)...) + case reflect.Uint64: + finalBytes = append(finalBytes, util.Uint64ToBytes(fieldValue.Uint())...) + case reflect.Slice: + finalBytes, err = marshalSlice(finalBytes, fieldType, fieldValue) + if err != nil { + return finalBytes, err + } + case reflect.String: + // Strings get converted to []byte with a 4 byte size prefix + strBytes := []byte(fieldValue.String()) + numBytes := uint32(len(strBytes)) + finalBytes = append(finalBytes, util.Uint32ToBytes(numBytes)...) + + finalBytes = append(finalBytes, strBytes...) + default: + return finalBytes, fmt.Errorf("unimplemented type %s", fieldValue.Kind()) + } + + return finalBytes, nil +} + +func marshalSlice(finalBytes []byte, t reflect.Type, v reflect.Value) ([]byte, error) { + var err error + + // Slice/List is 4 byte prefix (number of items) and then serialization of each item + // Get 4 byte length prefix + numItems := uint32(v.Len()) + finalBytes = append(finalBytes, util.Uint32ToBytes(numItems)...) + + sliceKind := t.Elem().Kind() + switch sliceKind { + case reflect.Uint8: // same as byte + // This is the easy case - already a slice of bytes + finalBytes = append(finalBytes, v.Bytes()...) + case reflect.Struct: + for j := 0; j < v.Len(); j++ { + currentStruct := v.Index(j) + + finalBytes, err = marshalStruct(finalBytes, currentStruct.Type(), currentStruct) + if err != nil { + return finalBytes, err + } + } + } + + return finalBytes, nil +} diff --git a/pkg/streamable/streamable_test.go b/pkg/streamable/streamable_test.go new file mode 100644 index 0000000..432dd2c --- /dev/null +++ b/pkg/streamable/streamable_test.go @@ -0,0 +1,221 @@ +package streamable_test + +import ( + "encoding/hex" + "testing" + + "github.com/samber/mo" + "github.com/stretchr/testify/assert" + + "github.com/chia-network/go-chia-libs/pkg/protocols" + "github.com/chia-network/go-chia-libs/pkg/streamable" +) + +const ( + //Message( + // uint8(ProtocolMessageTypes.handshake.value), + // None, + // bytes("This is a sample message to decode".encode(encoding = 'UTF-8', errors = 'string')) + //) + encodedHex1 string = "0100000000225468697320697320612073616d706c65206d65737361676520746f206465636f6465" + + //Message( + // uint8(ProtocolMessageTypes.handshake.value), + // uint16(35256), + // bytes("This is a sample message to decode".encode(encoding = 'UTF-8', errors = 'string')) + //) + encodedHex2 string = "010189b8000000225468697320697320612073616d706c65206d65737361676520746f206465636f6465" + + //Message( + // uint8(ProtocolMessageTypes.handshake.value), + // None, + // Handshake( + // "mainnet", + // "0.0.33", + // "1.2.11", + // uint16(8444), + // uint8(1), + // [(uint16(Capability.BASE.value), "1")], + // ) + //) + encodedHexHandshake string = "01000000002d000000076d61696e6e657400000006302e302e333300000006312e322e313120fc010000000100010000000131" +) + +func TestUnmarshal_Message1(t *testing.T) { + // Hex to bytes + encodedBytes, err := hex.DecodeString(encodedHex1) + assert.NoError(t, err) + + // test that nil is not accepted + err = streamable.Unmarshal(encodedBytes, nil) + assert.Error(t, err) + + msg := &protocols.Message{ + ProtocolMessageType: 0, + Data: nil, + } + + // Test that pointers are required + err = streamable.Unmarshal(encodedBytes, *msg) + assert.Error(t, err) + + err = streamable.Unmarshal(encodedBytes, msg) + + assert.NoError(t, err) + assert.Equal(t, protocols.ProtocolMessageTypeHandshake, msg.ProtocolMessageType) + assert.False(t, msg.ID.IsPresent()) + assert.Equal(t, []byte("This is a sample message to decode"), msg.Data) +} + +func TestMarshal_Message1(t *testing.T) { + encodedBytes, err := hex.DecodeString(encodedHex1) + assert.NoError(t, err) + + msg := &protocols.Message{ + ProtocolMessageType: protocols.ProtocolMessageTypeHandshake, + Data: []byte("This is a sample message to decode"), + } + + bytes, err := streamable.Marshal(msg) + + assert.NoError(t, err) + assert.Equal(t, encodedBytes, bytes) +} + +// Unmarshals fully then remarshals to ensure we can go back and forth +func TestUnmarshal_Remarshal_Message1(t *testing.T) { + encodedBytes, err := hex.DecodeString(encodedHex1) + assert.NoError(t, err) + + msg := &protocols.Message{} + + err = streamable.Unmarshal(encodedBytes, msg) + assert.NoError(t, err) + + // Remarshal and check against original bytes + reencodedBytes, err := streamable.Marshal(msg) + assert.NoError(t, err) + assert.Equal(t, encodedBytes, reencodedBytes) +} + +func TestUnmarshal_Message2(t *testing.T) { + // Hex to bytes + encodedBytes, err := hex.DecodeString(encodedHex2) + assert.NoError(t, err) + + // test that nil is not accepted + err = streamable.Unmarshal(encodedBytes, nil) + assert.Error(t, err) + + msg := &protocols.Message{ + ProtocolMessageType: 0, + Data: nil, + } + + // Test that pointers are required + err = streamable.Unmarshal(encodedBytes, *msg) + assert.Error(t, err) + + err = streamable.Unmarshal(encodedBytes, msg) + + assert.NoError(t, err) + assert.Equal(t, protocols.ProtocolMessageTypeHandshake, msg.ProtocolMessageType) + assert.True(t, msg.ID.IsPresent()) + assert.Equal(t, uint16(35256), msg.ID.MustGet()) + assert.Equal(t, []byte("This is a sample message to decode"), msg.Data) +} + +func TestMarshal_Message2(t *testing.T) { + encodedBytes, err := hex.DecodeString(encodedHex2) + assert.NoError(t, err) + + msg := &protocols.Message{ + ProtocolMessageType: protocols.ProtocolMessageTypeHandshake, + ID: mo.Some(uint16(35256)), + Data: []byte("This is a sample message to decode"), + } + + bytes, err := streamable.Marshal(msg) + + assert.NoError(t, err) + assert.Equal(t, encodedBytes, bytes) +} + +// Unmarshals fully then remarshals to ensure we can go back and forth +func TestUnmarshal_Remarshal_Message2(t *testing.T) { + encodedBytes, err := hex.DecodeString(encodedHex2) + assert.NoError(t, err) + + msg := &protocols.Message{} + + err = streamable.Unmarshal(encodedBytes, msg) + assert.NoError(t, err) + + // Remarshal and check against original bytes + reencodedBytes, err := streamable.Marshal(msg) + assert.NoError(t, err) + assert.Equal(t, encodedBytes, reencodedBytes) +} + +func TestUnmarshal_Handshake(t *testing.T) { + // Hex to bytes + encodedBytes, err := hex.DecodeString(encodedHexHandshake) + assert.NoError(t, err) + + msg := &protocols.Message{} + + err = streamable.Unmarshal(encodedBytes, msg) + + assert.NoError(t, err) + assert.Equal(t, protocols.ProtocolMessageTypeHandshake, msg.ProtocolMessageType) + assert.False(t, msg.ID.IsPresent()) + + // No decode the handshake portion + handshake := &protocols.Handshake{} + + // Handshake( + // "mainnet", + // "0.0.33", + // "1.2.11", + // uint16(8444), + // uint8(1), + // [(uint16(Capability.BASE.value), "1")], + // ) + + err = streamable.Unmarshal(msg.Data, handshake) + assert.NoError(t, err) + assert.Equal(t, "mainnet", handshake.NetworkID) + assert.Equal(t, "0.0.33", handshake.ProtocolVersion) + assert.Equal(t, "1.2.11", handshake.SoftwareVersion) + assert.Equal(t, uint16(8444), handshake.ServerPort) + assert.Equal(t, protocols.NodeTypeFullNode, handshake.NodeType) + assert.IsType(t, []protocols.Capability{}, handshake.Capabilities) + assert.Len(t, handshake.Capabilities, 1) + + // Test each capability item + cap1 := handshake.Capabilities[0] + + assert.Equal(t, protocols.CapabilityTypeBase, cap1.Capability) + assert.Equal(t, "1", cap1.Value) +} + +// Unmarshals fully then remarshals to ensure we can go back and forth +func TestUnmarshal_Remarshal_Handshake(t *testing.T) { + encodedBytes, err := hex.DecodeString(encodedHexHandshake) + assert.NoError(t, err) + + msg := &protocols.Message{} + + err = streamable.Unmarshal(encodedBytes, msg) + assert.NoError(t, err) + + handshake := &protocols.Handshake{} + + err = streamable.Unmarshal(msg.Data, handshake) + assert.NoError(t, err) + + // Remarshal and check against original bytes + reencodedBytes, err := protocols.MakeMessageBytes(msg.ProtocolMessageType, handshake) + assert.NoError(t, err) + assert.Equal(t, encodedBytes, reencodedBytes) +} diff --git a/pkg/types/peerinfo.go b/pkg/types/peerinfo.go new file mode 100644 index 0000000..d1638ad --- /dev/null +++ b/pkg/types/peerinfo.go @@ -0,0 +1,8 @@ +package types + +// TimestampedPeerInfo contains information about peers with timestamps +type TimestampedPeerInfo struct { + Host string `streamable:""` + Port uint16 `streamable:""` + Timestamp uint64 `streamable:""` +} diff --git a/pkg/util/bytes.go b/pkg/util/bytes.go index 5873cf6..2092085 100644 --- a/pkg/util/bytes.go +++ b/pkg/util/bytes.go @@ -33,3 +33,18 @@ func FormatBytes(bytes types.Uint128) string { return fmt.Sprintf("%s %s", value.String(), labels[len(labels)-1]) } + +// ShiftNBytes returns the specified number of bytes from the start of the provided []byte +// and removes them from the original byte slice +// First returned value is the requested number of bytes from the beginning of the original byte slice +// Second returned value is the new original byte slice with the requested number of bytes removed from the front of it +func ShiftNBytes(numBytes uint, bytes []byte) ([]byte, []byte, error) { + if uint(len(bytes)) < numBytes { + return nil, bytes, fmt.Errorf("requested more bytes than available") + } + + requestedBytes := bytes[:numBytes] + bytes = bytes[numBytes:] + + return requestedBytes, bytes, nil +} diff --git a/pkg/util/bytes_test.go b/pkg/util/bytes_test.go new file mode 100644 index 0000000..0c7b901 --- /dev/null +++ b/pkg/util/bytes_test.go @@ -0,0 +1,48 @@ +package util_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/chia-network/go-chia-libs/pkg/util" +) + +func TestShiftNBytes(t *testing.T) { + origBytes := []byte{ + uint8(0), + uint8(1), + uint8(2), + uint8(3), + uint8(4), + uint8(5), + uint8(6), + uint8(7), + } + + // Ensure we're in a good starting place before changing things around + assert.Equal(t, 8, len(origBytes)) + assert.Equal(t, uint8(0), origBytes[0]) + assert.Equal(t, uint8(7), origBytes[7]) + + shift2, origBytes, err := util.ShiftNBytes(2, origBytes) + + assert.NoError(t, err) + + // Check expected lengths + assert.Len(t, shift2, 2) + assert.Len(t, origBytes, 6) + + // Check actual expected values + assert.Equal(t, uint8(0), shift2[0]) + assert.Equal(t, uint8(1), shift2[1]) + assert.Equal(t, uint8(2), origBytes[0]) + assert.Equal(t, uint8(7), origBytes[5]) + + // Test pulling off too many bytes + shiftTooMany, origBytes, err := util.ShiftNBytes(7, origBytes) + + assert.Error(t, err) + assert.Nil(t, shiftTooMany) + assert.Len(t, origBytes, 6) +} diff --git a/pkg/util/uints.go b/pkg/util/uints.go new file mode 100644 index 0000000..0c83a4a --- /dev/null +++ b/pkg/util/uints.go @@ -0,0 +1,59 @@ +package util + +import ( + "encoding/binary" +) + +// Uint8ToBytes Converts uint8 to []byte +// Kind of pointless, since byte is uint8, but here for consistency with the other methods +func Uint8ToBytes(num uint8) []byte { + return []byte{num} +} + +// BytesToUint8 returns uint8 from []byte +// if you have more than one byte in your []byte this wont work like you think +func BytesToUint8(bytes []byte) uint8 { + return bytes[0] +} + +// Uint16ToBytes Converts uint16 to []byte +func Uint16ToBytes(num uint16) []byte { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, num) + + return b +} + +// BytesToUint16 returns uint16 from []byte +// if you have more than two bytes in your []byte this wont work like you think +func BytesToUint16(bytes []byte) uint16 { + return binary.BigEndian.Uint16(bytes) +} + +// Uint32ToBytes Converts uint32 to []byte +func Uint32ToBytes(num uint32) []byte { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, num) + + return b +} + +// BytesToUint32 returns uint32 from []byte +// if you have more than four bytes in your []byte this wont work like you think +func BytesToUint32(bytes []byte) uint32 { + return binary.BigEndian.Uint32(bytes) +} + +// Uint64ToBytes Converts uint64 to []byte +func Uint64ToBytes(num uint64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, num) + + return b +} + +// BytesToUint64 returns uint64 from []byte +// if you have more than eight bytes in your []byte this wont work like you think +func BytesToUint64(bytes []byte) uint64 { + return binary.BigEndian.Uint64(bytes) +} From 590cfec82156e0af3870f0bac436cd439fde96dc Mon Sep 17 00:00:00 2001 From: Chris Marslender Date: Wed, 28 Jun 2023 11:06:05 -0500 Subject: [PATCH 2/3] Add full node protocol with request_peers for now --- pkg/peerprotocol/connection.go | 2 +- pkg/peerprotocol/connectionoptions.go | 2 +- pkg/peerprotocol/fullnode.go | 22 ++++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 pkg/peerprotocol/fullnode.go diff --git a/pkg/peerprotocol/connection.go b/pkg/peerprotocol/connection.go index b801736..1eea00c 100644 --- a/pkg/peerprotocol/connection.go +++ b/pkg/peerprotocol/connection.go @@ -1,4 +1,4 @@ -package protocol +package peerprotocol import ( "context" diff --git a/pkg/peerprotocol/connectionoptions.go b/pkg/peerprotocol/connectionoptions.go index ba53393..29cd4c8 100644 --- a/pkg/peerprotocol/connectionoptions.go +++ b/pkg/peerprotocol/connectionoptions.go @@ -1,4 +1,4 @@ -package protocol +package peerprotocol import ( "time" diff --git a/pkg/peerprotocol/fullnode.go b/pkg/peerprotocol/fullnode.go new file mode 100644 index 0000000..cc9a9e7 --- /dev/null +++ b/pkg/peerprotocol/fullnode.go @@ -0,0 +1,22 @@ +package peerprotocol + +import ( + "github.com/chia-network/go-chia-libs/pkg/protocols" +) + +// FullNodeProtocol is for interfacing with full nodes via the peer protocol +type FullNodeProtocol struct { + connection *Connection +} + +// NewFullNodeProtocol returns a new instance of the full node protocol +func NewFullNodeProtocol(connection *Connection) (*FullNodeProtocol, error) { + fnp := &FullNodeProtocol{connection: connection} + + return fnp, nil +} + +// RequestPeers asks the current peer to respond with their current peer list +func (c *FullNodeProtocol) RequestPeers() error { + return c.connection.Do(protocols.ProtocolMessageTypeRequestPeers, &protocols.RequestPeers{}) +} From bbc7029e96a674e64524dae2c245db57dfb5b96f Mon Sep 17 00:00:00 2001 From: Chris Marslender Date: Wed, 28 Jun 2023 11:16:59 -0500 Subject: [PATCH 3/3] Add readme --- pkg/streamable/readme.md | 41 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 pkg/streamable/readme.md diff --git a/pkg/streamable/readme.md b/pkg/streamable/readme.md new file mode 100644 index 0000000..ad6b3a6 --- /dev/null +++ b/pkg/streamable/readme.md @@ -0,0 +1,41 @@ +# Streamable + +This package implements the chia streamable format. Not all aspects of the streamable format are fully implemented, and +support for more types are added as protocol messages are added to this package. This is not intended to be used in +consensus critical applications and there may be unexpected errors for untested streamable objects. + +For more information on the streamable format, see the [streamable docs](https://docs.chia.net/serialization-protocol?_highlight=strea#streamable-format) + +## How to Use + +When defining structs that are streamable, the order of the fields is extremely important, and should match the order +of the fields in [chia-blockchain](https://github.com/chia-network/chia-blockchain). To support struct fields that are +not defined in chia-blockchain, streamable objects require a `streamable` tag on each field of the struct that should be +streamed. + +**Example Type Definition:** + +```go +// TimestampedPeerInfo contains information about peers with timestamps +type TimestampedPeerInfo struct { + Host string `streamable:""` + Port uint16 `streamable:""` + Timestamp uint64 `streamable:""` +} +``` + +For a given streamable object, the interface is very similar to json marshal/unmarshal. + +**Encode to Bytes** + +```go +peerInfo := &TimestampedPeerInfo{....} +bytes, err := streamable.Marshal(peerInfo) +``` + +**Decode to Struct** + +```go +peerInfo := &TimestampedPeerInfo{} +err := streamable.Unmarshal(bytes, peerInfo) +```