Skip to content

Commit

Permalink
Merge branch 'master' into feature/document-common
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Dec 1, 2023
2 parents 19fcd11 + 6c2e7dc commit 722fdd5
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 115 deletions.
5 changes: 0 additions & 5 deletions fedn/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import logging

import click

logging.basicConfig(format='%(asctime)s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p') # , level=logging.DEBUG)

CONTEXT_SETTINGS = dict(
# Support -h as a shortcut for --help
help_option_names=['-h', '--help'],
Expand Down
10 changes: 6 additions & 4 deletions fedn/cli/run_cmd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
import uuid

import click
Expand Down Expand Up @@ -102,13 +101,15 @@ def run_cmd(ctx):
@click.option('-tr', '--trainer', required=False, default=True)
@click.option('-in', '--init', required=False, default=None,
help='Set to a filename to (re)init client from file state.')
@click.option('-l', '--logfile', required=False, default='{}-client.log'.format(time.strftime("%Y%m%d-%H%M%S")),
@click.option('-l', '--logfile', required=False, default=None,
help='Set logfile for client log to file.')
@click.option('--heartbeat-interval', required=False, default=2)
@click.option('--reconnect-after-missed-heartbeat', required=False, default=30)
@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False))
@click.pass_context
def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert,
verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat):
verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat,
verbosity):
"""
:param ctx:
Expand All @@ -127,14 +128,15 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa
:param logfile:
:param hearbeat_interval
:param reconnect_after_missed_heartbeat
:param verbosity
:return:
"""
remote = False if local_package else True
config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name,
'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure,
'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner,
'validator': validator, 'trainer': trainer, 'init': init, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval,
'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat}
'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity}

if init:
apply_config(config)
Expand Down
56 changes: 56 additions & 0 deletions fedn/fedn/common/log_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
import logging.config

import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logging.getLogger("urllib3").setLevel(logging.ERROR)

handler = logging.StreamHandler()
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
handler.setFormatter(formatter)


def set_log_level_from_string(level_str):
"""
Set the log level based on a string input.
"""
# Mapping of string representation to logging constants
level_mapping = {
'CRITICAL': logging.CRITICAL,
'ERROR': logging.ERROR,
'WARNING': logging.WARNING,
'INFO': logging.INFO,
'DEBUG': logging.DEBUG,
}

# Get the logging level from the mapping
level = level_mapping.get(level_str.upper())

if not level:
raise ValueError(f"Invalid log level: {level_str}")

# Set the log level
logger.setLevel(level)


def set_log_stream(log_file):
"""
Redirect the log stream to a specified file, if log_file is set.
"""
if not log_file:
return

# Remove existing handlers
for h in logger.handlers[:]:
logger.removeHandler(h)

# Create a FileHandler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)

# Add the file handler to the logger
logger.addHandler(file_handler)
98 changes: 50 additions & 48 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

import fedn.common.net.grpc.fedn_pb2 as fedn
import fedn.common.net.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.network.clients.connect import ConnectorClient, Status
from fedn.network.clients.package import PackageRuntime
from fedn.network.clients.state import ClientState, ClientStateToString
from fedn.utils.dispatcher import Dispatcher
from fedn.utils.helpers import get_helper
from fedn.utils.logger import Logger

CHUNK_SIZE = 1024 * 1024
VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$'
Expand All @@ -51,13 +52,15 @@ class Client:

def __init__(self, config):
"""Initialize the client."""

self.state = None
self.error_state = False
self._attached = False
self._missed_heartbeat = 0
self.config = config

set_log_level_from_string(config.get('verbosity', "INFO"))
set_log_stream(config.get('logfile', None))

self.connector = ConnectorClient(host=config['discover_host'],
port=config['discover_port'],
token=config['token'],
Expand All @@ -78,8 +81,6 @@ def __init__(self, config):
self.run_path = os.path.join(os.getcwd(), dirname)
os.mkdir(self.run_path)

self.logger = Logger(
to_file=config['logfile'], file_path=self.run_path)
self.started_at = datetime.now()
self.logs = []

Expand All @@ -92,8 +93,8 @@ def __init__(self, config):

self._initialize_helper(client_config)
if not self.helper:
print("Failed to retrive helper class settings! {}".format(
client_config), flush=True)
logger.warning("Failed to retrieve helper class settings: {}".format(
client_config))

self._subscribe_to_combiner(config)

Expand All @@ -106,27 +107,26 @@ def _assign(self):
:rtype: dict
"""

print("Asking for assignment!", flush=True)
logger.info("Initiating assignment request.")
while True:
status, response = self.connector.assign()
if status == Status.TryAgain:
print(response, flush=True)
logger.info(response)
time.sleep(5)
continue
if status == Status.Assigned:
client_config = response
break
if status == Status.UnAuthorized:
print(response, flush=True)
logger.critical(response)
sys.exit("Exiting: Unauthorized")
if status == Status.UnMatchedConfig:
print(response, flush=True)
logger.critical(response)
sys.exit("Exiting: UnMatchedConfig")
time.sleep(5)
print(".", end=' ', flush=True)

print("Got assigned!", flush=True)
print("Received combiner config: {}".format(client_config), flush=True)
logger.info("Assignment successfully received.")
logger.info("Received combiner configuration: {}".format(client_config))
return client_config

def _add_grpc_metadata(self, key, value):
Expand Down Expand Up @@ -177,31 +177,31 @@ def _connect(self, client_config):
host = client_config['host']
# Add host to gRPC metadata
self._add_grpc_metadata('grpc-server', host)
print("CLIENT: Using metadata: {}".format(self.metadata), flush=True)
logger.info("Client using metadata: {}.".format(self.metadata))
port = client_config['port']
secure = False
if client_config['fqdn'] is not None:
host = client_config['fqdn']
# assuming https if fqdn is used
port = 443
print(f"CLIENT: Connecting to combiner host: {host}:{port}", flush=True)
logger.info(f"Initiating connection to combiner host at: {host}:{port}")

if client_config['certificate']:
print("CLIENT: using certificate from Reducer for GRPC channel")
logger.info("Utilizing CA certificate for GRPC channel authentication.")
secure = True
cert = base64.b64decode(
client_config['certificate']) # .decode('utf-8')
credentials = grpc.ssl_channel_credentials(root_certificates=cert)
channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials)
elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"):
secure = True
print("CLIENT: using root certificate from environment variable for GRPC channel")
logger.info("Using root certificate from environment variable for GRPC channel.")
with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f:
credentials = grpc.ssl_channel_credentials(f.read())
channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials)
elif self.config['secure']:
secure = True
print("CLIENT: using CA certificate for GRPC channel")
logger.info("Using CA certificate for GRPC channel.")
cert = self._get_ssl_certificate(host, port=port)

credentials = grpc.ssl_channel_credentials(cert.encode('utf-8'))
Expand All @@ -212,7 +212,7 @@ def _connect(self, client_config):
else:
channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials)
else:
print("CLIENT: using insecure GRPC channel")
logger.info("Using insecure GRPC channel.")
if port == 443:
port = 80
channel = grpc.insecure_channel("{}:{}".format(
Expand All @@ -225,13 +225,11 @@ def _connect(self, client_config):
self.combinerStub = rpc.CombinerStub(channel)
self.modelStub = rpc.ModelServiceStub(channel)

print("Client: {} connected {} to {}:{}".format(self.name,
"SECURED" if secure else "INSECURE",
host,
port),
flush=True)
logger.info("Successfully established {} connection to {}:{}".format("secure" if secure else "insecure",
host,
port))

print("Client: Using {} compute package.".format(
logger.info("Using {} compute package.".format(
client_config["package"]))

def _disconnect(self):
Expand All @@ -242,7 +240,7 @@ def _detach(self):
"""Detach from the FEDn network (disconnect from combiner)"""
# Setting _attached to False will make all processing threads return
if not self._attached:
print("Client is not attached.", flush=True)
logger.info("Client is not attached.")

self._attached = False
# Close gRPC connection to combiner
Expand All @@ -252,7 +250,7 @@ def _attach(self):
"""Attach to the FEDn network (connect to combiner)"""
# Ask controller for a combiner and connect to that combiner.
if self._attached:
print("Client is already attached. ", flush=True)
logger.info("Client is already attached. ")
return None

client_config = self._assign()
Expand Down Expand Up @@ -325,19 +323,16 @@ def _initialize_dispatcher(self, config):
if retval:
break
time.sleep(60)
print("No compute package available... retrying in 60s Trying {} more times.".format(
tries), flush=True)
logger.warning("Compute package not available. Retrying in 60 seconds. {} attempts remaining.".format(tries))
tries -= 1

if retval:
if 'checksum' not in config:
print(
"\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n",
flush=True)
logger.warning("Bypassing security validation for local package. Ensure the package source is trusted.")
else:
checks_out = pr.validate(config['checksum'])
if not checks_out:
print("Validation was enforced and invalid, client closing!")
logger.critical("Validation of local package failed. Client terminating.")
self.error_state = True
return

Expand All @@ -346,10 +341,12 @@ def _initialize_dispatcher(self, config):

self.dispatcher = pr.dispatcher(self.run_path)
try:
print("Running Dispatcher for entrypoint: startup", flush=True)
logger.info("Initiating Dispatcher with entrypoint set to: startup")
self.dispatcher.run_cmd("startup")
except KeyError:
pass
except Exception as e:
logger.error(f"Caught exception: {type(e).__name__}")
else:
# TODO: Deprecate
dispatch_config = {'entry_points':
Expand Down Expand Up @@ -523,11 +520,14 @@ def _process_training_request(self, model_id):
outpath = self.helper.get_tmp_path()
tic = time.time()
# TODO: Check return status, fail gracefully

self.dispatcher.run_cmd("train {} {}".format(inpath, outpath))

meta['exec_training'] = time.time() - tic

tic = time.time()
out_model = None

with open(outpath, "rb") as fr:
out_model = io.BytesIO(fr.read())

Expand All @@ -546,11 +546,15 @@ def _process_training_request(self, model_id):
os.unlink(outpath+'-metadata')

except Exception as e:
print("ERROR could not process training request due to error: {}".format(
e), flush=True)
logger.error("Could not process training request due to error: {}".format(e))
updated_model_id = None
meta = {'status': 'failed', 'error': str(e)}

# Push model update to combiner server
updated_model_id = uuid.uuid4()
self.set_model(out_model, str(updated_model_id))
meta['upload_model'] = time.time() - tic

self.state = ClientState.idle

return updated_model_id, meta
Expand Down Expand Up @@ -591,7 +595,7 @@ def _process_validation_request(self, model_id, is_inference):
os.unlink(outpath)

except Exception as e:
print("Validation failed with exception {}".format(e), flush=True)
logger.warning("Validation failed with exception {}".format(e))
raise
self.state = ClientState.idle
return None
Expand Down Expand Up @@ -701,8 +705,9 @@ def _send_heartbeat(self, update_frequency=2.0):
self._missed_heartbeat = 0
except grpc.RpcError as e:
status_code = e.code()
print("CLIENT heartbeat: GRPC ERROR {} retrying..".format(
status_code.name), flush=True)
logger.warning("Client heartbeat: GRPC error, {}. Retrying.".format(
status_code.name))
logger.debug(e)
self._handle_combiner_failure()

time.sleep(update_frequency)
Expand Down Expand Up @@ -745,20 +750,17 @@ def run(self):
old_state = self.state
while True:
time.sleep(1)
cnt += 1
if cnt == 0:
logger.info("Client is active, waiting for model update requests.")
cnt = 1
if self.state != old_state:
print("{}:CLIENT in {} state".format(datetime.now().strftime(
'%Y-%m-%d %H:%M:%S'), ClientStateToString(self.state)), flush=True)
if cnt > 5:
print("{}:CLIENT active".format(
datetime.now().strftime('%Y-%m-%d %H:%M:%S')), flush=True)
cnt = 0
logger.info("Client in {} state.".format(ClientStateToString(self.state)))
if not self._attached:
print("Detatched from combiner.", flush=True)
logger.info("Detached from combiner.")
# TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s
self._attach()
self._subscribe_to_combiner(self.config)
if self.error_state:
return
except KeyboardInterrupt:
print("Ok, exiting..")
logger.info("Shutting down.")
Loading

0 comments on commit 722fdd5

Please sign in to comment.