From f13d951d23274e96fa9c05a7bd110cccee5fd7ee Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Sat, 7 Dec 2024 16:35:16 -0500 Subject: [PATCH] support explicit client auth --- nvflare/fuel/f3/cellnet/core_cell.py | 19 +++++ nvflare/private/defs.py | 10 ++- .../private/fed/app/client/worker_process.py | 6 +- .../fed/app/deployer/simulator_deployer.py | 2 +- .../private/fed/app/server/runner_process.py | 17 +++- .../fed/app/simulator/simulator_worker.py | 2 +- .../private/fed/client/client_app_runner.py | 1 - nvflare/private/fed/client/client_engine.py | 17 ---- nvflare/private/fed/client/client_executor.py | 2 + nvflare/private/fed/client/communicator.py | 49 +++++++++--- nvflare/private/fed/client/fed_client_base.py | 9 ++- nvflare/private/fed/server/fed_server.py | 79 +++++++++++++++++++ nvflare/private/fed/server/server_engine.py | 16 +++- nvflare/private/fed/utils/identity_utils.py | 13 ++- 14 files changed, 197 insertions(+), 45 deletions(-) diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 7d0122429a..296405a290 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -406,6 +406,7 @@ def __init__( self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self) self.communicator.register_monitor(monitor=self) self.req_reg = Registry() + self.in_filter_reg = Registry() # for any incoming messages self.in_req_filter_reg = Registry() # for request received self.out_reply_filter_reg = Registry() # for reply going out self.out_req_filter_reg = Registry() # for request sent @@ -991,6 +992,11 @@ def decrypt_payload(self, message: Message): if len(message.payload) != payload_len: raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}") + def add_incoming_filter(self, channel: str, topic: str, cb, *args, **kwargs): + if not callable(cb): + raise ValueError(f"specified incoming_filter {type(cb)} is not callable") + self.in_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) + def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable") @@ -1856,6 +1862,19 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess category=self._stats_category(message), counter_name=_CounterName.RECEIVED ) + # invoke incoming filters + channel = message.get_header(MessageHeaderKey.CHANNEL, "") + topic = message.get_header(MessageHeaderKey.TOPIC, "") + in_filters = self.in_filter_reg.find(channel, topic) + if in_filters: + self.logger.debug(f"{self.my_info.fqcn}: invoking incoming filters") + assert isinstance(in_filters, list) + for f in in_filters: + assert isinstance(f, Callback) + reply = self._try_cb(message, f.cb, *f.args, **f.kwargs) + if reply: + return reply + if msg_type == MessageType.REQ and self.message_interceptor is not None: reply = self._try_cb( message, self.message_interceptor, *self.message_interceptor_args, **self.message_interceptor_kwargs diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 10c50abfa6..d8367dbb8c 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -33,8 +33,8 @@ class TaskConstant(object): class EngineConstant(object): FEDERATE_CLIENT = "federate_client" - FL_TOKEN = "fl_token" - CLIENT_TOKEN_FILE = "client_token.txt" + AUTH_TOKEN = "auth_token" + AUTH_TOKEN_SIGNATURE = "auth_token_signature" ENGINE_TASK_NAME = "engine_task_name" @@ -140,7 +140,8 @@ class CellMessageHeaderKeys: CLIENT_NAME = "client_name" CLIENT_IP = "client_ip" PROJECT_NAME = "project_name" - TOKEN = "token" + TOKEN = "__token__" + TOKEN_SIGNATURE = "__token_signature__" SSID = "ssid" UNAUTHENTICATED = "unauthenticated" JOB_ID = "job_id" @@ -149,6 +150,9 @@ class CellMessageHeaderKeys: ABORT_JOBS = "abort_jobs" +AUTH_CLIENT_NAME_FOR_SJ = "server_job" + + class JobFailureMsgKey: JOB_ID = "job_id" diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index 4b71d4cc76..b29cd16f8f 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -111,11 +111,12 @@ def main(args): federated_client = deployer.create_fed_client(args) federated_client.status = ClientStatus.STARTING + federated_client.communicator.set_auth(args.client_name, args.token, args.token_signature, args.ssid) federated_client.token = args.token + federated_client.token_signature = args.token_signature federated_client.ssid = args.ssid federated_client.client_name = args.client_name federated_client.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False) - federated_client.fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False) federated_client.fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True) client_app_runner = ClientAppRunner(time_out=kv_list.get("app_runner_timeout", 60.0)) @@ -150,7 +151,8 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True) - parser.add_argument("--token", "-t", type=str, help="token", required=True) + parser.add_argument("--token", "-t", type=str, help="auth token", required=True) + parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True) parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True) parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True) parser.add_argument("--client_name", "-c", type=str, help="client name", required=True) diff --git a/nvflare/private/fed/app/deployer/simulator_deployer.py b/nvflare/private/fed/app/deployer/simulator_deployer.py index cae1858bb3..bb8c615af4 100644 --- a/nvflare/private/fed/app/deployer/simulator_deployer.py +++ b/nvflare/private/fed/app/deployer/simulator_deployer.py @@ -100,7 +100,7 @@ def _create_client_cell(self, client_config, client_name, federated_client): ) cell.start() federated_client.cell = cell - federated_client.communicator.cell = cell + federated_client.communicator.set_cell(cell) # if self.engine: # self.engine.admin_agent.register_cell_cb() diff --git a/nvflare/private/fed/app/server/runner_process.py b/nvflare/private/fed/app/server/runner_process.py index f839e8e210..e4466faf6a 100644 --- a/nvflare/private/fed/app/server/runner_process.py +++ b/nvflare/private/fed/app/server/runner_process.py @@ -23,12 +23,13 @@ from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SystemConfigs from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.sec.audit import AuditService from nvflare.fuel.sec.security_content_service import SecurityContentService from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService -from nvflare.private.defs import AppFolderConstants +from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants, CellMessageHeaderKeys from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.server.server_app_runner import ServerAppRunner @@ -112,6 +113,12 @@ def main(args): server.cell = server.create_job_cell( args.job_id, args.root_url, args.parent_url, secure_train, server_config ) + + # set filter to add additional auth headers + server.cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + + server.cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + server.server_state = HotState(host=args.host, port=args.port, ssid=args.ssid) snapshot = None @@ -142,6 +149,13 @@ def main(args): raise e +def _add_auth_headers(message: CellMessage, config): + message.set_header(CellMessageHeaderKeys.SSID, config.ssid) + message.set_header(CellMessageHeaderKeys.CLIENT_NAME, AUTH_CLIENT_NAME_FOR_SJ) + message.set_header(CellMessageHeaderKeys.TOKEN, config.job_id) + message.set_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, config.token_signature) + + def parse_arguments(): """FL Server program starting point.""" parser = argparse.ArgumentParser() @@ -151,6 +165,7 @@ def parse_arguments(): ) parser.add_argument("--app_root", "-r", type=str, help="App Root", required=True) parser.add_argument("--job_id", "-n", type=str, help="job id", required=True) + parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True) parser.add_argument("--root_url", "-u", type=str, help="root_url", required=True) parser.add_argument("--host", "-host", type=str, help="server host", required=True) parser.add_argument("--port", "-port", type=str, help="service port", required=True) diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index 262f78026f..d320fec685 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -214,7 +214,7 @@ def _create_client_cell(self, federated_client, root_url, parent_url): cell.start() mpm.add_cleanup_cb(cell.stop) federated_client.cell = cell - federated_client.communicator.cell = cell + federated_client.communicator.set_cell(cell) start = time.time() while not cell.is_cell_connected(FQCN.ROOT_SERVER): diff --git a/nvflare/private/fed/client/client_app_runner.py b/nvflare/private/fed/client/client_app_runner.py index 4ac3d3cb09..c1e6e793fc 100644 --- a/nvflare/private/fed/client/client_app_runner.py +++ b/nvflare/private/fed/client/client_app_runner.py @@ -70,7 +70,6 @@ def start_run(self, app_root, args, config_folder, federated_client, secure_trai @staticmethod def _set_fl_context(fl_ctx: FLContext, app_root, args, workspace, secure_train): fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False) - fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False) fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True) fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True) fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, private=True, sticky=True) diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 6f7c8d6690..fc556510ba 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -185,23 +185,6 @@ def notify_job_status(self, job_id: str, job_status): def get_client_name(self): return self.client.client_name - def _write_token_file(self, job_id, open_port): - token_file = os.path.join(self.args.workspace, EngineConstant.CLIENT_TOKEN_FILE) - if os.path.exists(token_file): - os.remove(token_file) - with open(token_file, "wt") as f: - f.write( - "%s\n%s\n%s\n%s\n%s\n%s\n" - % ( - self.client.token, - self.client.ssid, - job_id, - self.client.client_name, - open_port, - list(self.client.servers.values())[0]["target"], - ) - ) - def abort_app(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.STOPPED: diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index 8830cdbf9e..79f9646f25 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -184,6 +184,8 @@ def start_app( + self.startup + " -t " + client.token + + " -ts " + + client.token_signature + " -d " + client.ssid + " -n " diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 06e0eccad9..c5da940c84 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -31,6 +31,7 @@ from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.utils import format_size +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec from nvflare.private.fed.utils.fed_utils import get_scope_prop @@ -93,7 +94,39 @@ def __init__( self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout + # token and token_signature are issued by the Server after the client is authenticated + # they are added to every message going to the server as proof of authentication + self.token = None + self.token_signature = None + self.ssid = None + self.client_name = None + self.logger = logging.getLogger(self.__class__.__name__) + self.logger.info(f"==== Communicator GOT CELL: {type(cell)}") + + def set_auth(self, client_name, token, token_signature, ssid): + self.ssid = ssid + self.token_signature = token_signature + self.token = token + self.client_name = client_name + + def set_cell(self, cell): + self.cell = cell + + # set filter to add additional auth headers + cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) + cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) + + def _add_auth_headers(self, message: CellMessage): + if self.ssid: + message.set_header(CellMessageHeaderKeys.SSID, self.ssid) + + if self.client_name: + message.set_header(CellMessageHeaderKeys.CLIENT_NAME, self.client_name) + + if self.token: + message.set_header(CellMessageHeaderKeys.TOKEN, self.token) + message.set_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, self.token_signature) def _challenge_server(self, client_name, expected_host, root_cert_file): # ask server for its info and make sure that it matches expected host @@ -252,17 +285,19 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): raise FLCommunicationError("error:client_registration " + reason) token = result.get_header(CellMessageHeaderKeys.TOKEN) + token_signature = result.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") ssid = result.get_header(CellMessageHeaderKeys.SSID) if not token and not self.should_stop: time.sleep(self.client_register_interval) else: + self.set_auth(client_name, token, token_signature, ssid) break except Exception as ex: traceback.print_exc() raise FLCommunicationError("error:client_registration", ex) - return token, ssid + return token, token_signature, ssid def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. @@ -285,9 +320,6 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): client_name = fl_ctx.get_identity_name() task_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, @@ -361,9 +393,6 @@ def submit_update( task_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, @@ -410,9 +439,6 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, }, shareable, @@ -452,9 +478,6 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.SSID: ssid, - CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, }, diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index acad0c9875..b0593c3c53 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -31,7 +31,6 @@ from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars -from nvflare.private.defs import EngineConstant from nvflare.private.fed.utils.fed_utils import set_scope_prop from nvflare.security.logging import secure_format_exception @@ -77,6 +76,7 @@ def __init__( self.client_name = client_name self.token = None + self.token_signature = None self.ssid = None self.client_args = client_args self.servers = server_args @@ -207,7 +207,7 @@ def _create_cell(self, location, scheme): parent_url=parent_url, ) self.cell.start() - self.communicator.cell = self.cell + self.communicator.set_cell(self.cell) self.net_agent = NetAgent(self.cell) mpm.add_cleanup_cb(self.net_agent.close) mpm.add_cleanup_cb(self.cell.stop) @@ -250,10 +250,11 @@ def client_register(self, project_name, fl_ctx: FLContext): """ if not self.token: try: - self.token, self.ssid = self.communicator.client_registration(self.client_name, project_name, fl_ctx) + self.token, self.token_signature, self.ssid = self.communicator.client_registration( + self.client_name, project_name, fl_ctx + ) if self.token is not None: self.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, self.client_name, private=False) - self.fl_ctx.set_prop(EngineConstant.FL_TOKEN, self.token, private=False) self.logger.info( "Successfully registered client:{} for project {}. Token:{} SSID:{}".format( self.client_name, project_name, self.token, self.ssid diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index c21b9e7bc7..413a95e6c0 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -17,6 +17,7 @@ import shutil import threading import time +import uuid from abc import ABC, abstractmethod from threading import Lock from typing import Dict, List, Optional @@ -337,6 +338,11 @@ def __init__( self.name_to_reg = {} self.id_asserter = None + # these are used when the server sends a message to itself. + self.my_own_auth_client_name = "server" + self.my_own_token = "server" + self.my_own_token_signature = None + def _register_cellnet_cbs(self): self.cell.register_request_cb( channel=CellChannel.SERVER_MAIN, @@ -384,6 +390,57 @@ def _register_cellnet_cbs(self): reg_checker = threading.Thread(target=self._check_regs, daemon=True) reg_checker.start() + def _add_auth_headers(self, message: Message): + origin = message.get_header(MessageHeaderKey.ORIGIN) + dest = message.get_header(MessageHeaderKey.DESTINATION) + if origin == FQCN.ROOT_SERVER and dest == origin: + message.set_header(CellMessageHeaderKeys.CLIENT_NAME, self.my_own_auth_client_name) + message.set_header(CellMessageHeaderKeys.TOKEN, self.my_own_token) + + if not self.my_own_token_signature: + self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) + message.set_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, self.my_own_token_signature) + + def _validate_auth_headers(self, message: Message): + headers = message.headers + self.logger.info(f"**** _validate_auth_headers: {headers=}") + topic = message.get_header(MessageHeaderKey.TOPIC) + channel = message.get_header(MessageHeaderKey.CHANNEL) + + origin = message.get_header(MessageHeaderKey.ORIGIN) + + if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: + # skip: client not registered yet + self.logger.info(f"skip special message {topic=} {channel=}") + return None + + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) + if not client_name: + err = "missing client name" + self.logger.error(f"unauthenticated msg received from {origin}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED) + + token = message.get_header(CellMessageHeaderKeys.TOKEN) + if not token: + err = "missing auth token" + self.logger.error(f"unauthenticated msg received from {origin}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED) + + signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) + if not signature: + err = "missing auth token signature" + self.logger.error(f"unauthenticated msg received from {origin}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED) + + if not self.verify_auth_token(client_name, token, signature): + err = "invalid auth token signature" + self.logger.error(f"unauthenticated msg received from {origin}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED) + + # all good + self.logger.info(f"auth valid from {origin}: {topic=} {channel=}") + return None + def _check_regs(self): while True: with self.reg_lock: @@ -554,6 +611,14 @@ def _get_id_asserter(self): self.id_asserter = IdentityAsserter(private_key_file=private_key_file, cert_file=cert_file) return self.id_asserter + def sign_auth_token(self, client_name: str, token: str): + id_asserter = self._get_id_asserter() + return id_asserter.sign(client_name + token, return_str=True) + + def verify_auth_token(self, client_name: str, token: str, signature): + id_asserter = self._get_id_asserter() + return id_asserter.verify_signature(client_name + token, signature) + def _ready_for_registration(self, fl_ctx: FLContext): self._before_service(fl_ctx) state_check = self.server_state.register(fl_ctx) @@ -634,8 +699,10 @@ def register_client(self, request: Message) -> Message: if self.admin_server: self.admin_server.client_heartbeat(client.token, client.name) + token_signature = self.sign_auth_token(client.name, client.token) headers = { CellMessageHeaderKeys.TOKEN: client.token, + CellMessageHeaderKeys.TOKEN_SIGNATURE: token_signature, CellMessageHeaderKeys.SSID: self.server_state.ssid, } else: @@ -906,6 +973,18 @@ def deploy(self, args, grpc_args=None, secure_train=False): self.engine.cell = self.cell self._register_cellnet_cbs() + if secure_train: + core_cell = self.cell.core_cell + core_cell.add_incoming_filter( + channel="*", + topic="*", + cb=self._validate_auth_headers, + ) + + # set filter to add additional auth headers + core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) + core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) + self.overseer_agent.start(self.overseer_callback) def _init_agent(self, args=None): diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 65393b7e94..3724ca5509 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -53,7 +53,14 @@ from nvflare.fuel.utils.network_utils import get_open_ports from nvflare.fuel.utils.zip_utils import zip_directory_to_bytes from nvflare.private.admin_defs import Message, MsgHeader -from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, RequestHeader, TrainingTopic, new_cell_message +from nvflare.private.defs import ( + AUTH_CLIENT_NAME_FOR_SJ, + CellChannel, + CellMessageHeaderKeys, + RequestHeader, + TrainingTopic, + new_cell_message, +) from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator from nvflare.private.fed.server.server_state import ServerState from nvflare.private.fed.utils.fed_utils import ( @@ -253,6 +260,11 @@ def _start_runner_process( for t in args.set: command_options += " " + t + # create token and signature for SJ + token = run_number # use the run_number as the auth token + client_name = AUTH_CLIENT_NAME_FOR_SJ + signature = self.server.sign_auth_token(client_name, token) + command = ( sys.executable + " -m nvflare.private.fed.app.server.runner_process -m " @@ -261,6 +273,8 @@ def _start_runner_process( + app_root + " -n " + str(run_number) + + " -ts " + + signature + " -p " + str(cell.get_internal_listener_url()) + " -u " diff --git a/nvflare/private/fed/utils/identity_utils.py b/nvflare/private/fed/utils/identity_utils.py index e45fb2562b..73a5bfcaf6 100644 --- a/nvflare/private/fed/utils/identity_utils.py +++ b/nvflare/private/fed/utils/identity_utils.py @@ -67,9 +67,20 @@ def __init__(self, private_key_file: str, cert_file: str): self.cert = load_cert_bytes(self.cert_data) self.cn = get_cn_from_cert(self.cert) - def sign_common_name(self, nonce: str) -> str: + def sign_common_name(self, nonce: str): return sign_content(self.cn + nonce, self.pri_key, return_str=False) + def sign(self, content, return_str: bool) -> str: + return sign_content(content, self.pri_key, return_str=return_str) + + def verify_signature(self, content, signature) -> bool: + pub_key = self.cert.public_key() + try: + verify_content(content=content, signature=signature, public_key=pub_key) + return True + except Exception: + return False + class IdentityVerifier: def __init__(self, root_cert_file: str):