diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 03ee43d1fb..7cc577a13e 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -14,14 +14,18 @@ import logging import threading +import uuid from typing import Dict, List, Union +from nvflare.apis.fl_constant import ServerCommandNames from nvflare.fuel.f3.cellnet.core_cell import CoreCell -from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, MessageType +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, MessageType, ReturnCode +from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply from nvflare.fuel.f3.message import Message from nvflare.fuel.f3.stream_cell import StreamCell from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.private.defs import CellChannel class SimpleWaiter: @@ -29,15 +33,15 @@ def __init__(self, req_id, result): super().__init__() self.req_id = req_id self.result = result - self.receiving_futre = None + self.receiving_future = None self.in_receiving = threading.Event() class Adapter: - def __init__(self, cb, my_info, nice_cell): + def __init__(self, cb, my_info, cell): self.cb = cb self.my_info = my_info - self.nice_cell = nice_cell + self.cell = cell self.logger = logging.getLogger(self.__class__.__name__) def call(self, future): # this will be called by StreamCell upon receiving the first byte of blob @@ -46,11 +50,15 @@ def call(self, future): # this will be called by StreamCell upon receiving the origin = headers.get(MessageHeaderKey.ORIGIN, None) result = future.result() request = Message(headers, result) + + decode_payload(request, StreamHeaderKey.PAYLOAD_ENCODING) + channel = request.get_header(StreamHeaderKey.CHANNEL) request.set_header(MessageHeaderKey.CHANNEL, channel) topic = request.get_header(StreamHeaderKey.TOPIC) request.set_header(MessageHeaderKey.TOPIC, topic) req_id = request.get_header(MessageHeaderKey.REQ_ID, "") + secure = request.get_header(MessageHeaderKey.SECURE, False) response = self.cb(request) response.add_headers( { @@ -59,7 +67,9 @@ def call(self, future): # this will be called by StreamCell upon receiving the StreamHeaderKey.STREAM_REQ_ID: stream_req_id, } ) - messagesend_future = self.nice_cell.send_blob(channel, topic, origin, response) + + encode_payload(response, StreamHeaderKey.PAYLOAD_ENCODING) + self.cell.send_blob(channel, topic, origin, response, secure) class Cell(StreamCell): @@ -79,7 +89,7 @@ def method(*args, **kwargs): return method def fire_and_forget( - self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, optional=False + self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, secure=False, optional=False ) -> Dict[str, str]: """ Send a message over a channel to specified destination cell(s), and do not wait for replies. @@ -89,25 +99,31 @@ def fire_and_forget( topic: topic of the message targets: one or more destination cell IDs. None means all. message: message to be sent + secure: End-end encryption if True optional: whether the message is optional Returns: None """ - # if channel == CellChannel.SERVER_COMMAND and topic == ServerCommandNames.HANDLE_DEAD_JOB: - # if isinstance(targets, list): - # for target in targets: - # self.send_blob(channel=channel, topic=topic, target=target, message=message) - # else: - # self.send_blob(channel=channel, topic=topic, target=targets, message=message) - # else: - # self.core_cell.fire_and_forget( - # channel=channel, topic=topic, targets=targets, message=message, optional=optional - # ) - - self.core_cell.fire_and_forget( - channel=channel, topic=topic, targets=targets, message=message, optional=optional - ) + + if channel == CellChannel.SERVER_COMMAND and topic == ServerCommandNames.HANDLE_DEAD_JOB: + + encode_payload(message, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) + + result = {} + if isinstance(targets, list): + for target in targets: + self.send_blob(channel=channel, topic=topic, target=target, message=message, secure=secure) + result[target] = "" + else: + self.send_blob(channel=channel, topic=topic, target=targets, message=message, secure=secure) + result[targets] = "" + + return result + else: + return self.core_cell.fire_and_forget( + channel=channel, topic=topic, targets=targets, message=message, optional=optional + ) def _get_result(self, req_id): waiter = self.requests_dict.pop(req_id) @@ -124,55 +140,60 @@ def _future_wait(self, future, timeout): last_progress = current_progress return True - def send_request(self, channel, target, topic, request, timeout=10.0, optional=False): + def send_request(self, channel, target, topic, request, timeout=10.0, secure=False, optional=False): + self.logger.debug(f"send_request: {channel=}, {topic=}, {target=}, {timeout=}") - # if channel != CellChannel.SERVER_COMMAND: - # return self.core_cell.send_request( - # channel=channel, target=target, topic=topic, request=request, timeout=timeout, optional=optional - # ) - # - # request.payload = fobs.dumps(request.payload) - # - # req_id = str(uuid.uuid4()) - # request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) - # - # # this future can be used to check sending progress, but not for checking return blob - # future = self.send_blob(channel, topic, target, request) - # - # waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT)) - # self.requests_dict[req_id] = waiter - # - # # Three stages, sending, waiting for receiving first byte, receiving - # - # # sending with progress timeout - # self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") - # sending_complete = self._future_wait(future, timeout) - # if not sending_complete: - # self.logger.debug(f"{req_id=}: sending timeout") - # return self._get_result(req_id) - # self.logger.debug(f"{req_id=}: sending complete") - # - # # waiting for receiving first byte - # self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") - # if not waiter.in_receiving.wait(timeout): - # self.logger.debug(f"{req_id=}: remote processing timeout") - # return self._get_result(req_id) - # self.logger.debug(f"{req_id=}: in receiving") - # - # # receiving with progress timeout - # r_future = waiter.receiving_future - # self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") - # receiving_complete = self._future_wait(r_future, timeout) - # if not receiving_complete: - # self.logger.debug(f"{req_id=}: receiving timeout") - # return self._get_result(req_id) - # self.logger.debug(f"{req_id=}: receiving complete") - # waiter.result = Message(r_future.headers, r_future.result()) - # return self._get_result(req_id) - - return self.core_cell.send_request( - channel=channel, target=target, topic=topic, request=request, timeout=timeout, optional=optional - ) + + if channel != CellChannel.SERVER_COMMAND: + return self.core_cell.send_request( + channel=channel, + target=target, + topic=topic, + request=request, + timeout=timeout, + secure=secure, + optional=optional, + ) + + encode_payload(request, StreamHeaderKey.PAYLOAD_ENCODING) + + req_id = str(uuid.uuid4()) + request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) + + # this future can be used to check sending progress, but not for checking return blob + future = self.send_blob(channel, topic, target, request, secure) + + waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT)) + self.requests_dict[req_id] = waiter + + # Three stages, sending, waiting for receiving first byte, receiving + + # sending with progress timeout + self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") + sending_complete = self._future_wait(future, timeout) + if not sending_complete: + self.logger.debug(f"{req_id=}: sending timeout") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: sending complete") + + # waiting for receiving first byte + self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") + if not waiter.in_receiving.wait(timeout): + self.logger.debug(f"{req_id=}: remote processing timeout") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: in receiving") + + # receiving with progress timeout + r_future = waiter.receiving_future + self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") + receiving_complete = self._future_wait(r_future, timeout) + if not receiving_complete: + self.logger.debug(f"{req_id=}: receiving timeout") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: receiving complete") + waiter.result = Message(r_future.headers, r_future.result()) + decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) + return self._get_result(req_id) def _process_reply(self, future: StreamFuture): headers = future.headers @@ -180,7 +201,7 @@ def _process_reply(self, future: StreamFuture): try: waiter = self.requests_dict[req_id] except KeyError as e: - self.logger.warning(f"Receiving unknown {req_id=}, discarded") + self.logger.warning(f"Receiving unknown {req_id=}, discarded: {e}") return waiter.receiving_future = future waiter.in_receiving.set() @@ -199,19 +220,17 @@ def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs): Returns: """ + if not callable(cb): raise ValueError(f"specified request_cb {type(cb)} is not callable") - # if channel == CellChannel.SERVER_COMMAND and topic in [ - # "*", - # ServerCommandNames.GET_TASK, - # ServerCommandNames.SUBMIT_UPDATE, - # ]: - # self.logger.debug(f"Register blob CB for {channel=}, {topic=}") - # adapter = Adapter(cb, self.core_cell.my_info, self) - # self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs) - # else: - # self.logger.debug(f"Register regular CB for {channel=}, {topic=}") - # self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs) - - self.logger.debug(f"Register regular CB for {channel=}, {topic=}") - self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs) + if channel == CellChannel.SERVER_COMMAND and topic in [ + "*", + ServerCommandNames.GET_TASK, + ServerCommandNames.SUBMIT_UPDATE, + ]: + self.logger.debug(f"Register blob CB for {channel=}, {topic=}") + adapter = Adapter(cb, self.core_cell.my_info, self) + self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs) + else: + self.logger.debug(f"Register regular CB for {channel=}, {topic=}") + self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs) diff --git a/nvflare/fuel/f3/cellnet/cell_cipher.py b/nvflare/fuel/f3/cellnet/cell_cipher.py index 39621fb9d2..da8ddba23b 100644 --- a/nvflare/fuel/f3/cellnet/cell_cipher.py +++ b/nvflare/fuel/f3/cellnet/cell_cipher.py @@ -72,6 +72,13 @@ def _sign(k, m): def _verify(k, m, s): + + if not isinstance(m, bytes): + m = bytes(m) + + if not isinstance(s, bytes): + s = bytes(s) + k.verify( s, m, @@ -210,6 +217,10 @@ def decrypt(self, message: bytes, origin_cert: Certificate): message[NONCE_LENGTH : NONCE_LENGTH + KEY_ENC_LENGTH], message[NONCE_LENGTH + KEY_ENC_LENGTH : SIMPLE_HEADER_LENGTH], ) + + if not isinstance(key_enc, bytes): + key_enc = bytes(key_enc) + key_hash = hash(key_enc) dec = self._cached_dec.get(key_hash) if dec is None: diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index cdf8a172fa..ed4c12e38d 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -23,6 +23,7 @@ from urllib.parse import urlparse from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager +from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager from nvflare.fuel.f3.cellnet.defs import ( AbortRun, AuthenticationError, @@ -52,6 +53,8 @@ _CHANNEL = "cellnet.channel" _TOPIC_BULK = "bulk" _TOPIC_BYE = "bye" +_SM_CHANNEL = "credential_manager" +_SM_TOPIC = "key_exchange" _ONE_MB = 1024 * 1024 @@ -143,17 +146,23 @@ def log_messaging_error( class _BulkSender: - def __init__(self, cell, target: str, max_queue_size): + def __init__(self, cell, target: str, max_queue_size, secure=False): self.cell = cell self.target = target self.max_queue_size = max_queue_size + self.secure = secure self.messages = [] self.last_send_time = 0 self.lock = threading.Lock() self.logger = logging.getLogger(self.__class__.__name__) def queue_message(self, channel: str, topic: str, message: Message): + if self.secure: + message.add_headers({MessageHeaderKey.SECURE, True}) + encode_payload(message) + self.cell.encrypt_payload(message) + with self.lock: tm = TargetMessage(target=self.target, channel=channel, topic=topic, message=message) self.messages.append(tm) @@ -216,6 +225,44 @@ class _CounterName: REP_FILTER_ERROR = "rep_filter_error" +class CertificateExchanger: + """This class handles cert-exchange messages""" + + def __init__(self, core_cell, credential_manager: CredentialManager): + + self.core_cell = core_cell + self.credential_manager = credential_manager + self.core_cell.register_request_cb(_SM_CHANNEL, _SM_TOPIC, self._handle_cert_request) + + def get_certificate(self, target: str) -> bytes: + + cert = self.credential_manager.get_certificate(target) + if cert: + return cert + + cert = self.exchange_certificate(target) + self.credential_manager.save_certificate(target, cert) + + return cert + + def exchange_certificate(self, target: str) -> bytes: + root = FQCN.get_root(target) + req = self.credential_manager.create_request(root) + response = self.core_cell.send_request(_SM_CHANNEL, _SM_TOPIC, root, Message(None, req)) + reply = response.payload + + if not reply: + error_code = response.get_header(MessageHeaderKey.RETURN_CODE) + raise RuntimeError(f"Cert exchanged to {root} failed: {error_code}") + + return self.credential_manager.process_response(reply) + + def _handle_cert_request(self, request: Message): + + reply = self.credential_manager.process_request(request.payload) + return Message(None, reply) + + class CoreCell(MessageReceiver, EndpointMonitor): APP_ID = 1 @@ -433,6 +480,9 @@ def __init__( ) self.ALL_CELLS[fqcn] = self + self.credential_manager = CredentialManager(self.endpoint) + self.cert_ex = CertificateExchanger(self, self.credential_manager) + def log_error(self, log_text: str, msg: Union[None, Message], log_except=False): log_messaging_error( logger=self.logger, log_text=log_text, cell=self, msg=msg, log_except=log_except, log_level=logging.ERROR @@ -868,6 +918,58 @@ def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs): raise ValueError(f"specified request_cb {type(cb)} is not callable") self.req_reg.set(channel, topic, Callback(cb, args, kwargs)) + def encrypt_payload(self, message: Message): + + if not message.get_header(MessageHeaderKey.SECURE, False): + return + + encrypted = message.get_header(MessageHeaderKey.ENCRYPTED, False) + if encrypted: + # Prevent double encryption + return + + target = message.get_header(MessageHeaderKey.DESTINATION) + + if not target: + raise RuntimeError("Message destination missing") + + if message.payload is None: + message.payload = bytes(0) + + payload_len = len(message.payload) + message.add_headers( + { + MessageHeaderKey.PAYLOAD_LEN: payload_len, + MessageHeaderKey.ENCRYPTED: True, + } + ) + + target_cert = self.cert_ex.get_certificate(target) + message.payload = self.credential_manager.encrypt(target_cert, message.payload) + self.logger.debug(f"Payload ({payload_len} bytes) is encrypted ({len(message.payload)} bytes)") + + def decrypt_payload(self, message: Message): + + if not message.get_header(MessageHeaderKey.SECURE, False): + return + + encrypted = message.get_header(MessageHeaderKey.ENCRYPTED, False) + if not encrypted: + # Message is already decrypted + return + + message.remove_header(MessageHeaderKey.ENCRYPTED) + + origin = message.get_header(MessageHeaderKey.ORIGIN) + if not origin: + raise RuntimeError("Message origin missing") + + payload_len = message.get_header(MessageHeaderKey.PAYLOAD_LEN) + origin_cert = self.cert_ex.get_certificate(origin) + message.payload = self.credential_manager.decrypt(origin_cert, message.payload) + if len(message.payload) != payload_len: + raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}") + def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable") @@ -1004,6 +1106,8 @@ def _send_to_endpoint(self, to_endpoint: Endpoint, message: Message) -> str: err = "" try: encode_payload(message) + self.encrypt_payload(message) + message.set_header(MessageHeaderKey.SEND_TIME, time.time()) if not message.payload: msg_size = 0 @@ -1121,15 +1225,15 @@ def _send_to_targets( return self._send_target_messages(target_msgs) def send_request( - self, channel: str, topic: str, target: str, request: Message, timeout=None, optional=False + self, channel: str, topic: str, target: str, request: Message, timeout=None, secure=False, optional=False ) -> Message: self.logger.debug(f"{self.my_info.fqcn}: sending request {channel}:{topic} to {target}") - result = self.broadcast_request(channel, topic, [target], request, timeout, optional) + result = self.broadcast_request(channel, topic, [target], request, timeout, secure, optional) assert isinstance(result, dict) return result.get(target) def broadcast_multi_requests( - self, target_msgs: Dict[str, TargetMessage], timeout=None, optional=False + self, target_msgs: Dict[str, TargetMessage], timeout=None, secure=False, optional=False ) -> Dict[str, Message]: """ This is the core of the request/response handling. Be extremely careful when making any changes! @@ -1153,6 +1257,7 @@ def broadcast_multi_requests( Args: target_msgs: messages to be sent timeout: timeout value + secure: End-end encryption optional: whether the message is optional Returns: a dict of: target name => reply message @@ -1175,6 +1280,7 @@ def broadcast_multi_requests( { MessageHeaderKey.REQ_ID: waiter.id, MessageHeaderKey.REPLY_EXPECTED: True, + MessageHeaderKey.SECURE: secure, MessageHeaderKey.OPTIONAL: optional, } ) @@ -1228,7 +1334,14 @@ def broadcast_multi_requests( return result def broadcast_request( - self, channel: str, topic: str, targets: Union[str, List[str]], request: Message, timeout=None, optional=False + self, + channel: str, + topic: str, + targets: Union[str, List[str]], + request: Message, + timeout=None, + secure=False, + optional=False, ) -> Dict[str, Message]: """ Send a message over a channel to specified destination cell(s), and wait for reply @@ -1239,6 +1352,7 @@ def broadcast_request( targets: FQCN of the destination cell(s) request: message to be sent timeout: how long to wait for replies + secure: End-end encryption optional: whether the message is optional Returns: a dict of: cell_id => reply message @@ -1249,10 +1363,10 @@ def broadcast_request( target_msgs = {} for t in targets: target_msgs[t] = TargetMessage(t, channel, topic, request) - return self.broadcast_multi_requests(target_msgs, timeout, optional=optional) + return self.broadcast_multi_requests(target_msgs, timeout, secure=secure, optional=optional) def fire_and_forget( - self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, optional=False + self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, secure=False, optional=False ) -> Dict[str, str]: """ Send a message over a channel to specified destination cell(s), and do not wait for replies. @@ -1262,12 +1376,19 @@ def fire_and_forget( topic: topic of the message targets: one or more destination cell IDs. None means all. message: message to be sent + secure: End-end encryption of the message optional: whether the message is optional Returns: None """ - message.add_headers({MessageHeaderKey.REPLY_EXPECTED: False, MessageHeaderKey.OPTIONAL: optional}) + message.add_headers( + { + MessageHeaderKey.REPLY_EXPECTED: False, + MessageHeaderKey.OPTIONAL: optional, + MessageHeaderKey.SECURE: secure, + } + ) return self._send_to_targets(channel, topic, targets, message) def queue_message(self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, optional=False): @@ -1360,7 +1481,7 @@ def fire_multi_requests_and_forget(self, target_msgs: Dict[str, TargetMessage], request.add_headers({MessageHeaderKey.REPLY_EXPECTED: False, MessageHeaderKey.OPTIONAL: optional}) return self._send_target_messages(target_msgs) - def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], optional=False) -> str: + def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], secure=False, optional=False) -> str: """Send a reply to respond to one or more requests. This is useful if the request receiver needs to delay its reply as follows: @@ -1373,6 +1494,7 @@ def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], optio reply: the reply message to_cell: the target cell for_req_ids: the list of req IDs that the reply is for + secure: End-end encryption optional: whether the message is optional Returns: an error message if any @@ -1386,6 +1508,7 @@ def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], optio MessageHeaderKey.DESTINATION: to_cell, MessageHeaderKey.REQ_ID: for_req_ids, MessageHeaderKey.MSG_TYPE: MessageType.REPLY, + MessageHeaderKey.SECURE: secure, MessageHeaderKey.OPTIONAL: optional, } ) @@ -1427,6 +1550,8 @@ def process_message(self, endpoint: Endpoint, connection: Connection, app_id: in def _process_request(self, origin: str, message: Message) -> Union[None, Message]: self.logger.debug(f"{self.my_info.fqcn}: processing incoming request") + + self.decrypt_payload(message) decode_payload(message) # this is a request for me - dispatch to the right CB channel = message.get_header(MessageHeaderKey.CHANNEL, "") @@ -1464,6 +1589,10 @@ def _process_request(self, origin: str, message: Message) -> Union[None, Message msg=message, ) reply = make_reply(ReturnCode.PROCESS_EXCEPTION, error="bad cb result") + + # Reply must be secure if request is + reply.add_headers({MessageHeaderKey.SECURE: message.get_header(MessageHeaderKey.SECURE, False)}) + return reply def _add_to_route(self, message: Message): @@ -1551,6 +1680,7 @@ def _process_reply(self, origin: str, message: Message, msg_type: str): topic = message.get_header(MessageHeaderKey.TOPIC, "") now = time.time() self.logger.debug(f"{self.my_info.fqcn}: processing reply from {origin} for type {msg_type}") + self.decrypt_payload(message) decode_payload(message) req_ids = message.get_header(MessageHeaderKey.REQ_ID) diff --git a/nvflare/fuel/f3/cellnet/credential_manager.py b/nvflare/fuel/f3/cellnet/credential_manager.py new file mode 100644 index 0000000000..32a1fef618 --- /dev/null +++ b/nvflare/fuel/f3/cellnet/credential_manager.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509 import Certificate + +from nvflare.fuel.f3.cellnet.cell_cipher import SimpleCellCipher +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.endpoint import Endpoint + +log = logging.getLogger(__name__) + +CERT_ERROR = "cert_error" +CERT_TARGET = "cert_target" +CERT_ORIGIN = "cert_origin" +CERT_CONTENT = "cert_content" +CERT_CA_CONTENT = "cert_ca_content" +CERT_REQ_TIMEOUT = 10 + + +class CredentialManager: + """Helper class for secure message. It holds the local credentials and certificate cache""" + + def __init__(self, local_endpoint: Endpoint): + + self.local_endpoint = local_endpoint + self.cert_cache = {} + self.lock = threading.Lock() + + conn_props = self.local_endpoint.conn_props + ca_cert_path = conn_props.get(DriverParams.CA_CERT) + server_cert_path = conn_props.get(DriverParams.SERVER_CERT) + if server_cert_path: + local_cert_path = server_cert_path + local_key_path = conn_props.get(DriverParams.SERVER_KEY) + else: + local_cert_path = conn_props.get(DriverParams.CLIENT_CERT) + local_key_path = conn_props.get(DriverParams.CLIENT_KEY) + + if not local_cert_path: + log.debug("Certificate is not configured, secure message is not supported") + self.ca_cert = None + self.local_cert = None + self.local_key = None + self.cell_cipher = None + else: + self.ca_cert = self.read_file(ca_cert_path) + self.local_cert = self.read_file(local_cert_path) + self.local_key = self.read_file(local_key_path) + self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert()) + + if not self.local_cert: + log.debug("Certificate is not configured, secure message is not supported") + self.cell_cipher = None + else: + self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert()) + + def encrypt(self, target_cert: bytes, payload: bytes) -> bytes: + + if not self.cell_cipher: + raise RuntimeError("Secure message not supported, Cell not running in secure mode") + + return self.cell_cipher.encrypt(payload, x509.load_pem_x509_certificate(target_cert)) + + def decrypt(self, origin_cert: bytes, cipher: bytes) -> bytes: + + if not self.cell_cipher: + raise RuntimeError("Secure message not supported, Cell not running in secure mode") + + return self.cell_cipher.decrypt(cipher, x509.load_pem_x509_certificate(origin_cert)) + + def get_certificate(self, fqcn: str) -> bytes: + + if not self.cell_cipher: + raise RuntimeError("This cell doesn't support certificate exchange, not running in secure mode") + + target = FQCN.get_root(fqcn) + return self.cert_cache.get(target) + + def save_certificate(self, fqcn: str, cert: bytes): + target = FQCN.get_root(fqcn) + self.cert_cache[target] = cert + + def create_request(self, target: str) -> dict: + + req = { + CERT_TARGET: target, + CERT_ORIGIN: FQCN.get_root(self.local_endpoint.name), + CERT_CONTENT: self.local_cert, + CERT_CA_CONTENT: self.ca_cert, + } + + return req + + def process_request(self, request: dict) -> dict: + + target = request.get(CERT_TARGET) + origin = request.get(CERT_ORIGIN) + + reply = {CERT_TARGET: target, CERT_ORIGIN: origin} + + if not self.local_cert: + reply[CERT_ERROR] = f"Target {target} is not running in secure mode" + else: + cert = request.get(CERT_CONTENT) + + # Save cert from requester in the cache + self.cert_cache[origin] = cert + + reply[CERT_CONTENT] = self.local_cert + reply[CERT_CA_CONTENT] = self.ca_cert + + return reply + + @staticmethod + def process_response(reply: dict) -> bytes: + + error = reply.get(CERT_ERROR) + if error: + raise RuntimeError(f"Request to get certificate from {target} failed: {error}") + + return reply.get(CERT_CONTENT) + + def get_local_cert(self) -> Certificate: + return x509.load_pem_x509_certificate(self.local_cert) + + def get_local_key(self) -> RSAPrivateKey: + return serialization.load_pem_private_key(self.local_key, password=None) + + def get_ca_cert(self) -> Certificate: + return x509.load_pem_x509_certificate(self.ca_cert) + + @staticmethod + def read_file(file_name: str): + if not file_name: + return None + + with open(file_name, "rb") as f: + return f.read() diff --git a/nvflare/fuel/f3/cellnet/defs.py b/nvflare/fuel/f3/cellnet/defs.py index e29c19a56c..a481a8c0c4 100644 --- a/nvflare/fuel/f3/cellnet/defs.py +++ b/nvflare/fuel/f3/cellnet/defs.py @@ -40,6 +40,9 @@ class MessageHeaderKey: ORIGINAL_HEADERS = CELLNET_PREFIX + "original_headers" SEND_TIME = CELLNET_PREFIX + "send_time" RETURN_REASON = CELLNET_PREFIX + "return_reason" + SECURE = CELLNET_PREFIX + "secure" + PAYLOAD_LEN = CELLNET_PREFIX + "payload_len" + ENCRYPTED = CELLNET_PREFIX + "encrypted" OPTIONAL = CELLNET_PREFIX + "optional" diff --git a/nvflare/fuel/f3/cellnet/utils.py b/nvflare/fuel/f3/cellnet/utils.py index 6caabe7969..14778dc245 100644 --- a/nvflare/fuel/f3/cellnet/utils.py +++ b/nvflare/fuel/f3/cellnet/utils.py @@ -67,8 +67,8 @@ def format_log_message(fqcn: str, message: Message, log: str) -> str: return " ".join(context) + f"] {log}" -def encode_payload(message: Message): - encoding = message.get_header(MessageHeaderKey.PAYLOAD_ENCODING) +def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING): + encoding = message.get_header(encoding_key) if not encoding: if message.payload is None: encoding = Encoding.NONE @@ -77,11 +77,11 @@ def encode_payload(message: Message): else: encoding = Encoding.FOBS message.payload = fobs.dumps(message.payload) - message.set_header(MessageHeaderKey.PAYLOAD_ENCODING, encoding) + message.set_header(encoding_key, encoding) -def decode_payload(message: Message): - encoding = message.get_header(MessageHeaderKey.PAYLOAD_ENCODING) +def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING): + encoding = message.get_header(encoding_key) if not encoding: return @@ -92,4 +92,4 @@ def decode_payload(message: Message): else: # assume to be bytes pass - message.remove_header(MessageHeaderKey.PAYLOAD_ENCODING) + message.remove_header(encoding_key) diff --git a/nvflare/fuel/f3/stream_cell.py b/nvflare/fuel/f3/stream_cell.py index 4f62ccdb9f..687cfa31d2 100644 --- a/nvflare/fuel/f3/stream_cell.py +++ b/nvflare/fuel/f3/stream_cell.py @@ -40,7 +40,7 @@ def get_chunk_size(): """ return ByteStreamer.get_chunk_size() - def send_stream(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + def send_stream(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture: """ Send a byte-stream over a channel/topic asynchronously. The streaming is performed in a different thread. The streamer will read from stream and send the data in chunks till the stream reaches EOF. @@ -50,6 +50,7 @@ def send_stream(self, channel: str, topic: str, target: str, message: Message) - topic: topic for the stream target: destination cell FQCN message: The payload is the stream to send + secure: Send the message with end-end encryption if True Returns: StreamFuture that can be used to check status/progress, or register callbacks. The future result is the number of bytes sent @@ -59,7 +60,7 @@ def send_stream(self, channel: str, topic: str, target: str, message: Message) - if not isinstance(message.payload, Stream): raise StreamError(f"Message payload is not a stream: {type(message.payload)}") - return self.byte_streamer.send(channel, topic, target, message.headers, message.payload) + return self.byte_streamer.send(channel, topic, target, message.headers, message.payload, secure) def register_stream_cb(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): """ @@ -87,7 +88,7 @@ def register_stream_cb(self, channel: str, topic: str, stream_cb: Callable, *arg """ self.byte_receiver.register_callback(channel, topic, stream_cb, *args, **kwargs) - def send_blob(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + def send_blob(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture: """ Send a BLOB (Binary Large Object) to the target. The payload of message is the BLOB. The BLOB must fit in memory on the receiving end. @@ -97,16 +98,20 @@ def send_blob(self, channel: str, topic: str, target: str, message: Message) -> topic: topic of the message target: destination cell IDs message: the headers and the blob as payload + secure: Send the message with end-end encryption if True Returns: StreamFuture that can be used to check status/progress and get result The future result is the total number of bytes sent """ + if message.payload is None: + message.payload = bytes(0) + if not isinstance(message.payload, (bytes, bytearray, memoryview)): raise StreamError(f"Message payload is not a byte array: {type(message.payload)}") - return self.blob_streamer.send(channel, topic, target, message) + return self.blob_streamer.send(channel, topic, target, message, secure) def register_blob_cb(self, channel: str, topic: str, blob_cb, *args, **kwargs): """ @@ -126,7 +131,7 @@ def register_blob_cb(self, channel: str, topic: str, blob_cb, *args, **kwargs): """ self.blob_streamer.register_blob_callback(channel, topic, blob_cb, *args, **kwargs) - def send_file(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + def send_file(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture: """ Send a file to target using stream API. @@ -135,6 +140,7 @@ def send_file(self, channel: str, topic: str, target: str, message: Message) -> topic: topic for the message target: destination cell FQCN message: the headers and the full path of the file to be sent as payload + secure: Send the message with end-end encryption if True Returns: StreamFuture that can be used to check status/progress and get the total bytes sent @@ -146,7 +152,7 @@ def send_file(self, channel: str, topic: str, target: str, message: Message) -> if not os.path.isfile(file_name) or not os.access(file_name, os.R_OK): raise StreamError(f"File {file_name} doesn't exist or isn't readable") - return self.file_streamer.send(channel, topic, target, message) + return self.file_streamer.send(channel, topic, target, message, secure) def register_file_cb(self, channel: str, topic: str, file_cb, *args, **kwargs): """ @@ -162,7 +168,7 @@ def register_file_cb(self, channel: str, topic: str, file_cb, *args, **kwargs): """ self.file_streamer.register_file_callback(channel, topic, file_cb, *args, **kwargs) - def send_objects(self, channel: str, topic: str, target: str, message: Message) -> ObjectStreamFuture: + def send_objects(self, channel: str, topic: str, target: str, message: Message, secure=False) -> ObjectStreamFuture: """ Send a list of objects to the destination. Each object is sent as BLOB, so it must fit in memory @@ -171,13 +177,15 @@ def send_objects(self, channel: str, topic: str, target: str, message: Message) topic: topic of the message target: destination cell IDs message: Headers and the payload which is an iterator that provides next object + secure: Send the message with end-end encryption if True + Returns: ObjectStreamFuture that can be used to check status/progress, or register callbacks """ if not isinstance(message.payload, ObjectIterator): raise StreamError(f"Message payload is not an object iterator: {type(message.payload)}") - return self.object_streamer.stream_objects(channel, topic, target, message.headers, message.payload) + return self.object_streamer.stream_objects(channel, topic, target, message.headers, message.payload, secure) def register_objects_cb( self, channel: str, topic: str, object_stream_cb: Callable, object_cb: Callable, *args, **kwargs diff --git a/nvflare/fuel/f3/streaming/blob_streamer.py b/nvflare/fuel/f3/streaming/blob_streamer.py index 842bba4ff2..08ee8b308f 100644 --- a/nvflare/fuel/f3/streaming/blob_streamer.py +++ b/nvflare/fuel/f3/streaming/blob_streamer.py @@ -19,7 +19,7 @@ from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver from nvflare.fuel.f3.streaming.byte_streamer import ByteStreamer from nvflare.fuel.f3.streaming.stream_const import EOS -from nvflare.fuel.f3.streaming.stream_types import Stream, StreamFuture +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, stream_thread_pool, wrap_view from nvflare.security.logging import secure_format_traceback @@ -117,9 +117,15 @@ def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): self.byte_streamer = byte_streamer self.byte_receiver = byte_receiver - def send(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + def send(self, channel: str, topic: str, target: str, message: Message, secure: bool) -> StreamFuture: + if message.payload is None: + message.payload = bytes(0) + + if not isinstance(message.payload, (bytes, bytearray, memoryview)): + raise StreamError(f"BLOB is invalid type: {type(message.payload)}") + blob_stream = BlobStream(message.payload, message.headers) - return self.byte_streamer.send(channel, topic, target, message.headers, blob_stream) + return self.byte_streamer.send(channel, topic, target, message.headers, blob_stream, secure) def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs): handler = BlobHandler(blob_cb) diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index 9e6709aa70..5401eb01ae 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -99,6 +99,9 @@ def read(self, chunk_size: int) -> bytes: with self.task.task_lock: last_chunk, buf = self.task.buffers.popleft() + if buf is None: + buf = bytes(0) + if 0 < chunk_size < len(buf): result = buf[0:chunk_size] # Put leftover to the head of the queue @@ -176,6 +179,8 @@ def _data_handler(self, message: Message): seq = message.get_header(StreamHeaderKey.SEQUENCE) error = message.get_header(StreamHeaderKey.ERROR_MSG, None) + payload = message.payload + with self.map_lock: task = self.rx_task_map.get(sid, None) if not task: @@ -215,7 +220,7 @@ def _data_handler(self, message: Message): task.last_chunk_received = True if seq == task.next_seq: - self._append(task, (last_chunk, message.payload)) + self._append(task, (last_chunk, payload)) task.next_seq += 1 # Try to reassemble out-of-seq buffers @@ -230,7 +235,7 @@ def _data_handler(self, message: Message): self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}")) return else: - task.out_seq_buffers[seq] = last_chunk, message.payload + task.out_seq_buffers[seq] = last_chunk, payload # If all chunks are lined up, the task can be deleted if not task.out_seq_buffers and task.buffers: diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index ed8ebfe03b..5a8e87d662 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -36,7 +36,7 @@ class TxTask: - def __init__(self, channel: str, topic: str, target: str, headers: dict, stream: Stream): + def __init__(self, channel: str, topic: str, target: str, headers: dict, stream: Stream, secure: bool): self.sid = gen_stream_id() self.buffer = bytearray(STREAM_CHUNK_SIZE) # Optimization to send the original buffer without copying @@ -53,6 +53,7 @@ def __init__(self, channel: str, topic: str, target: str, headers: dict, stream: self.seq = 0 self.offset = 0 self.offset_ack = 0 + self.secure = secure def __str__(self): return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]" @@ -69,8 +70,8 @@ def __init__(self, cell: CoreCell): def get_chunk_size(): return STREAM_CHUNK_SIZE - def send(self, channel: str, topic: str, target: str, headers: dict, stream: Stream) -> StreamFuture: - tx_task = TxTask(channel, topic, target, headers, stream) + def send(self, channel: str, topic: str, target: str, headers: dict, stream: Stream, secure=False) -> StreamFuture: + tx_task = TxTask(channel, topic, target, headers, stream, secure) with self.map_lock: self.tx_task_map[tx_task.sid] = tx_task @@ -118,7 +119,7 @@ def _transmit_task(self, task: TxTask): def _transmit(self, task: TxTask, final=False): if task.buffer_size == 0: - payload = None + payload = bytes(0) elif task.buffer_size == STREAM_CHUNK_SIZE: if task.direct_buf: payload = task.direct_buf @@ -151,7 +152,7 @@ def _transmit(self, task: TxTask, final=False): } ) - errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message) + errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure) error = errors.get(task.target) if error: msg = f"Message sending error to target {task.target}: {error}" @@ -186,7 +187,7 @@ def _stop_task(self, task: TxTask, error: StreamError = None, notify=True): StreamHeaderKey.ERROR_MSG: str(error), } ) - self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message) + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure) else: # Result is the number of bytes streamed task.stream_future.set_result(task.offset) diff --git a/nvflare/fuel/f3/streaming/file_streamer.py b/nvflare/fuel/f3/streaming/file_streamer.py index 7dd785f873..9ead52517c 100644 --- a/nvflare/fuel/f3/streaming/file_streamer.py +++ b/nvflare/fuel/f3/streaming/file_streamer.py @@ -87,7 +87,7 @@ def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): self.byte_streamer = byte_streamer self.byte_receiver = byte_receiver - def send(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + def send(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture: file_name = Path(message.payload).name file_stream = FileStream(message.payload, message.headers) @@ -98,7 +98,7 @@ def send(self, channel: str, topic: str, target: str, message: Message) -> Strea } ) - return self.byte_streamer.send(channel, topic, target, message.headers, file_stream) + return self.byte_streamer.send(channel, topic, target, message.headers, file_stream, secure) def register_file_callback(self, channel, topic, file_cb: Callable, *args, **kwargs): handler = FileHandler(file_cb) diff --git a/nvflare/fuel/f3/streaming/stream_const.py b/nvflare/fuel/f3/streaming/stream_const.py index 8fe8ee97b7..60977d284e 100644 --- a/nvflare/fuel/f3/streaming/stream_const.py +++ b/nvflare/fuel/f3/streaming/stream_const.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + STREAM_PREFIX = "sm__" STREAM_CHANNEL = STREAM_PREFIX + "STREAM" STREAM_DATA_TOPIC = STREAM_PREFIX + "DATA" STREAM_ACK_TOPIC = STREAM_PREFIX + "ACK" +STREAM_CERT_TOPIC = STREAM_PREFIX + "CERT" + # End of Stream indicator EOS = bytes() @@ -49,3 +52,4 @@ class StreamHeaderKey: OBJECT_STREAM_ID = STREAM_PREFIX + "os" OBJECT_INDEX = STREAM_PREFIX + "oi" STREAM_REQ_ID = STREAM_PREFIX + "ri" + PAYLOAD_ENCODING = STREAM_PREFIX + "pe"