Skip to content

Commit

Permalink
Refactor overseer_agent to AdminAPI from cli, add ha commands (NVIDIA…
Browse files Browse the repository at this point in the history
…#302)

* refactor overseer_agent to AdminAPI from cli, add working ha commands

* Fix CI

* fix CI

* remove check because it was performed earlier

* fix CI
  • Loading branch information
nvkevlu authored Mar 15, 2022
1 parent 5ac8a52 commit a5a589b
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 47 deletions.
98 changes: 91 additions & 7 deletions nvflare/fuel/hci/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import traceback
from datetime import datetime

from nvflare.apis.overseer_spec import SP, OverseerAgent
from nvflare.fuel.hci.cmd_arg_utils import split_to_args
from nvflare.fuel.hci.conn import Connection, receive_and_process
from nvflare.fuel.hci.proto import make_error
from nvflare.fuel.hci.reg import CommandModule, CommandRegister
from nvflare.fuel.hci.security import get_certificate_common_name
from nvflare.fuel.hci.table import Table
from nvflare.ha.ha_admin_cmds import HACommandModule

from .api_spec import AdminAPISpec, ReplyProcessor
from .api_status import APIStatus
Expand Down Expand Up @@ -89,15 +91,19 @@ class AdminAPI(AdminAPISpec):
"""Underlying API to keep certs, keys and connection information and to execute admin commands through do_command.
Args:
host: cn provisioned for the project, with this fully qualified domain name resolving to the IP of the FL server
port: port provisioned as admin_port for FL admin communication, by default provisioned as 8003, must be int
host: cn provisioned for the server, with this fully qualified domain name resolving to the IP of the FL server. This may be set by the OverseerAgent.
port: port provisioned as admin_port for FL admin communication, by default provisioned as 8003, must be int if provided. This may be set by the OverseerAgent.
ca_cert: path to CA Cert file, by default provisioned rootCA.pem
client_cert: path to admin client Cert file, by default provisioned as client.crt
client_key: path to admin client Key file, by default provisioned as client.key
upload_dir: File transfer upload directory. Folders uploaded to the server to be deployed must be here. Folder must already exist and be accessible.
download_dir: File transfer download directory. Can be same as upload_dir. Folder must already exist and be accessible.
server_cn: server cn (only used for validating server cn)
cmd_modules: command modules to load and register
cmd_modules: command modules to load and register. Note that FileTransferModule is initialized here with upload_dir and download_dir if cmd_modules is None.
overseer_agent: initialized OverseerAgent to obtain the primary service provider to set the host and port of the active server
auto_login: Whether to use stored credentials to automatically log in (required to be True with OverseerAgent to provide high availability)
user_name: Username to authenticate with FL server
password: Password to authenticate with FL server (not used in secure mode with SSL)
poc: Whether to enable poc mode for using the proof of concept example without secure communication.
debug: Whether to print debug messages, which can help with diagnosing problems. False by default.
"""
Expand All @@ -114,6 +120,9 @@ def __init__(
server_cn=None,
cmd_modules=None,
overseer_agent=None,
auto_login: bool = False,
user_name: str = None,
password: str = None,
poc: bool = False,
debug: bool = False,
):
Expand All @@ -122,6 +131,17 @@ def __init__(
from .file_transfer import FileTransferModule

cmd_modules = [FileTransferModule(upload_dir=upload_dir, download_dir=download_dir)]
elif not isinstance(cmd_modules, list):
raise TypeError("cmd_modules must be a list, but got {}".format(type(cmd_modules)))
else:
for m in cmd_modules:
if not isinstance(m, CommandModule):
raise TypeError(
"cmd_modules must be a list of CommandModule, but got element of type {}".format(type(m))
)
cmd_modules.append(HACommandModule())

self.overseer_agent = overseer_agent
self.host = host
self.port = port
self.poc = poc
Expand All @@ -135,12 +155,31 @@ def __init__(
if len(client_key) <= 0:
raise Exception("missing Client Key file name")
self.client_key = client_key
if not isinstance(self.overseer_agent, OverseerAgent):
raise Exception("overseer_agent is missing but must be provided for secure context.")
self.overseer_agent.set_secure_context(
ca_path=self.ca_cert, cert_path=self.client_cert, prv_key_path=self.client_key
)
if self.overseer_agent:
self.overseer_agent.start(self._overseer_callback)
self.server_cn = server_cn
self.debug = debug

# for overseer agent
self.ssid = None

# for login
self.token = None
self.login_result = None
if auto_login:
self.auto_login = True
if not user_name:
raise Exception("for auto_login, user_name is required.")
self.user_name = user_name
if self.poc:
if not password:
raise Exception("for auto_login, password is required for credential_type password.")
self.password = password

self.server_cmd_reg = CommandRegister(app_ctx=self)
self.client_cmd_reg = CommandRegister(app_ctx=self)
Expand All @@ -158,13 +197,58 @@ def __init__(
self.sess_monitor_thread = None
self.sess_monitor_active = False

def _overseer_callback(self, overseer_agent):
sp = overseer_agent.get_primary_sp()
self._set_primary_sp(sp)

def _set_primary_sp(self, sp: SP):
if sp and sp.primary is True:
if self.host != sp.name or self.port != int(sp.admin_port) or self.ssid != sp.service_session_id:
# if needing to log out of previous server, this may be where to issue server_execute("_logout")
self.host = sp.name
self.port = int(sp.admin_port)
self.ssid = sp.service_session_id
print(
f"Got primary SP {self.host}:{sp.fl_port}:{self.port} from overseer. Host: {self.host} Admin_port: {self.port} SSID: {self.ssid}"
)

thread = threading.Thread(target=self._login_sp)
thread.start()

def _login_sp(self):
if not self._auto_login():
print("cannot log in, shutting down...")
self.shutdown_received = True

def _auto_login(self):
try_count = 0
while try_count < 5:
if self.poc:
self.login_with_password(username=self.user_name, password=self.password)
print(f"login_result: {self.login_result} token: {self.token}")
if self.login_result == "OK":
return True
elif self.login_result == "REJECT":
print("Incorrect password.")
return False
else:
print("Communication Error - please try later")
try_count += 1
else:
self.login(username=self.user_name)
if self.login_result == "OK":
return True
elif self.login_result == "REJECT":
print("Incorrect user name or certificate.")
return False
else:
print("Communication Error - please try later")
try_count += 1
return False

def _load_client_cmds(self, cmd_modules):
if cmd_modules:
if not isinstance(cmd_modules, list):
raise TypeError("cmd_modules must be a list")
for m in cmd_modules:
if not isinstance(m, CommandModule):
raise TypeError("cmd_modules must be a list of CommandModule")
self.client_cmd_reg.register_module(m, include_invisible=False)
self.client_cmd_reg.finalize(self.register_command)

Expand Down
57 changes: 22 additions & 35 deletions nvflare/fuel/hci/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
import json
import os
import signal
import threading
import time
import traceback
from datetime import datetime
from enum import Enum
from functools import partial
from typing import List, Optional

from nvflare.apis.overseer_spec import SP, OverseerAgent
from nvflare.apis.overseer_spec import OverseerAgent
from nvflare.fuel.hci.cmd_arg_utils import join_args, split_to_args
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandRegister, CommandSpec
from nvflare.fuel.hci.security import hash_password, verify_password
Expand Down Expand Up @@ -58,8 +57,8 @@ class AdminClient(cmd.Cmd):
"""Admin command prompt for submitting admin commands to the server through the CLI.
Args:
host: cn provisioned for the project, with this fully qualified domain name resolving to the IP of the FL server
port: port provisioned as admin_port for FL admin communication, by default provisioned as 8003, must be int
host: cn provisioned for the server, with this fully qualified domain name resolving to the IP of the FL server. This may be set by the OverseerAgent.
port: port provisioned as admin_port for FL admin communication, by default provisioned as 8003, must be int if provided. This may be set by the OverseerAgent.
prompt: prompt to use for the command prompt
ca_cert: path to CA Cert file, by default provisioned rootCA.pem
client_cert: path to admin client Cert file, by default provisioned as client.crt
Expand All @@ -68,9 +67,7 @@ class AdminClient(cmd.Cmd):
require_login: whether to require login
credential_type: what type of credential to use
cmd_modules: command modules to load and register
overseer_end_point: end point for the overseer in order to find the active server
project: project name to provide overseer
name: name of the provisioned admin to provide to the overseer
overseer_agent: initialized OverseerAgent to obtain the primary service provider to set the host and port of the active server
debug: whether to print debug messages. False by default.
"""

Expand Down Expand Up @@ -103,6 +100,8 @@ def __init__(
self.out_file = None
self.no_stdout = False

if not isinstance(overseer_agent, OverseerAgent):
raise TypeError("overseer_agent was not properly initialized.")
if not isinstance(credential_type, CredentialType):
raise TypeError("invalid credential_type {}".format(credential_type))

Expand All @@ -117,6 +116,8 @@ def __init__(

poc = True if self.credential_type == CredentialType.PASSWORD else False

self._get_login_creds()

self.api = AdminAPI(
host=host,
port=port,
Expand All @@ -126,39 +127,15 @@ def __init__(
server_cn=server_cn,
cmd_modules=modules,
overseer_agent=self.overseer_agent,
auto_login=True,
user_name=self.user_name,
password=self.password,
debug=self.debug,
poc=poc,
)

signal.signal(signal.SIGUSR1, partial(self.session_signal_handler))

self.ssid = None

if self.credential_type == CredentialType.CERT:
if self.overseer_agent:
self.overseer_agent.set_secure_context(ca_path=ca_cert, cert_path=client_cert, prv_key_path=client_key)

self.overseer_agent.start(self.overseer_callback)

def overseer_callback(self, overseer_agent):
sp = overseer_agent.get_primary_sp()
self.set_primary_sp(sp)

def set_primary_sp(self, sp: SP):
if sp and sp.primary is True:
if self.api.host != sp.name or self.api.port != int(sp.admin_port) or self.ssid != sp.service_session_id:
self.api.host = sp.name
self.api.port = int(sp.admin_port)
self.ssid = sp.service_session_id
print(f"Got primary SP. Host: {self.api.host} Admin_port: {self.api.port} SSID: {self.ssid}")

thread = threading.Thread(target=self._login_sp)
thread.start()

def _login_sp(self):
self.do_bye("logout")
self.login()

def session_ended(self, message):
self.write_error(message)
os.kill(os.getpid(), signal.SIGUSR1)
Expand Down Expand Up @@ -320,7 +297,7 @@ def _do_default(self, line):
ent = entries[0]
resp = ent.handler(args, self.api)
self.print_resp(resp)
if resp["status"] == APIStatus.ERROR_INACTIVE_SESSION:
if resp.get("status") == APIStatus.ERROR_INACTIVE_SESSION:
return True
return

Expand Down Expand Up @@ -432,14 +409,24 @@ def cmdloop(self, intro=None):
def run(self):

try:
self.stdout.write("Waiting for token from successful login...\n")
while self.api.token is None:
time.sleep(1.0)
if self.api.shutdown_received:
return False

# self.api.start_session_monitor(self.session_ended)
# above line was commented out, but if we want to use it, need to be logged in to call server_execute("_check_session") and consider how SP changes impact this
self.cmdloop(intro='Type ? to list commands; type "? cmdName" to show usage of a command.')
finally:
self.overseer_agent.end()

def _get_login_creds(self):
self.user_name = input("User Name: ")
if self.credential_type == CredentialType.PASSWORD:
self.password = getpass.getpass("Password: ")
self.pwd = hash_password(self.password)

def login(self):
if self.require_login:
if self.user_name:
Expand Down
8 changes: 4 additions & 4 deletions nvflare/fuel/hci/client/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_spec(self):
def upload_file(self, args, api: AdminAPISpec, cmd_name, file_to_str_func):
full_cmd_name = _server_cmd_name(cmd_name)
if len(args) < 2:
return {"status": APIStatus.ERROR_COMMAND_SYNTAX, "details": "syntax error: missing file names"}
return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"}

parts = [full_cmd_name]
for i in range(1, len(args)):
Expand All @@ -238,7 +238,7 @@ def upload_binary_file(self, args, api: AdminAPISpec):
def download_file(self, args, api: AdminAPISpec, cmd_name, str_to_file_func):
full_cmd_name = _server_cmd_name(cmd_name)
if len(args) < 2:
return {"status": APIStatus.ERROR_COMMAND_SYNTAX, "details": "syntax error: missing file names"}
return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"}

parts = [full_cmd_name]
for i in range(1, len(args)):
Expand All @@ -257,7 +257,7 @@ def download_binary_file(self, args, api: AdminAPISpec):

def upload_folder(self, args, api: AdminAPISpec):
if len(args) != 2:
return {"status": APIStatus.ERROR_COMMAND_SYNTAX, "details": "usage: upload_folder folder_name"}
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: upload_folder folder_name"}

folder_name = args[1]
if folder_name.endswith("/"):
Expand All @@ -281,7 +281,7 @@ def upload_folder(self, args, api: AdminAPISpec):

def download_folder(self, args, api: AdminAPISpec):
if len(args) != 2:
return {"status": APIStatus.ERROR_COMMAND_SYNTAX, "details": "usage: download_folder folder_name"}
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: download_folder folder_name"}

parts = [_server_cmd_name(ftd.SERVER_CMD_DOWNLOAD_FOLDER), args[1]]
command = join_args(parts)
Expand Down
Loading

0 comments on commit a5a589b

Please sign in to comment.