Skip to content

Commit

Permalink
support explicit client auth
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Dec 7, 2024
1 parent a7545a2 commit f13d951
Show file tree
Hide file tree
Showing 14 changed files with 197 additions and 45 deletions.
19 changes: 19 additions & 0 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def __init__(
self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self)
self.communicator.register_monitor(monitor=self)
self.req_reg = Registry()
self.in_filter_reg = Registry() # for any incoming messages
self.in_req_filter_reg = Registry() # for request received
self.out_reply_filter_reg = Registry() # for reply going out
self.out_req_filter_reg = Registry() # for request sent
Expand Down Expand Up @@ -991,6 +992,11 @@ def decrypt_payload(self, message: Message):
if len(message.payload) != payload_len:
raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}")

def add_incoming_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_filter {type(cb)} is not callable")
self.in_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable")
Expand Down Expand Up @@ -1856,6 +1862,19 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess
category=self._stats_category(message), counter_name=_CounterName.RECEIVED
)

# invoke incoming filters
channel = message.get_header(MessageHeaderKey.CHANNEL, "")
topic = message.get_header(MessageHeaderKey.TOPIC, "")
in_filters = self.in_filter_reg.find(channel, topic)
if in_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming filters")
assert isinstance(in_filters, list)
for f in in_filters:
assert isinstance(f, Callback)
reply = self._try_cb(message, f.cb, *f.args, **f.kwargs)
if reply:
return reply

if msg_type == MessageType.REQ and self.message_interceptor is not None:
reply = self._try_cb(
message, self.message_interceptor, *self.message_interceptor_args, **self.message_interceptor_kwargs
Expand Down
10 changes: 7 additions & 3 deletions nvflare/private/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class TaskConstant(object):
class EngineConstant(object):

FEDERATE_CLIENT = "federate_client"
FL_TOKEN = "fl_token"
CLIENT_TOKEN_FILE = "client_token.txt"
AUTH_TOKEN = "auth_token"
AUTH_TOKEN_SIGNATURE = "auth_token_signature"
ENGINE_TASK_NAME = "engine_task_name"


Expand Down Expand Up @@ -140,7 +140,8 @@ class CellMessageHeaderKeys:
CLIENT_NAME = "client_name"
CLIENT_IP = "client_ip"
PROJECT_NAME = "project_name"
TOKEN = "token"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"
SSID = "ssid"
UNAUTHENTICATED = "unauthenticated"
JOB_ID = "job_id"
Expand All @@ -149,6 +150,9 @@ class CellMessageHeaderKeys:
ABORT_JOBS = "abort_jobs"


AUTH_CLIENT_NAME_FOR_SJ = "server_job"


class JobFailureMsgKey:

JOB_ID = "job_id"
Expand Down
6 changes: 4 additions & 2 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ def main(args):
federated_client = deployer.create_fed_client(args)
federated_client.status = ClientStatus.STARTING

federated_client.communicator.set_auth(args.client_name, args.token, args.token_signature, args.ssid)
federated_client.token = args.token
federated_client.token_signature = args.token_signature
federated_client.ssid = args.ssid
federated_client.client_name = args.client_name
federated_client.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False)
federated_client.fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False)
federated_client.fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True)

client_app_runner = ClientAppRunner(time_out=kv_list.get("app_runner_timeout", 60.0))
Expand Down Expand Up @@ -150,7 +151,8 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True)
parser.add_argument("--token", "-t", type=str, help="token", required=True)
parser.add_argument("--token", "-t", type=str, help="auth token", required=True)
parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True)
parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/app/deployer/simulator_deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _create_client_cell(self, client_config, client_name, federated_client):
)
cell.start()
federated_client.cell = cell
federated_client.communicator.cell = cell
federated_client.communicator.set_cell(cell)
# if self.engine:
# self.engine.admin_agent.register_cell_cb()

Expand Down
17 changes: 16 additions & 1 deletion nvflare/private/fed/app/server/runner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SystemConfigs
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants, CellMessageHeaderKeys
from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger
from nvflare.private.fed.app.utils import monitor_parent_process
from nvflare.private.fed.server.server_app_runner import ServerAppRunner
Expand Down Expand Up @@ -112,6 +113,12 @@ def main(args):
server.cell = server.create_job_cell(
args.job_id, args.root_url, args.parent_url, secure_train, server_config
)

# set filter to add additional auth headers
server.cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=_add_auth_headers, config=args)

server.cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=_add_auth_headers, config=args)

server.server_state = HotState(host=args.host, port=args.port, ssid=args.ssid)

snapshot = None
Expand Down Expand Up @@ -142,6 +149,13 @@ def main(args):
raise e


def _add_auth_headers(message: CellMessage, config):
message.set_header(CellMessageHeaderKeys.SSID, config.ssid)
message.set_header(CellMessageHeaderKeys.CLIENT_NAME, AUTH_CLIENT_NAME_FOR_SJ)
message.set_header(CellMessageHeaderKeys.TOKEN, config.job_id)
message.set_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, config.token_signature)


def parse_arguments():
"""FL Server program starting point."""
parser = argparse.ArgumentParser()
Expand All @@ -151,6 +165,7 @@ def parse_arguments():
)
parser.add_argument("--app_root", "-r", type=str, help="App Root", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job id", required=True)
parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True)
parser.add_argument("--root_url", "-u", type=str, help="root_url", required=True)
parser.add_argument("--host", "-host", type=str, help="server host", required=True)
parser.add_argument("--port", "-port", type=str, help="service port", required=True)
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/app/simulator/simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _create_client_cell(self, federated_client, root_url, parent_url):
cell.start()
mpm.add_cleanup_cb(cell.stop)
federated_client.cell = cell
federated_client.communicator.cell = cell
federated_client.communicator.set_cell(cell)

start = time.time()
while not cell.is_cell_connected(FQCN.ROOT_SERVER):
Expand Down
1 change: 0 additions & 1 deletion nvflare/private/fed/client/client_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def start_run(self, app_root, args, config_folder, federated_client, secure_trai
@staticmethod
def _set_fl_context(fl_ctx: FLContext, app_root, args, workspace, secure_train):
fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False)
fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False)
fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True)
fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True)
fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, private=True, sticky=True)
Expand Down
17 changes: 0 additions & 17 deletions nvflare/private/fed/client/client_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,6 @@ def notify_job_status(self, job_id: str, job_status):
def get_client_name(self):
return self.client.client_name

def _write_token_file(self, job_id, open_port):
token_file = os.path.join(self.args.workspace, EngineConstant.CLIENT_TOKEN_FILE)
if os.path.exists(token_file):
os.remove(token_file)
with open(token_file, "wt") as f:
f.write(
"%s\n%s\n%s\n%s\n%s\n%s\n"
% (
self.client.token,
self.client.ssid,
job_id,
self.client.client_name,
open_port,
list(self.client.servers.values())[0]["target"],
)
)

def abort_app(self, job_id: str) -> str:
status = self.client_executor.get_status(job_id)
if status == ClientStatus.STOPPED:
Expand Down
2 changes: 2 additions & 0 deletions nvflare/private/fed/client/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def start_app(
+ self.startup
+ " -t "
+ client.token
+ " -ts "
+ client.token_signature
+ " -d "
+ client.ssid
+ " -n "
Expand Down
49 changes: 36 additions & 13 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell
from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.utils import format_size
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message
from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec
from nvflare.private.fed.utils.fed_utils import get_scope_prop
Expand Down Expand Up @@ -93,7 +94,39 @@ def __init__(
self.timeout = timeout
self.maint_msg_timeout = maint_msg_timeout

# token and token_signature are issued by the Server after the client is authenticated
# they are added to every message going to the server as proof of authentication
self.token = None
self.token_signature = None
self.ssid = None
self.client_name = None

self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(f"==== Communicator GOT CELL: {type(cell)}")

def set_auth(self, client_name, token, token_signature, ssid):
self.ssid = ssid
self.token_signature = token_signature
self.token = token
self.client_name = client_name

def set_cell(self, cell):
self.cell = cell

# set filter to add additional auth headers
cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers)
cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers)

def _add_auth_headers(self, message: CellMessage):
if self.ssid:
message.set_header(CellMessageHeaderKeys.SSID, self.ssid)

if self.client_name:
message.set_header(CellMessageHeaderKeys.CLIENT_NAME, self.client_name)

if self.token:
message.set_header(CellMessageHeaderKeys.TOKEN, self.token)
message.set_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, self.token_signature)

def _challenge_server(self, client_name, expected_host, root_cert_file):
# ask server for its info and make sure that it matches expected host
Expand Down Expand Up @@ -252,17 +285,19 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext):
raise FLCommunicationError("error:client_registration " + reason)

token = result.get_header(CellMessageHeaderKeys.TOKEN)
token_signature = result.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA")
ssid = result.get_header(CellMessageHeaderKeys.SSID)
if not token and not self.should_stop:
time.sleep(self.client_register_interval)
else:
self.set_auth(client_name, token, token_signature, ssid)
break

except Exception as ex:
traceback.print_exc()
raise FLCommunicationError("error:client_registration", ex)

return token, ssid
return token, token_signature, ssid

def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None):
"""Get a task from server.
Expand All @@ -285,9 +320,6 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None):
client_name = fl_ctx.get_identity_name()
task_message = new_cell_message(
{
CellMessageHeaderKeys.TOKEN: token,
CellMessageHeaderKeys.CLIENT_NAME: client_name,
CellMessageHeaderKeys.SSID: ssid,
CellMessageHeaderKeys.PROJECT_NAME: project_name,
},
shareable,
Expand Down Expand Up @@ -361,9 +393,6 @@ def submit_update(

task_message = new_cell_message(
{
CellMessageHeaderKeys.TOKEN: token,
CellMessageHeaderKeys.CLIENT_NAME: client_name,
CellMessageHeaderKeys.SSID: ssid,
CellMessageHeaderKeys.PROJECT_NAME: project_name,
},
shareable,
Expand Down Expand Up @@ -410,9 +439,6 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext):
client_name = fl_ctx.get_identity_name()
quit_message = new_cell_message(
{
CellMessageHeaderKeys.TOKEN: token,
CellMessageHeaderKeys.CLIENT_NAME: client_name,
CellMessageHeaderKeys.SSID: ssid,
CellMessageHeaderKeys.PROJECT_NAME: task_name,
},
shareable,
Expand Down Expand Up @@ -452,9 +478,6 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C
job_ids = engine.get_all_job_ids()
heartbeat_message = new_cell_message(
{
CellMessageHeaderKeys.TOKEN: token,
CellMessageHeaderKeys.SSID: ssid,
CellMessageHeaderKeys.CLIENT_NAME: client_name,
CellMessageHeaderKeys.PROJECT_NAME: task_name,
CellMessageHeaderKeys.JOB_IDS: job_ids,
},
Expand Down
9 changes: 5 additions & 4 deletions nvflare/private/fed/client/fed_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.private.defs import EngineConstant
from nvflare.private.fed.utils.fed_utils import set_scope_prop
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -77,6 +76,7 @@ def __init__(

self.client_name = client_name
self.token = None
self.token_signature = None
self.ssid = None
self.client_args = client_args
self.servers = server_args
Expand Down Expand Up @@ -207,7 +207,7 @@ def _create_cell(self, location, scheme):
parent_url=parent_url,
)
self.cell.start()
self.communicator.cell = self.cell
self.communicator.set_cell(self.cell)
self.net_agent = NetAgent(self.cell)
mpm.add_cleanup_cb(self.net_agent.close)
mpm.add_cleanup_cb(self.cell.stop)
Expand Down Expand Up @@ -250,10 +250,11 @@ def client_register(self, project_name, fl_ctx: FLContext):
"""
if not self.token:
try:
self.token, self.ssid = self.communicator.client_registration(self.client_name, project_name, fl_ctx)
self.token, self.token_signature, self.ssid = self.communicator.client_registration(
self.client_name, project_name, fl_ctx
)
if self.token is not None:
self.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, self.client_name, private=False)
self.fl_ctx.set_prop(EngineConstant.FL_TOKEN, self.token, private=False)
self.logger.info(
"Successfully registered client:{} for project {}. Token:{} SSID:{}".format(
self.client_name, project_name, self.token, self.ssid
Expand Down
Loading

0 comments on commit f13d951

Please sign in to comment.