From 394cd0673955b3e2473e7d9600222622aad73a31 Mon Sep 17 00:00:00 2001 From: bruwbird Date: Mon, 16 Sep 2024 12:45:53 +0900 Subject: [PATCH] multi: ignore non-peerswap related custom messages Ignore custom messages that are not relevant to peerswap. Additionally, it refactors the existing message conversion logic for better clarity and maintainability. Fixes https://github.com/ElementsProject/peerswap/issues/305 --- messages/errors.go | 25 ++++++++++--- messages/messages.go | 4 --- messages/messages_test.go | 75 --------------------------------------- messages/types.go | 60 ++++++++++++++++--------------- messages/types_test.go | 37 +++++++++++++++++++ poll/service.go | 51 +++++++++++++++----------- swap/service.go | 8 ++++- 7 files changed, 127 insertions(+), 133 deletions(-) delete mode 100644 messages/messages.go delete mode 100644 messages/messages_test.go create mode 100644 messages/types_test.go diff --git a/messages/errors.go b/messages/errors.go index 26531c5c..cdb011f5 100644 --- a/messages/errors.go +++ b/messages/errors.go @@ -2,10 +2,27 @@ package messages import "fmt" -var ( - ErrEvenMessageType = fmt.Errorf("message type is even") - ErrMessageNotInRange = fmt.Errorf("message type not in range") -) +// ErrNotPeerswapCustomMessage represents an error indicating +// that the message type is not a peerswap custom message. +type ErrNotPeerswapCustomMessage struct { + MessageType string +} + +// NewErrNotPeerswapCustomMessage creates a new ErrNotPeerswapCustomMessage with the given message type. +func NewErrNotPeerswapCustomMessage(messageType string) ErrNotPeerswapCustomMessage { + return ErrNotPeerswapCustomMessage{MessageType: messageType} +} + +// Error returns the error message for ErrNotPeerswapCustomMessage. +func (e ErrNotPeerswapCustomMessage) Error() string { + return fmt.Sprintf("message type %s is not a peerswap custom message", e.MessageType) +} + +// Is checks if the target error is of type ErrNotPeerswapCustomMessage. +func (e ErrNotPeerswapCustomMessage) Is(target error) bool { + _, ok := target.(*ErrNotPeerswapCustomMessage) + return ok +} type ErrAlreadyHasASender string diff --git a/messages/messages.go b/messages/messages.go deleted file mode 100644 index eb001a32..00000000 --- a/messages/messages.go +++ /dev/null @@ -1,4 +0,0 @@ -package messages - -// Message needs to have a MessageType. -type Message interface{} diff --git a/messages/messages_test.go b/messages/messages_test.go deleted file mode 100644 index f0deac67..00000000 --- a/messages/messages_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package messages - -import ( - "strconv" - "testing" -) - -func TestInRange(t *testing.T) { - type args struct { - msg string - } - tests := []struct { - name string - args args - want bool - wantErr bool - }{ - {"t1", args{MessageTypeToHexString(BASE_MESSAGE_TYPE - 1)}, false, true}, - {"t2", args{MessageTypeToHexString(BASE_MESSAGE_TYPE)}, true, false}, - {"t3", args{MessageTypeToHexString(MESSAGETYPE_SWAPINAGREEMENT)}, true, false}, - {"t4", args{MessageTypeToHexString(UPPER_MESSAGE_BOUND - 1)}, true, false}, - {"t5", args{MessageTypeToHexString(UPPER_MESSAGE_BOUND)}, false, true}, - {"t6", args{MessageTypeToHexString(UPPER_MESSAGE_BOUND + 1)}, false, false}, - {"t7", args{"z"}, false, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := inRangeStr(tt.args.msg) - if (err != nil) != tt.wantErr { - t.Errorf("InRange() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("InRange() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestHexStringToMessageType(t *testing.T) { - type args struct { - msgType string - } - tests := []struct { - name string - args args - want MessageType - wantErr bool - }{ - {"t1", args{MessageTypeToHexString(MESSAGETYPE_SWAPINREQUEST)}, MESSAGETYPE_SWAPINREQUEST, false}, - {"t2", args{MessageTypeToHexString(BASE_MESSAGE_TYPE + 1)}, 0, true}, - {"t3", args{MessageTypeToHexString(UPPER_MESSAGE_BOUND + 1)}, 0, true}, - {"t4", args{"z"}, 0, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := HexStringToMessageType(tt.args.msgType) - if (err != nil) != tt.wantErr { - t.Errorf("HexStringToMessageType() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("HexStringToMessageType() got = %v, want %v", got, tt.want) - } - }) - } -} - -func inRangeStr(msg string) (bool, error) { - msgInt, err := strconv.ParseInt(msg, 16, 64) - if err != nil { - return false, err - } - return InRange(MessageType(msgInt)) -} diff --git a/messages/types.go b/messages/types.go index b49c0d67..6fac9446 100644 --- a/messages/types.go +++ b/messages/types.go @@ -35,17 +35,41 @@ const ( MESSAGETYPE_POLL _ MESSAGETYPE_REQUEST_POLL - UPPER_MESSAGE_BOUND ) -// InRange checks if the message type lays in the -// peerswap message range. -func InRange(msgType MessageType) (bool, error) { - // MessageType we do not accept even message types - if msgType%2 == 0 { - return false, ErrEvenMessageType +// PeerswapCustomMessageType converts a hexadecimal string representation of a message type +// to its corresponding MessageType. If the message type is not recognized, it returns an error. +func PeerswapCustomMessageType(msgType string) (MessageType, error) { + // Parse the hexadecimal string to an integer. + msgTypeInt, err := strconv.ParseInt(msgType, 16, 64) + if err != nil { + return 0, fmt.Errorf("could not parse hex string to message type: %w", err) + } + + // Match the parsed integer to the corresponding MessageType. + switch MessageType(msgTypeInt) { + case MESSAGETYPE_SWAPINREQUEST: + return MESSAGETYPE_SWAPINREQUEST, nil + case MESSAGETYPE_SWAPOUTREQUEST: + return MESSAGETYPE_SWAPOUTREQUEST, nil + case MESSAGETYPE_SWAPINAGREEMENT: + return MESSAGETYPE_SWAPINAGREEMENT, nil + case MESSAGETYPE_SWAPOUTAGREEMENT: + return MESSAGETYPE_SWAPOUTAGREEMENT, nil + case MESSAGETYPE_OPENINGTXBROADCASTED: + return MESSAGETYPE_OPENINGTXBROADCASTED, nil + case MESSAGETYPE_CANCELED: + return MESSAGETYPE_CANCELED, nil + case MESSAGETYPE_COOPCLOSE: + return MESSAGETYPE_COOPCLOSE, nil + case MESSAGETYPE_POLL: + return MESSAGETYPE_POLL, nil + case MESSAGETYPE_REQUEST_POLL: + return MESSAGETYPE_REQUEST_POLL, nil + default: + // Return an error if the message type is not recognized. + return 0, NewErrNotPeerswapCustomMessage(msgType) } - return BASE_MESSAGE_TYPE <= msgType && msgType < UPPER_MESSAGE_BOUND, nil } // MessageTypeToHexStr returns the hex encoded string @@ -53,23 +77,3 @@ func InRange(msgType MessageType) (bool, error) { func MessageTypeToHexString(messageIndex MessageType) string { return strconv.FormatInt(int64(messageIndex), 16) } - -// HexStrToMsgType returns the message type from a -// hex encoded string. -func HexStringToMessageType(msgTypeStr string) (MessageType, error) { - msgTypeInt, err := strconv.ParseInt(msgTypeStr, 16, 64) - if err != nil { - return 0, fmt.Errorf("could not parse hex string to message type: %w", err) - } - - msgType := MessageType(msgTypeInt) - - inRange, err := InRange(msgType) - if err != nil { - return 0, err - } - if !inRange { - return 0, ErrMessageNotInRange - } - return msgType, nil -} diff --git a/messages/types_test.go b/messages/types_test.go new file mode 100644 index 00000000..0569a540 --- /dev/null +++ b/messages/types_test.go @@ -0,0 +1,37 @@ +package messages + +import "testing" + +func TestPeerswapCustomMessageType(t *testing.T) { + t.Parallel() + tests := map[string]struct { + msgType string + want MessageType + wantErr bool + }{ + "swapinrequest": {msgType: MessageTypeToHexString(MESSAGETYPE_SWAPINREQUEST), want: MESSAGETYPE_SWAPINREQUEST}, + "swapoutrequest": {msgType: MessageTypeToHexString(MESSAGETYPE_SWAPOUTREQUEST), want: MESSAGETYPE_SWAPOUTREQUEST}, + "swapinagreement": {msgType: MessageTypeToHexString(MESSAGETYPE_SWAPINAGREEMENT), want: MESSAGETYPE_SWAPINAGREEMENT}, + "swapoutagreement": {msgType: MessageTypeToHexString(MESSAGETYPE_SWAPOUTAGREEMENT), want: MESSAGETYPE_SWAPOUTAGREEMENT}, + "openintxbroadcasted": {msgType: MessageTypeToHexString(MESSAGETYPE_OPENINGTXBROADCASTED), want: MESSAGETYPE_OPENINGTXBROADCASTED}, + "canceled": {msgType: MessageTypeToHexString(MESSAGETYPE_CANCELED), want: MESSAGETYPE_CANCELED}, + "coopclose": {msgType: MessageTypeToHexString(MESSAGETYPE_COOPCLOSE), want: MESSAGETYPE_COOPCLOSE}, + "poll": {msgType: MessageTypeToHexString(MESSAGETYPE_POLL), want: MESSAGETYPE_POLL}, + "request_poll": {msgType: MessageTypeToHexString(MESSAGETYPE_REQUEST_POLL), want: MESSAGETYPE_REQUEST_POLL}, + "invalid": {msgType: "invalid", wantErr: true}, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + got, err := PeerswapCustomMessageType(tt.msgType) + if (err != nil) != tt.wantErr { + t.Errorf("PeerswapCustomMessageType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("PeerswapCustomMessageType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/poll/service.go b/poll/service.go index da1a6bd6..0ebe0f46 100644 --- a/poll/service.go +++ b/poll/service.go @@ -3,6 +3,7 @@ package poll import ( "context" "encoding/json" + "errors" "fmt" "sync" "time" @@ -167,62 +168,70 @@ func (s *Service) RequestAllPeerPolls() { // MessageHandler checks for the incoming messages // type and takes the incoming payload to update the // store. -func (s *Service) MessageHandler(peerId string, msgType string, payload []byte) error { - messageType, err := messages.HexStringToMessageType(msgType) +func (s *Service) MessageHandler(peerID, msgType string, payload []byte) error { + messageType, err := messages.PeerswapCustomMessageType(msgType) if err != nil { + // Check for specific errors: even message type or message out of range + // message type that peerswap is not interested in. + if errors.Is(err, &messages.ErrNotPeerswapCustomMessage{}) { + // These errors are expected and can be handled gracefully + return nil + } return err } switch messageType { case messages.MESSAGETYPE_POLL: var msg PollMessage - err = json.Unmarshal(payload, &msg) - if err != nil { - return err + if jerr := json.Unmarshal(payload, &msg); jerr != nil { + return jerr } - s.store.Update(peerId, PollInfo{ + if serr := s.store.Update(peerID, PollInfo{ ProtocolVersion: msg.Version, Assets: msg.Assets, PeerAllowed: msg.PeerAllowed, LastSeen: time.Now(), - }) - if ti, ok := s.tmpStore[peerId]; ok { + }); serr != nil { + return serr + } + if ti, ok := s.tmpStore[peerID]; ok { if ti == string(payload) { return nil } } if msg.Version != swap.PEERSWAP_PROTOCOL_VERSION { - log.Debugf("Received poll from INCOMPATIBLE peer %s: %s", peerId, string(payload)) + log.Debugf("Received poll from INCOMPATIBLE peer %s: %s", peerID, string(payload)) } else { - log.Debugf("Received poll from peer %s: %s", peerId, string(payload)) + log.Debugf("Received poll from peer %s: %s", peerID, string(payload)) } - s.tmpStore[peerId] = string(payload) + s.tmpStore[peerID] = string(payload) return nil case messages.MESSAGETYPE_REQUEST_POLL: var msg RequestPollMessage - err = json.Unmarshal([]byte(payload), &msg) - if err != nil { - return err + if jerr := json.Unmarshal(payload, &msg); jerr != nil { + return jerr } - s.store.Update(peerId, PollInfo{ + if serr := s.store.Update(peerID, PollInfo{ ProtocolVersion: msg.Version, Assets: msg.Assets, PeerAllowed: msg.PeerAllowed, LastSeen: time.Now(), - }) + }); serr != nil { + return serr + } // Send a poll on request - s.Poll(peerId) - if ti, ok := s.tmpStore[peerId]; ok { + s.Poll(peerID) + if ti, ok := s.tmpStore[peerID]; ok { if ti == string(payload) { return nil } } if msg.Version != swap.PEERSWAP_PROTOCOL_VERSION { - log.Debugf("Received poll from INCOMPATIBLE peer %s: %s", peerId, string(payload)) + log.Debugf("Received poll from INCOMPATIBLE peer %s: %s", peerID, string(payload)) } else { - log.Debugf("Received poll from peer %s: %s", peerId, string(payload)) + log.Debugf("Received poll from peer %s: %s", peerID, string(payload)) } - s.tmpStore[peerId] = string(payload) + s.tmpStore[peerID] = string(payload) return nil default: return nil diff --git a/swap/service.go b/swap/service.go index 5a4e45e3..7065d9fc 100644 --- a/swap/service.go +++ b/swap/service.go @@ -178,8 +178,14 @@ func (s *SwapService) OnMessageReceived(peerId string, msgTypeString string, pay if len(payload) > 100*1024 { return errors.New("Payload is unexpectedly large") } - msgType, err := messages.HexStringToMessageType(msgTypeString) + msgType, err := messages.PeerswapCustomMessageType(msgTypeString) if err != nil { + // Check for specific errors: even message type or message out of range + // message type that peerswap is not interested in. + if errors.Is(err, &messages.ErrNotPeerswapCustomMessage{}) { + // These errors are expected and can be handled gracefully + return nil + } return err } msgBytes := []byte(payload)