From 00f7e4fa23a461c0820633fd92d16dab9cf9a614 Mon Sep 17 00:00:00 2001 From: arvidn Date: Fri, 3 Nov 2023 11:43:43 +0100 Subject: [PATCH 1/3] use separate protocol versions for the different services --- chia/_tests/connection_utils.py | 2 +- chia/_tests/core/server/test_server.py | 4 ++-- chia/_tests/core/ssl/test_ssl.py | 4 +++- chia/protocols/shared_protocol.py | 11 ++++++++++- chia/server/server.py | 8 ++++++-- chia/server/ws_connection.py | 7 +++++-- 6 files changed, 27 insertions(+), 9 deletions(-) diff --git a/chia/_tests/connection_utils.py b/chia/_tests/connection_utils.py index cf204aa2c58d..0ebfef816ebf 100644 --- a/chia/_tests/connection_utils.py +++ b/chia/_tests/connection_utils.py @@ -88,7 +88,7 @@ async def add_dummy_connection_wsc( 30, local_capabilities_for_handshake=capabilities, ) - await wsc.perform_handshake(server._network_id, protocol_version, dummy_port, type) + await wsc.perform_handshake(server._network_id, protocol_version[type], dummy_port, type) if wsc.incoming_message_task is not None: wsc.incoming_message_task.cancel() return wsc, peer_id diff --git a/chia/_tests/core/server/test_server.py b/chia/_tests/core/server/test_server.py index 3811c7235bfb..612a0bca4664 100644 --- a/chia/_tests/core/server/test_server.py +++ b/chia/_tests/core/server/test_server.py @@ -16,7 +16,7 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.shared_protocol import Error, protocol_version from chia.protocols.wallet_protocol import RejectHeaderRequest -from chia.server.outbound_message import make_msg +from chia.server.outbound_message import NodeType, make_msg from chia.server.server import ChiaServer from chia.server.start_full_node import create_full_node_service from chia.server.start_wallet import create_wallet_service @@ -80,7 +80,7 @@ async def test_connection_versions( outgoing_connection = wallet_node.server.all_connections[full_node.server.node_id] incoming_connection = full_node.server.all_connections[wallet_node.server.node_id] for connection in [outgoing_connection, incoming_connection]: - assert connection.protocol_version == Version(protocol_version) + assert connection.protocol_version == Version(protocol_version[NodeType.FULL_NODE]) assert connection.version == __version__ assert connection.get_version() == connection.version diff --git a/chia/_tests/core/ssl/test_ssl.py b/chia/_tests/core/ssl/test_ssl.py index ef1976847e3f..00bc74f7ca4f 100644 --- a/chia/_tests/core/ssl/test_ssl.py +++ b/chia/_tests/core/ssl/test_ssl.py @@ -33,7 +33,9 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte 30, local_capabilities_for_handshake=capabilities, ) - await wsc.perform_handshake(server._network_id, protocol_version, dummy_port, NodeType.FULL_NODE) + await wsc.perform_handshake( + server._network_id, protocol_version[NodeType.FULL_NODE], dummy_port, NodeType.FULL_NODE + ) await wsc.close() diff --git a/chia/protocols/shared_protocol.py b/chia/protocols/shared_protocol.py index 36d495d65541..a51c3fac62d7 100644 --- a/chia/protocols/shared_protocol.py +++ b/chia/protocols/shared_protocol.py @@ -4,10 +4,19 @@ from enum import IntEnum from typing import List, Optional, Tuple +from chia.server.outbound_message import NodeType from chia.util.ints import int16, uint8, uint16 from chia.util.streamable import Streamable, streamable -protocol_version = "0.0.36" +protocol_version = { + NodeType.FULL_NODE: "0.0.36", + NodeType.HARVESTER: "0.0.36", + NodeType.FARMER: "0.0.36", + NodeType.TIMELORD: "0.0.36", + NodeType.INTRODUCER: "0.0.36", + NodeType.WALLET: "0.0.36", + NodeType.DATA_LAYER: "0.0.36", +} """ diff --git a/chia/server/server.py b/chia/server/server.py index fe70ee355d7d..29860b5d7b37 100644 --- a/chia/server/server.py +++ b/chia/server/server.py @@ -334,7 +334,9 @@ async def incoming_connection(self, request: web.Request) -> web.StreamResponse: outbound_rate_limit_percent=self._outbound_rate_limit_percent, local_capabilities_for_handshake=self._local_capabilities_for_handshake, ) - await connection.perform_handshake(self._network_id, protocol_version, self.get_port(), self._local_type) + await connection.perform_handshake( + self._network_id, protocol_version[self._local_type], self.get_port(), self._local_type + ) assert connection.connection_type is not None, "handshake failed to set connection type, still None" # Limit inbound connections to config's specifications. @@ -485,7 +487,9 @@ async def start_client( local_capabilities_for_handshake=self._local_capabilities_for_handshake, session=session, ) - await connection.perform_handshake(self._network_id, protocol_version, server_port, self._local_type) + await connection.perform_handshake( + self._network_id, protocol_version[self._local_type], server_port, self._local_type + ) await self.connection_added(connection, on_connect) # the session has been adopted by the connection, don't close it at # the end of the function diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index 63c1683015b1..c8cda6d30060 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -222,10 +222,11 @@ async def perform_handshake( if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) + inbound_node_type = NodeType(inbound_handshake.node_type) self.version = inbound_handshake.software_version self.protocol_version = Version(inbound_handshake.protocol_version) self.peer_server_port = inbound_handshake.server_port - self.connection_type = NodeType(inbound_handshake.node_type) + self.connection_type = inbound_node_type # "1" means capability is enabled self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities) else: @@ -249,11 +250,13 @@ async def perform_handshake( inbound_handshake = Handshake.from_bytes(message.data) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) + + inbound_node_type = NodeType(inbound_handshake.node_type) await self._send_message(outbound_handshake) self.version = inbound_handshake.software_version self.protocol_version = Version(inbound_handshake.protocol_version) self.peer_server_port = inbound_handshake.server_port - self.connection_type = NodeType(inbound_handshake.node_type) + self.connection_type = inbound_node_type # "1" means capability is enabled self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities) From e373852e2410775ed8a49d3e4d70ebc83329e07f Mon Sep 17 00:00:00 2001 From: Harold Brenes Date: Fri, 3 Nov 2023 17:33:05 -0400 Subject: [PATCH 2/3] Validate protocol version on handshake --- chia/_tests/connection_utils.py | 4 +- chia/_tests/core/ssl/test_ssl.py | 6 +-- chia/server/server.py | 9 +---- chia/server/ws_connection.py | 66 ++++++++++++++++++++++++-------- 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/chia/_tests/connection_utils.py b/chia/_tests/connection_utils.py index 0ebfef816ebf..e73694af267e 100644 --- a/chia/_tests/connection_utils.py +++ b/chia/_tests/connection_utils.py @@ -11,7 +11,7 @@ from cryptography.hazmat.primitives import hashes, serialization from chia._tests.util.time_out_assert import time_out_assert -from chia.protocols.shared_protocol import capabilities, protocol_version +from chia.protocols.shared_protocol import capabilities from chia.server.outbound_message import NodeType from chia.server.server import ChiaServer, ssl_context_for_client from chia.server.ssl_context import chia_ssl_ca_paths, private_ssl_ca_paths @@ -88,7 +88,7 @@ async def add_dummy_connection_wsc( 30, local_capabilities_for_handshake=capabilities, ) - await wsc.perform_handshake(server._network_id, protocol_version[type], dummy_port, type) + await wsc.perform_handshake(server._network_id, dummy_port, type) if wsc.incoming_message_task is not None: wsc.incoming_message_task.cancel() return wsc, peer_id diff --git a/chia/_tests/core/ssl/test_ssl.py b/chia/_tests/core/ssl/test_ssl.py index 00bc74f7ca4f..9d0943f567da 100644 --- a/chia/_tests/core/ssl/test_ssl.py +++ b/chia/_tests/core/ssl/test_ssl.py @@ -3,7 +3,7 @@ import aiohttp import pytest -from chia.protocols.shared_protocol import capabilities, protocol_version +from chia.protocols.shared_protocol import capabilities from chia.server.outbound_message import NodeType from chia.server.server import ChiaServer, ssl_context_for_client from chia.server.ssl_context import chia_ssl_ca_paths, private_ssl_ca_paths @@ -33,9 +33,7 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte 30, local_capabilities_for_handshake=capabilities, ) - await wsc.perform_handshake( - server._network_id, protocol_version[NodeType.FULL_NODE], dummy_port, NodeType.FULL_NODE - ) + await wsc.perform_handshake(server._network_id, dummy_port, NodeType.FULL_NODE) await wsc.close() diff --git a/chia/server/server.py b/chia/server/server.py index 29860b5d7b37..b39cf8b38db6 100644 --- a/chia/server/server.py +++ b/chia/server/server.py @@ -27,7 +27,6 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.protocol_state_machine import message_requires_reply from chia.protocols.protocol_timing import INVALID_PROTOCOL_BAN_SECONDS -from chia.protocols.shared_protocol import protocol_version from chia.server.api_protocol import ApiProtocol from chia.server.introducer_peers import IntroducerPeers from chia.server.outbound_message import Message, NodeType @@ -334,9 +333,7 @@ async def incoming_connection(self, request: web.Request) -> web.StreamResponse: outbound_rate_limit_percent=self._outbound_rate_limit_percent, local_capabilities_for_handshake=self._local_capabilities_for_handshake, ) - await connection.perform_handshake( - self._network_id, protocol_version[self._local_type], self.get_port(), self._local_type - ) + await connection.perform_handshake(self._network_id, self.get_port(), self._local_type) assert connection.connection_type is not None, "handshake failed to set connection type, still None" # Limit inbound connections to config's specifications. @@ -487,9 +484,7 @@ async def start_client( local_capabilities_for_handshake=self._local_capabilities_for_handshake, session=session, ) - await connection.perform_handshake( - self._network_id, protocol_version[self._local_type], server_port, self._local_type - ) + await connection.perform_handshake(self._network_id, server_port, self._local_type) await self.connection_added(connection, on_connect) # the session has been adopted by the connection, don't close it at # the end of the function diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index c8cda6d30060..961e70b0ab5c 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -22,7 +22,7 @@ CONSENSUS_ERROR_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS, ) -from chia.protocols.shared_protocol import Capability, Error, Handshake +from chia.protocols.shared_protocol import Capability, Error, Handshake, protocol_version from chia.server.api_protocol import ApiProtocol from chia.server.capabilities import known_active_capabilities from chia.server.outbound_message import Message, NodeType, make_msg @@ -188,22 +188,21 @@ def _get_extra_info(self, name: str) -> Optional[Any]: async def perform_handshake( self, network_id: str, - protocol_version: str, server_port: int, local_type: NodeType, ) -> None: - outbound_handshake = make_msg( - ProtocolMessageTypes.handshake, - Handshake( - network_id, - protocol_version, - __version__, - uint16(server_port), - uint8(local_type.value), - self.local_capabilities_for_handshake, - ), - ) if self.is_outbound: + outbound_handshake = make_msg( + ProtocolMessageTypes.handshake, + Handshake( + network_id, + protocol_version[local_type], + __version__, + uint16(server_port), + uint8(local_type.value), + self.local_capabilities_for_handshake, + ), + ) await self._send_message(outbound_handshake) inbound_handshake_msg = await self._read_one_message() if inbound_handshake_msg is None: @@ -222,11 +221,17 @@ async def perform_handshake( if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) - inbound_node_type = NodeType(inbound_handshake.node_type) + if inbound_handshake.protocol_version != protocol_version[local_type]: + self.log.warning( + f"protocol version mismatch: " + f"incoming={inbound_handshake.protocol_version} " + f"our={protocol_version[local_type]}" + ) + self.version = inbound_handshake.software_version self.protocol_version = Version(inbound_handshake.protocol_version) self.peer_server_port = inbound_handshake.server_port - self.connection_type = inbound_node_type + self.connection_type = NodeType(inbound_handshake.node_type) # "1" means capability is enabled self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities) else: @@ -251,12 +256,39 @@ async def perform_handshake( if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) - inbound_node_type = NodeType(inbound_handshake.node_type) + remote_node_type = NodeType(inbound_handshake.node_type) + if remote_node_type not in [ + NodeType.FULL_NODE, + NodeType.HARVESTER, + NodeType.FARMER, + NodeType.TIMELORD, + NodeType.INTRODUCER, + NodeType.WALLET, + NodeType.DATA_LAYER, + ]: + raise ProtocolError(Err.INVALID_HANDSHAKE) + + if inbound_handshake.protocol_version != protocol_version[remote_node_type]: + self.log.warning( + f"protocol version mismatch: incoming={inbound_handshake.protocol_version} our={protocol_version}" + ) + + outbound_handshake = make_msg( + ProtocolMessageTypes.handshake, + Handshake( + network_id, + protocol_version[remote_node_type], + __version__, + uint16(server_port), + uint8(local_type.value), + self.local_capabilities_for_handshake, + ), + ) await self._send_message(outbound_handshake) self.version = inbound_handshake.software_version self.protocol_version = Version(inbound_handshake.protocol_version) self.peer_server_port = inbound_handshake.server_port - self.connection_type = inbound_node_type + self.connection_type = remote_node_type # "1" means capability is enabled self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities) From d6ea95235f48ef9e1ee05a150dc58f9c404f5edf Mon Sep 17 00:00:00 2001 From: arvidn Date: Thu, 28 Mar 2024 23:25:08 +0100 Subject: [PATCH 3/3] simplify checking protocol version for incoming connections --- chia/server/ws_connection.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index 961e70b0ab5c..de8e63892c02 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -257,16 +257,6 @@ async def perform_handshake( raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) remote_node_type = NodeType(inbound_handshake.node_type) - if remote_node_type not in [ - NodeType.FULL_NODE, - NodeType.HARVESTER, - NodeType.FARMER, - NodeType.TIMELORD, - NodeType.INTRODUCER, - NodeType.WALLET, - NodeType.DATA_LAYER, - ]: - raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.protocol_version != protocol_version[remote_node_type]: self.log.warning(