diff --git a/.ci/tests/examples/is_success.py b/.ci/tests/examples/is_success.py index 90cb7c1c2..0e52dba5a 100644 --- a/.ci/tests/examples/is_success.py +++ b/.ci/tests/examples/is_success.py @@ -1,6 +1,7 @@ -import pymongo -from time import sleep import sys +from time import sleep + +import pymongo N_ROUNDS = 3 RETRIES= 6 diff --git a/.devcontainer/bin/init_venv.sh b/.devcontainer/bin/init_venv.sh index 4c5dcdf51..24e4d84c2 100755 --- a/.devcontainer/bin/init_venv.sh +++ b/.devcontainer/bin/init_venv.sh @@ -9,5 +9,7 @@ python -m venv .venv .venv/bin/pip install \ sphinx==4.4.0 \ sphinx_press_theme==0.8.0 \ - sphinx-autobuild==2021.3.14 + sphinx-autobuild==2021.3.14 \ + autopep8==1.5.7 \ + isort==5.10.1 .venv/bin/pip install -e fedn \ No newline at end of file diff --git a/.devcontainer/devcontainer.json.tpl b/.devcontainer/devcontainer.json.tpl index de392db69..3a09b005f 100644 --- a/.devcontainer/devcontainer.json.tpl +++ b/.devcontainer/devcontainer.json.tpl @@ -9,7 +9,8 @@ "ms-azuretools.vscode-docker", "ms-python.python", "exiasr.hadolint", - "yzhang.markdown-all-in-one" + "yzhang.markdown-all-in-one", + "ms-python.isort" ], "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind,consistency=default", diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml new file mode 100644 index 000000000..3f32c87d0 --- /dev/null +++ b/.github/workflows/code-checks.yaml @@ -0,0 +1,30 @@ +name: "code checks" + +on: push + +jobs: + code-checks: + runs-on: ubuntu-20.04 + steps: + - name: checkout + uses: actions/checkout@v2 + + - name: init venv + run: .devcontainer/bin/init_venv.sh + + - name: check Python imports + run: > + .venv/bin/isort . --check --diff + --skip .venv + --skip .mnist-keras + --skip .mnist-pytorch + + - name: check Python formatting + run: > + .venv/bin/autopep8 --recursive --diff + --exclude .venv + --exclude .mnist-keras + --exclude .mnist-pytorch + . + + # TODO: add linting/formatting for all file types \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..35fa1902d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": true + }, +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c7a30a217..d2f774df8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,3 +28,11 @@ Report a bug or propose a feature by [opening a new GitHub Issue](https://github - if your branch is a hotfix, name it **hotfix/[GitHub-Issue-ID]** Open your pull requests against the **develop** branch unless you're resolving a critical bug in production (hotfix). Then your pull request should be against **master** branch. + +### Code checks +We defined GitHub actions that check code quality and formatting against pushed branches and pull requests. We use: + +- [autopep8](https://pypi.org/project/autopep8/) to conform to the PEP 8 code style +- [isort](https://github.com/PyCQA/isort) to organize imports + +For more information please refer to the code check action: [.github/workflows/code-checks.yaml](.github/workflows/code-checks.yaml). \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index c05432ccf..cc3543318 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,6 +12,7 @@ # import os import sys + sys.path.insert(0, os.path.abspath('../../fedn')) diff --git a/examples/mnist-keras/bin/get_data b/examples/mnist-keras/bin/get_data index 0faf8c04f..4c449d03e 100755 --- a/examples/mnist-keras/bin/get_data +++ b/examples/mnist-keras/bin/get_data @@ -1,17 +1,21 @@ #!./.mnist-keras/bin/python +import os + import fire -import tensorflow as tf import numpy as np -import os +import tensorflow as tf + def get_data(out_dir='data'): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) - + # Download data (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - np.savez(f'{out_dir}/mnist.npz', x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test) + np.savez(f'{out_dir}/mnist.npz', x_train=x_train, + y_train=y_train, x_test=x_test, y_test=y_test) + if __name__ == '__main__': - fire.Fire(get_data) \ No newline at end of file + fire.Fire(get_data) diff --git a/examples/mnist-keras/bin/split_data b/examples/mnist-keras/bin/split_data index 035bdc449..bb583b6d7 100755 --- a/examples/mnist-keras/bin/split_data +++ b/examples/mnist-keras/bin/split_data @@ -1,8 +1,10 @@ #!./.mnist-keras/bin/python import os -import numpy as np from math import floor + import fire +import numpy as np + def splitset(dataset, parts): n = dataset.shape[0] @@ -26,14 +28,15 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2): # Make splits for i in range(n_splits): - subdir=f'{outdir}/clients/{str(i+1)}' + subdir = f'{outdir}/clients/{str(i+1)}' if not os.path.exists(subdir): os.mkdir(subdir) np.savez(f'{subdir}/mnist.npz', - x_train=data['x_train'][i], - y_train=data['y_train'][i], - x_test=data['x_test'][i], - y_test=data['y_test'][i]) + x_train=data['x_train'][i], + y_train=data['y_train'][i], + x_test=data['x_test'][i], + y_test=data['y_test'][i]) + if __name__ == '__main__': - fire.Fire(split) \ No newline at end of file + fire.Fire(split) diff --git a/examples/mnist-keras/client/entrypoint b/examples/mnist-keras/client/entrypoint index 7fef1eb18..e9609f2ca 100755 --- a/examples/mnist-keras/client/entrypoint +++ b/examples/mnist-keras/client/entrypoint @@ -1,23 +1,27 @@ #!./.mnist-keras/bin/python -import tensorflow as tf -import numpy as np -from fedn.utils.kerashelper import KerasHelper -import fire import json -import docker import os -NUM_CLASSES=10 +import docker +import fire +import numpy as np +import tensorflow as tf + +from fedn.utils.kerashelper import KerasHelper + +NUM_CLASSES = 10 + def _get_data_path(): # Figure out FEDn client number from container name client = docker.from_env() container = client.containers.get(os.environ['HOSTNAME']) number = container.name[-1] - + # Return data path return f"/var/data/clients/{number}/mnist.npz" + def _compile_model(img_rows=28, img_cols=28): # Set input shape input_shape = (img_rows, img_cols, 1) @@ -30,10 +34,11 @@ def _compile_model(img_rows=28, img_cols=28): model.add(tf.keras.layers.Dense(32, activation='relu')) model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.Adam(), - metrics=['accuracy']) + optimizer=tf.keras.optimizers.Adam(), + metrics=['accuracy']) return model + def _load_data(data_path, is_train=True): # Load data if data_path is None: @@ -50,16 +55,18 @@ def _load_data(data_path, is_train=True): # Normalize X = X.astype('float32') - X = np.expand_dims(X,-1) + X = np.expand_dims(X, -1) X = X / 255 y = tf.keras.utils.to_categorical(y, NUM_CLASSES) return X, y + def init_seed(out_path='seed.npz'): - weights = _compile_model().get_weights() - helper = KerasHelper() - helper.save_model(weights, out_path) + weights = _compile_model().get_weights() + helper = KerasHelper() + helper.save_model(weights, out_path) + def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1): # Load data @@ -73,11 +80,12 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 # Train model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs) - + # Save weights = model.get_weights() helper.save_model(weights, out_model_path) + def validate(in_model_path, out_json_path, data_path=None): # Load data x_train, y_train = _load_data(data_path) @@ -104,13 +112,14 @@ def validate(in_model_path, out_json_path, data_path=None): } # Save JSON - with open(out_json_path,"w") as fh: + with open(out_json_path, "w") as fh: fh.write(json.dumps(report)) + if __name__ == '__main__': fire.Fire({ 'init_seed': init_seed, 'train': train, 'validate': validate, - '_get_data_path': _get_data_path, # for testing - }) \ No newline at end of file + '_get_data_path': _get_data_path, # for testing + }) diff --git a/examples/mnist-pytorch/bin/get_data b/examples/mnist-pytorch/bin/get_data index 72b775433..abe8a7197 100755 --- a/examples/mnist-pytorch/bin/get_data +++ b/examples/mnist-pytorch/bin/get_data @@ -1,17 +1,22 @@ #!./.mnist-pytorch/bin/python +import os + import fire -import torchvision import numpy as np -import os +import torchvision + def get_data(out_dir='data'): # Make dir if necessary if not os.path.exists(out_dir): os.mkdir(out_dir) - + # Download data - torchvision.datasets.MNIST(root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True) - torchvision.datasets.MNIST(root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True) + torchvision.datasets.MNIST( + root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True) + torchvision.datasets.MNIST( + root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True) + if __name__ == '__main__': - fire.Fire(get_data) \ No newline at end of file + fire.Fire(get_data) diff --git a/examples/mnist-pytorch/bin/split_data b/examples/mnist-pytorch/bin/split_data index b884c054e..51e38b4d4 100755 --- a/examples/mnist-pytorch/bin/split_data +++ b/examples/mnist-pytorch/bin/split_data @@ -1,9 +1,11 @@ #!./.mnist-pytorch/bin/python -import torchvision -import torch +import os from math import floor + import fire -import os +import torch +import torchvision + def splitset(dataset, parts): n = dataset.shape[0] @@ -20,8 +22,10 @@ def split(out_dir='data', n_splits=2): os.mkdir(f'{out_dir}/clients') # Load and convert to dict - train_data = torchvision.datasets.MNIST(root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True) - test_data = torchvision.datasets.MNIST(root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False) + train_data = torchvision.datasets.MNIST( + root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST( + root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False) data = { 'x_train': splitset(train_data.data, n_splits), 'y_train': splitset(train_data.targets, n_splits), @@ -31,7 +35,7 @@ def split(out_dir='data', n_splits=2): # Make splits for i in range(n_splits): - subdir=f'{out_dir}/clients/{str(i+1)}' + subdir = f'{out_dir}/clients/{str(i+1)}' if not os.path.exists(subdir): os.mkdir(subdir) torch.save({ @@ -40,7 +44,8 @@ def split(out_dir='data', n_splits=2): 'x_test': data['x_test'][i], 'y_test': data['y_test'][i], }, - f'{subdir}/mnist.pt') + f'{subdir}/mnist.pt') + if __name__ == '__main__': - fire.Fire(split) \ No newline at end of file + fire.Fire(split) diff --git a/examples/mnist-pytorch/client/entrypoint b/examples/mnist-pytorch/client/entrypoint index d56fb7e66..429ca268a 100755 --- a/examples/mnist-pytorch/client/entrypoint +++ b/examples/mnist-pytorch/client/entrypoint @@ -1,24 +1,28 @@ #!./.mnist-pytorch/bin/python -import torch -from fedn.utils.pytorchhelper import PytorchHelper -import fire -import json -import docker -import os import collections +import json import math +import os + +import docker +import fire +import torch + +from fedn.utils.pytorchhelper import PytorchHelper + +NUM_CLASSES = 10 -NUM_CLASSES=10 def _get_data_path(): # Figure out FEDn client number from container name client = docker.from_env() container = client.containers.get(os.environ['HOSTNAME']) number = container.name[-1] - + # Return data path return f"/var/data/clients/{number}/mnist.pt" + def _compile_model(): # Define model class Net(torch.nn.Module): @@ -38,6 +42,7 @@ def _compile_model(): # Return model return Net() + def _load_data(data_path, is_train=True): # Load data if data_path is None: @@ -57,6 +62,7 @@ def _load_data(data_path, is_train=True): return X, y + def _save_model(model, out_path): weights = model.state_dict() weights_np = collections.OrderedDict() @@ -65,6 +71,7 @@ def _save_model(model, out_path): helper = PytorchHelper() helper.save_model(weights, out_path) + def _load_model(model_path): helper = PytorchHelper() weights_np = helper.load_model(model_path) @@ -76,11 +83,13 @@ def _load_model(model_path): model.eval() return model + def init_seed(out_path='seed.npz'): # Init and save - model = _compile_model() + model = _compile_model() _save_model(model, out_path) + def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): # Load data x_train, y_train = _load_data(data_path) @@ -92,8 +101,8 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 optimizer = torch.optim.SGD(model.parameters(), lr=lr) n_batches = int(math.ceil(len(x_train) / batch_size)) criterion = torch.nn.NLLLoss() - for e in range(epochs): # epoch loop - for b in range(n_batches): # batch loop + for e in range(epochs): # epoch loop + for b in range(n_batches): # batch loop # Retrieve current batch batch_x = x_train[b * batch_size:(b + 1) * batch_size] batch_y = y_train[b * batch_size:(b + 1) * batch_size] @@ -104,12 +113,14 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 loss.backward() optimizer.step() # Log - if b % 100 == 0: - print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + if b % 100 == 0: + print( + f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") # Save _save_model(model, out_model_path) + def validate(in_model_path, out_json_path, data_path=None): # Load data x_train, y_train = _load_data(data_path) @@ -123,10 +134,12 @@ def validate(in_model_path, out_json_path, data_path=None): with torch.no_grad(): train_out = model(x_train) training_loss = criterion(train_out, y_train) - training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) + training_accuracy = torch.sum(torch.argmax( + train_out, dim=1) == y_train) / len(train_out) test_out = model(x_test) test_loss = criterion(test_out, y_test) - test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) + test_accuracy = torch.sum(torch.argmax( + test_out, dim=1) == y_test) / len(test_out) # JSON schema report = { @@ -137,13 +150,14 @@ def validate(in_model_path, out_json_path, data_path=None): } # Save JSON - with open(out_json_path,"w") as fh: + with open(out_json_path, "w") as fh: fh.write(json.dumps(report)) + if __name__ == '__main__': fire.Fire({ 'init_seed': init_seed, 'train': train, 'validate': validate, - '_get_data_path': _get_data_path, # for testing - }) \ No newline at end of file + '_get_data_path': _get_data_path, # for testing + }) diff --git a/fedn/cli/__init__.py b/fedn/cli/__init__.py index 4048eb69c..dbea41de3 100644 --- a/fedn/cli/__init__.py +++ b/fedn/cli/__init__.py @@ -1,3 +1,3 @@ +from .control_cmd import control_cmd from .main import main from .run_cmd import run_cmd -from .control_cmd import control_cmd diff --git a/fedn/cli/control_cmd.py b/fedn/cli/control_cmd.py index a1114c1ed..ecfa631f5 100644 --- a/fedn/cli/control_cmd.py +++ b/fedn/cli/control_cmd.py @@ -44,7 +44,8 @@ def package_cmd(ctx, reducer, port, token, name, upload, validate, cwd): if not name: from datetime import datetime - name = str(os.path.basename(cwd)) + '-' + datetime.today().strftime('%Y-%m-%d-%H%M%S') + name = str(os.path.basename(cwd)) + '-' + \ + datetime.today().strftime('%Y-%m-%d-%H%M%S') config = {'host': reducer, 'port': port, 'token': token, 'name': name, 'cwd': cwd} diff --git a/fedn/cli/main.py b/fedn/cli/main.py index bfaa3b95f..cc33c579b 100644 --- a/fedn/cli/main.py +++ b/fedn/cli/main.py @@ -1,7 +1,7 @@ -import click - import logging +import click + logging.basicConfig(format='%(asctime)s [%(filename)s:%(lineno)d] %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') # , level=logging.DEBUG) diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index e87b7bd85..0e6766325 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -1,10 +1,15 @@ -import click +import time import uuid + +import click import yaml -import time -from fedn.clients.reducer.restservice import encode_auth_token, decode_auth_token + +from fedn.clients.reducer.restservice import (decode_auth_token, + encode_auth_token) + from .main import main + def get_statestore_config_from_file(init): """ @@ -18,6 +23,7 @@ def get_statestore_config_from_file(init): except yaml.YAMLError as e: raise (e) + def check_helper_config_file(config): control = config['control'] try: @@ -27,6 +33,7 @@ def check_helper_config_file(config): exit(-1) return helper + @main.group('run') @click.pass_context def run_cmd(ctx): @@ -57,8 +64,8 @@ def run_cmd(ctx): help='Set to a filename to (re)init client from file state.') @click.option('-l', '--logfile', required=False, default='{}-client.log'.format(time.strftime("%Y%m%d-%H%M%S")), help='Set logfile for client log to file.') -@click.option('--heartbeat-interval',required=False, default=2) -@click.option('--reconnect-after-missed-heartbeat',required=False, default=30) +@click.option('--heartbeat-interval', required=False, default=2) +@click.option('--reconnect-after-missed-heartbeat', required=False, default=30) @click.pass_context def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, dry_run, secure, preshared_cert, verify_cert, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat): @@ -86,7 +93,7 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name, 'client_id': client_id, 'remote_compute_context': remote, 'dry_run': dry_run, 'secure': secure, 'preshared_cert': preshared_cert, 'verify_cert': verify_cert, 'preferred_combiner': preferred_combiner, - 'validator': validator, 'trainer': trainer, 'init': init, 'logfile': logfile,'heartbeat_interval': heartbeat_interval, + 'validator': validator, 'trainer': trainer, 'init': init, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval, 'reconnect_after_missed_heartbeat': 30} if config['init']: @@ -109,7 +116,8 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa config['discover_host'] == '' or \ config['discover_host'] is None or \ config['discover_port'] == '': - print("Missing required configuration: discover_host, discover_port", flush=True) + print( + "Missing required configuration: discover_host, discover_port", flush=True) return except Exception as e: print("Could not load config appropriately. Check config", flush=True) @@ -140,7 +148,7 @@ def reducer_cmd(ctx, discoverhost, discoverport, secret_key, local_package, name :param init: """ remote = False if local_package else True - config = {'discover_host': discoverhost, 'discover_port': discoverport, 'secret_key': secret_key, 'name': name, + config = {'discover_host': discoverhost, 'discover_port': discoverport, 'secret_key': secret_key, 'name': name, 'remote_compute_context': remote, 'init': init} # Read settings from config file @@ -151,10 +159,10 @@ def reducer_cmd(ctx, discoverhost, discoverport, secret_key, local_package, name print('Failed to read config from settings file, exiting.', flush=True) print(e, flush=True) exit(-1) - + if not remote: helper = check_helper_config_file(fedn_config) - + try: network_id = fedn_config['network_id'] except KeyError: @@ -163,21 +171,22 @@ def reducer_cmd(ctx, discoverhost, discoverport, secret_key, local_package, name statestore_config = fedn_config['statestore'] if statestore_config['type'] == 'MongoDB': - from fedn.clients.reducer.statestore.mongoreducerstatestore import MongoReducerStateStore - statestore = MongoReducerStateStore(network_id, statestore_config['mongo_config'], defaults=config['init']) + from fedn.clients.reducer.statestore.mongoreducerstatestore import \ + MongoReducerStateStore + statestore = MongoReducerStateStore( + network_id, statestore_config['mongo_config'], defaults=config['init']) else: print("Unsupported statestore type, exiting. ", flush=True) exit(-1) - if config['secret_key']: - # If we already have a valid token in statestore config, use that one. + # If we already have a valid token in statestore config, use that one. existing_config = statestore.get_reducer() - if existing_config: + if existing_config: try: existing_config = statestore.get_reducer() current_token = existing_config['token'] - status = decode_auth_token(current_token,config['secret_key']) + status = decode_auth_token(current_token, config['secret_key']) if status != 'Success': token = encode_auth_token(config['secret_key']) config['token'] = token diff --git a/fedn/cli/tests/__init__.py b/fedn/cli/tests/__init__.py index 8b1378917..e69de29bb 100644 --- a/fedn/cli/tests/__init__.py +++ b/fedn/cli/tests/__init__.py @@ -1 +0,0 @@ - diff --git a/fedn/cli/tests/tests.py b/fedn/cli/tests/tests.py index ece34396a..4c71c3816 100644 --- a/fedn/cli/tests/tests.py +++ b/fedn/cli/tests/tests.py @@ -1,10 +1,12 @@ import unittest from unittest.mock import MagicMock, patch +from uuid import UUID + import yaml from click.testing import CliRunner -from uuid import UUID from run_cmd import check_helper_config_file + class TestReducerCLI(unittest.TestCase): def setUp(self): @@ -12,11 +14,11 @@ def setUp(self): self.INIT_FILE_REDUCER = { "network_id": "fedn-test-network", "token": "fedn_token", - "control":{ + "control": { "state": "idle", "helper": "keras", }, - "statestore":{ + "statestore": { "type": "MongoDB", "mongo_config": { "username": "fedn_admin", @@ -25,9 +27,9 @@ def setUp(self): "port": "6534" } }, - "storage":{ + "storage": { "storage_type": "S3", - "storage_config":{ + "storage_config": { "storage_hostname": "minio", "storage_port": "9000", "storage_access_key": "fedn_admin", @@ -38,36 +40,36 @@ def setUp(self): } } } - + @unittest.skip def test_get_statestore_config_from_file(self): pass - # def test_reducer_cmd_remote(self): - + # with self.runner.isolated_filesystem(): - + # COPY_INIT_FILE = self.INIT_FILE_REDUCER # del COPY_INIT_FILE["control"]["helper"] # with open('settings.yaml', 'w') as f: - # f.write(yaml.dump(COPY_INIT_FILE)) - + # f.write(yaml.dump(COPY_INIT_FILE)) + # result = self.runner.invoke(reducer_cmd, ['--remote', False, '--init',"settings.yaml"]) # self.assertEqual(result.output, "--remote was set to False, but no helper was found in --init settings file: settings.yaml\n") # self.assertEqual(result.exit_code, -1) def test_check_helper_config_file(self): - - self.assertEqual(check_helper_config_file(self.INIT_FILE_REDUCER), "keras") - + + self.assertEqual(check_helper_config_file( + self.INIT_FILE_REDUCER), "keras") + COPY_INIT_FILE = self.INIT_FILE_REDUCER del COPY_INIT_FILE["control"]["helper"] - + with self.assertRaises(SystemExit): helper = check_helper_config_file(COPY_INIT_FILE) - + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/fedn/fedn/__init__.py b/fedn/fedn/__init__.py index 7fe596081..f04a9cd80 100644 --- a/fedn/fedn/__init__.py +++ b/fedn/fedn/__init__.py @@ -1,10 +1,11 @@ -from os.path import dirname, basename, isfile import glob +import os +from os.path import basename, dirname, isfile modules = glob.glob(dirname(__file__) + "/*.py") -__all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] +__all__ = [basename(f)[:-3] for f in modules if isfile(f) + and not f.endswith('__init__.py')] -import os _ROOT = os.path.abspath(os.path.dirname(__file__)) diff --git a/fedn/fedn/aggregators/aggregator.py b/fedn/fedn/aggregators/aggregator.py index 1af0f6640..3e7b3bd20 100644 --- a/fedn/fedn/aggregators/aggregator.py +++ b/fedn/fedn/aggregators/aggregator.py @@ -1,14 +1,14 @@ import collections -from abc import ABC, abstractmethod import os import tempfile +from abc import ABC, abstractmethod class AggregatorBase(ABC): """ Abstract class defining helpers. """ @abstractmethod - def __init__(self, id, storage, server, modelservice, control): + def __init__(self, id, storage, server, modelservice, control): """ """ self.name = "" self.storage = storage @@ -28,14 +28,14 @@ def on_model_validation(self, validation): @abstractmethod def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180): pass - - -#def get_aggregator(aggregator_type): -# """ Return an instance of the aggregator class. + + +# def get_aggregator(aggregator_type): +# """ Return an instance of the aggregator class. # # :param aggregator_type (str): The aggregator type ('fedavg') -# :return: +# :return: # """ # if helper_type == 'fedavg': # from fedn.aggregators.fedavg import FedAvgAggregator -# return FedAvgAggregator() \ No newline at end of file +# return FedAvgAggregator() diff --git a/fedn/fedn/aggregators/fedavg.py b/fedn/fedn/aggregators/fedavg.py index d928adbce..034e22a70 100644 --- a/fedn/fedn/aggregators/fedavg.py +++ b/fedn/fedn/aggregators/fedavg.py @@ -1,18 +1,19 @@ import json import os import queue +import sys import tempfile import time import uuid -import sys import fedn.common.net.grpc.fedn_pb2 as fedn -from fedn.utils.helpers import get_helper from fedn.aggregators.aggregator import AggregatorBase +from fedn.utils.helpers import get_helper + class FedAvgAggregator(AggregatorBase): """ Local SGD / Federated Averaging (FedAvg) aggregator. - + :param id: A reference to id of :class: `fedn.combiner.Combiner` :type id: str :param storage: Model repository for :class: `fedn.combiner.Combiner` @@ -30,8 +31,8 @@ def __init__(self, id, storage, server, modelservice, control): """Constructor method """ - super().__init__(id,storage, server, modelservice, control) - + super().__init__(id, storage, server, modelservice, control) + self.name = "FedAvg" self.validations = {} self.model_updates = queue.Queue() @@ -40,19 +41,19 @@ def on_model_update(self, model_id): """Callback when a new model update is recieved from a client. Performs (optional) pre-processing and the puts the update id on the aggregation queue. - + :param model_id: ID of model update :type model_id: str """ try: self.server.report_status("AGGREGATOR({}): callback received model {}".format(self.name, model_id), - log_level=fedn.Status.INFO) + log_level=fedn.Status.INFO) # Push the model update to the processing queue self.model_updates.put(model_id) except Exception as e: self.server.report_status("AGGREGATOR({}): Failed to receive candidate model! {}".format(self.name, e), - log_level=fedn.Status.WARNING) + log_level=fedn.Status.WARNING) pass def on_model_validation(self, validation): @@ -63,10 +64,10 @@ def on_model_validation(self, validation): :type validation: dict """ - # Currently, the validations are actually sent as status messages + # Currently, the validations are actually sent as status messages # directly in the client, so here we are just storing them in the - # combiner memory. This will need to be refactored later so that this - # callback is responsible for reporting the validation to the db. + # combiner memory. This will need to be refactored later so that this + # callback is responsible for reporting the validation to the db. model_id = validation.model_id data = json.loads(validation.data) @@ -76,8 +77,7 @@ def on_model_validation(self, validation): self.validations[model_id] = [data] self.server.report_status("AGGREGATOR({}): callback processed validation {}".format(self.name, validation.model_id), - log_level=fedn.Status.INFO) - + log_level=fedn.Status.INFO) def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180): """Compute a running average of model updates. @@ -94,12 +94,12 @@ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=N :rtype: tuple """ - data = {} data['time_model_load'] = 0.0 data['time_model_aggregation'] = 0.0 - self.server.report_status("AGGREGATOR({}): Aggregating model updates...".format(self.name)) + self.server.report_status( + "AGGREGATOR({}): Aggregating model updates...".format(self.name)) round_time = 0.0 polling_interval = 1.0 @@ -107,46 +107,52 @@ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=N while nr_processed_models < nr_expected_models: try: model_id = self.model_updates.get(block=False) - self.server.report_status("AGGREGATOR({}): Received model update with id {}".format(self.name, model_id)) + self.server.report_status( + "AGGREGATOR({}): Received model update with id {}".format(self.name, model_id)) # Load the model update from disk tic = time.time() model_str = self.control.load_model_fault_tolerant(model_id) if model_str: try: - model_next = helper.load_model_from_BytesIO(model_str.getbuffer()) + model_next = helper.load_model_from_BytesIO( + model_str.getbuffer()) except IOError: - self.server.report_status("AGGREGATOR({}): Failed to load model!".format(self.name)) - else: + self.server.report_status( + "AGGREGATOR({}): Failed to load model!".format(self.name)) + else: raise data['time_model_load'] += time.time() - tic - # Aggregate / reduce + # Aggregate / reduce tic = time.time() if nr_processed_models == 0: model = model_next else: - model = helper.increment_average(model, model_next, nr_processed_models + 1) + model = helper.increment_average( + model, model_next, nr_processed_models + 1) data['time_model_aggregation'] += time.time() - tic nr_processed_models += 1 self.model_updates.task_done() except queue.Empty: - self.server.report_status("AGGREGATOR({}): waiting for model updates: {} of {} completed.".format(self.name, + self.server.report_status("AGGREGATOR({}): waiting for model updates: {} of {} completed.".format(self.name, nr_processed_models, nr_expected_models)) time.sleep(polling_interval) round_time += polling_interval except Exception as e: - self.server.report_status("AGGERGATOR({}): Error encoutered while reading model update, skipping this update. {}".format(self.name, e)) + self.server.report_status( + "AGGERGATOR({}): Error encoutered while reading model update, skipping this update. {}".format(self.name, e)) nr_expected_models -= 1 if nr_expected_models <= 0: return None, data self.model_updates.task_done() - + if round_time >= timeout: - self.server.report_status("AGGREGATOR({}): training round timed out.".format(self.name), log_level=fedn.Status.WARNING) - # TODO: Generalize policy for what to do in case of timeout. + self.server.report_status("AGGREGATOR({}): training round timed out.".format( + self.name), log_level=fedn.Status.WARNING) + # TODO: Generalize policy for what to do in case of timeout. if nr_processed_models >= nr_required_models: break else: @@ -155,5 +161,5 @@ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=N data['nr_successful_updates'] = nr_processed_models self.server.report_status("AGGREGATOR({}): Training round completed, aggregated {} models.".format(self.name, nr_processed_models), - log_level=fedn.Status.INFO) + log_level=fedn.Status.INFO) return model, data diff --git a/fedn/fedn/client.py b/fedn/fedn/client.py index f86e1bb28..50d6a885f 100644 --- a/fedn/fedn/client.py +++ b/fedn/fedn/client.py @@ -1,44 +1,41 @@ +import io import json import os +import queue import sys -import io -import uuid import tempfile -import threading, queue +import threading import time -from aiohttp import client +import uuid +from datetime import datetime import grpc +from aiohttp import client import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc -from fedn.common.net.connect import ConnectorClient, Status +from fedn.clients.client.state import ClientState, ClientStateToString from fedn.common.control.package import PackageRuntime - -from fedn.utils.logger import Logger -from fedn.utils.helpers import get_helper - +from fedn.common.net.connect import ConnectorClient, Status # TODO Remove from this level. Abstract to unified non implementation specific client. from fedn.utils.dispatcher import Dispatcher - -from fedn.clients.client.state import ClientState, ClientStateToString +from fedn.utils.helpers import get_helper +from fedn.utils.logger import Logger CHUNK_SIZE = 1024 * 1024 -from datetime import datetime - class Client: """FEDn Client. Service running on client/datanodes in a federation, recieving and handling model update and model validation requests. - + Attibutes --------- config: dict A configuration dictionary containing connection information for the discovery service (controller) and settings governing e.g. client-combiner assignment behavior. - + """ def __init__(self, config): @@ -54,8 +51,8 @@ def __init__(self, config): self.state = None self.error_state = False self._attached = False - self._missed_heartbeat=0 - self.config = config + self._missed_heartbeat = 0 + self.config = config self.connector = ConnectorClient(config['discover_host'], config['discover_port'], @@ -67,13 +64,14 @@ def __init__(self, config): secure=config['secure'], preshared_cert=config['preshared_cert'], verify_cert=config['verify_cert']) - + self.name = config['name'] dirname = time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) - self.logger = Logger(to_file=config['logfile'], file_path=self.run_path) + self.logger = Logger( + to_file=config['logfile'], file_path=self.run_path) self.started_at = datetime.now() self.logs = [] @@ -81,21 +79,22 @@ def __init__(self, config): # Attach to the FEDn network (get combiner) client_config = self._attach() - + self._initialize_dispatcher(config) self._initialize_helper(client_config) if not self.helper: - print("Failed to retrive helper class settings! {}".format(client_config), flush=True) + print("Failed to retrive helper class settings! {}".format( + client_config), flush=True) self._subscribe_to_combiner(config) self.state = ClientState.idle def _detach(self): - # Setting _attached to False will make all processing threads return + # Setting _attached to False will make all processing threads return if not self._attached: - print("Client is not attached.",flush=True) + print("Client is not attached.", flush=True) self._attached = False # Close gRPC connection to combiner @@ -104,35 +103,38 @@ def _detach(self): def _attach(self): """ """ # Ask controller for a combiner and connect to that combiner. - if self._attached: - print("Client is already attached. ",flush=True) + if self._attached: + print("Client is already attached. ", flush=True) return None client_config = self._assign() self._connect(client_config) - if client_config: - self._attached=True + if client_config: + self._attached = True return client_config - def _initialize_helper(self,client_config): - + def _initialize_helper(self, client_config): + if 'model_type' in client_config.keys(): self.helper = get_helper(client_config['model_type']) - def _subscribe_to_combiner(self,config): + def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. - + """ - # Start sending heartbeats to the combiner. - threading.Thread(target=self._send_heartbeat, kwargs={'update_frequency': config['heartbeat_interval']}, daemon=True).start() + # Start sending heartbeats to the combiner. + threading.Thread(target=self._send_heartbeat, kwargs={ + 'update_frequency': config['heartbeat_interval']}, daemon=True).start() - # Start listening for combiner training and validation messages + # Start listening for combiner training and validation messages if config['trainer'] == True: - threading.Thread(target=self._listen_to_model_update_request_stream, daemon=True).start() + threading.Thread( + target=self._listen_to_model_update_request_stream, daemon=True).start() if config['validator'] == True: - threading.Thread(target=self._listen_to_model_validation_request_stream, daemon=True).start() + threading.Thread( + target=self._listen_to_model_validation_request_stream, daemon=True).start() self._attached = True # Start processing the client message inbox @@ -147,11 +149,13 @@ def _initialize_dispatcher(self, config): tries = 10 while tries > 0: - retval = pr.download(config['discover_host'], config['discover_port'], config['token']) + retval = pr.download( + config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) - print("No compute package available... retrying in 60s Trying {} more times.".format(tries), flush=True) + print("No compute package available... retrying in 60s Trying {} more times.".format( + tries), flush=True) tries -= 1 if retval: @@ -178,9 +182,9 @@ def _initialize_dispatcher(self, config): else: # TODO: Deprecate dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} + {'predict': {'command': 'python3 predict.py'}, + 'train': {'command': 'python3 train.py'}, + 'validate': {'command': 'python3 validate.py'}}} dispatch_dir = os.getcwd() from_path = os.path.join(os.getcwd(), 'client') @@ -188,8 +192,7 @@ def _initialize_dispatcher(self, config): copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) - - def _assign(self): + def _assign(self): """Contacts the controller and asks for combiner assignment. """ print("Asking for assignment!", flush=True) @@ -210,30 +213,32 @@ def _assign(self): sys.exit("Exiting: UnMatchedConfig") time.sleep(5) print(".", end=' ', flush=True) - + print("Got assigned!", flush=True) return client_config def _connect(self, client_config): """Connect to assigned combiner. - + Parameters ---------- client_config : dict A dictionary with connection information and settings for the assigned combiner. - + """ # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 - cert = base64.b64decode(client_config['certificate']) # .decode('utf-8') + cert = base64.b64decode( + client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: - channel = grpc.insecure_channel("{}:{}".format(client_config['host'], str(client_config['port']))) + channel = grpc.insecure_channel("{}:{}".format( + client_config['host'], str(client_config['port']))) self.channel = channel @@ -244,8 +249,9 @@ def _connect(self, client_config): print("Client: {} connected {} to {}:{}".format(self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) - - print("Client: Using {} compute package.".format(client_config["package"])) + + print("Client: Using {} compute package.".format( + client_config["package"])) def _disconnect(self): self.channel.close() @@ -254,12 +260,12 @@ def get_model(self, id): """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel, Dowload. - + Parameters ---------- id : str The id of the model update object. - + """ from io import BytesIO @@ -289,7 +295,7 @@ def set_model(self, model, id): The model update object. id : str The id of the model update object. - """ + """ from io import BytesIO @@ -312,9 +318,11 @@ def upload_request_generator(mdl): while True: b = mdl.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + result = fedn.ModelRequest( + data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: - result = fedn.ModelRequest(id=id, status=fedn.ModelStatus.OK) + result = fedn.ModelRequest( + id=id, status=fedn.ModelStatus.OK) yield result if not b: @@ -339,24 +347,24 @@ def _listen_to_model_update_request_stream(self): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) + type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) self.inbox.put(('train', request)) - - if not self._attached: - return + + if not self._attached: + return except grpc.RpcError as e: status_code = e.code() - #TODO: make configurable + # TODO: make configurable timeout = 5 - #print("CLIENT __listen_to_model_update_request_stream: GRPC ERROR {} retrying in {}..".format( + # print("CLIENT __listen_to_model_update_request_stream: GRPC ERROR {} retrying in {}..".format( # status_code.name, timeout), flush=True) - time.sleep(timeout) + time.sleep(timeout) except: raise - if not self._attached: + if not self._attached: return def _listen_to_model_validation_request_stream(self): @@ -371,36 +379,37 @@ def _listen_to_model_validation_request_stream(self): # Process validation request model_id = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) + type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) self.inbox.put(('validate', request)) except grpc.RpcError as e: status_code = e.code() # TODO: make configurable timeout = 5 - #print("CLIENT __listen_to_model_validation_request_stream: GRPC ERROR {} retrying in {}..".format( + # print("CLIENT __listen_to_model_validation_request_stream: GRPC ERROR {} retrying in {}..".format( # status_code.name, timeout), flush=True) time.sleep(timeout) except: - raise + raise - if not self._attached: + if not self._attached: return def process_request(self): """Process training and validation tasks. """ while True: - if not self._attached: - return + if not self._attached: + return try: - (task_type, request) = self.inbox.get(timeout=1.0) + (task_type, request) = self.inbox.get(timeout=1.0) if task_type == 'train': tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request(request.model_id) + model_id, meta = self._process_training_request( + request.model_id) processing_time = time.time()-tic meta['processing_time'] = processing_time @@ -416,22 +425,23 @@ def process_request(self): update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id update.meta = json.dumps(meta) - #TODO: Check responses + # TODO: Check responses response = self.orchestrator.SendModelUpdate(update) self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update) + type=fedn.StatusType.MODEL_UPDATE, request=update) else: self._send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request) + log_level=fedn.Status.WARNING, + request=request) self.state = ClientState.idle self.inbox.task_done() elif task_type == 'validate': self.state = ClientState.validating - metrics = self._process_validation_request(request.model_id) + metrics = self._process_validation_request( + request.model_id) if metrics != None: # Send validation @@ -445,12 +455,13 @@ def process_request(self): self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id - response = self.orchestrator.SendModelValidation(validation) + response = self.orchestrator.SendModelValidation( + validation) self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_VALIDATION, request=validation) + type=fedn.StatusType.MODEL_VALIDATION, request=validation) else: self._send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request) + log_level=fedn.Status.WARNING, request=request) self.state = ClientState.idle self.inbox.task_done() @@ -459,15 +470,16 @@ def process_request(self): def _process_training_request(self, model_id): """Process a training (model update) request. - + Parameters ---------- model_id : Str The id of the model to update. - + """ - self._send_status("\t Starting processing of training request for model_id {}".format(model_id)) + self._send_status( + "\t Starting processing of training request for model_id {}".format(model_id)) self.state = ClientState.training try: @@ -500,7 +512,8 @@ def _process_training_request(self, model_id): os.unlink(outpath) except Exception as e: - print("ERROR could not process training request due to error: {}".format(e), flush=True) + print("ERROR could not process training request due to error: {}".format( + e), flush=True) updated_model_id = None meta = {'status': 'failed', 'error': str(e)} @@ -509,7 +522,8 @@ def _process_training_request(self, model_id): return updated_model_id, meta def _process_validation_request(self, model_id): - self._send_status("Processing validation request for model_id {}".format(model_id)) + self._send_status( + "Processing validation request for model_id {}".format(model_id)) self.state = ClientState.validating try: model = self.get_model(str(model_id)) @@ -540,33 +554,35 @@ def _handle_combiner_failure(self): """ Register failed combiner connection. """ - self._missed_heartbeat += 1 - if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: + self._missed_heartbeat += 1 + if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: self._detach() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. - + Parameters ---------- update_frequency : float The interval in seconds between heartbeat messages. - + """ while True: - heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) + heartbeat = fedn.Heartbeat(sender=fedn.Client( + name=self.name, role=fedn.WORKER)) try: self.connection.SendHeartbeat(heartbeat) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() - print("CLIENT heartbeat: GRPC ERROR {} retrying..".format(status_code.name), flush=True) + print("CLIENT heartbeat: GRPC ERROR {} retrying..".format( + status_code.name), flush=True) self._handle_combiner_failure() time.sleep(update_frequency) - if not self._attached: - return + if not self._attached: + return def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None): """Send status message. """ @@ -590,17 +606,17 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None) status.status)) response = self.connection.SendStatus(status) - def run_web(self): """Starts a local logging UI (Flask app) serving on port 8080. - + Currently not in use as default. - + """ from flask import Flask app = Flask(__name__) from fedn.common.net.web.client import page, style + @app.route('/') def index(): """ @@ -613,7 +629,8 @@ def index(): return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) - import os, sys + import os + import sys self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') app.run(host="0.0.0.0", port="8080") @@ -630,13 +647,15 @@ def run(self): time.sleep(1) cnt += 1 if self.state != old_state: - print("{}:CLIENT in {} state".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), ClientStateToString(self.state)), flush=True) + print("{}:CLIENT in {} state".format(datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), ClientStateToString(self.state)), flush=True) if cnt > 5: - print("{}:CLIENT active".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')), flush=True) + print("{}:CLIENT active".format( + datetime.now().strftime('%Y-%m-%d %H:%M:%S')), flush=True) cnt = 0 if not self._attached: print("Detatched from combiner.", flush=True) - # TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s + # TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s self._attach() self._subscribe_to_combiner(self.config) if self.error_state: diff --git a/fedn/fedn/clients/combiner/modelservice.py b/fedn/fedn/clients/combiner/modelservice.py index da5eed7a9..b5cd35de7 100644 --- a/fedn/fedn/clients/combiner/modelservice.py +++ b/fedn/fedn/clients/combiner/modelservice.py @@ -32,8 +32,8 @@ def get_model(self, id): from io import BytesIO data = BytesIO() data.seek(0, 0) - import time import random + import time parts = self.Download(fedn.ModelRequest(id=id), self) for part in parts: @@ -73,9 +73,11 @@ def upload_request_generator(mdl): while True: b = mdl.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + result = fedn.ModelRequest( + data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: - result = fedn.ModelRequest(id=id, data=None, status=fedn.ModelStatus.OK) + result = fedn.ModelRequest( + id=id, data=None, status=fedn.ModelStatus.OK) yield result if not b: break @@ -83,7 +85,7 @@ def upload_request_generator(mdl): # TODO: Check result result = self.Upload(upload_request_generator(bt), self) - ## Model Service + # Model Service def Upload(self, request_iterator, context): """ @@ -100,7 +102,7 @@ def Upload(self, request_iterator, context): if request.status == fedn.ModelStatus.OK and not request.data: result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, - message="Got model successfully.") + message="Got model successfully.") # self.models_metadata.update({request.id: fedn.ModelStatus.OK}) self.models.set_meta(request.id, fedn.ModelStatus.OK) self.models.get_ptr(request.id).flush() diff --git a/fedn/fedn/clients/combiner/roundcontrol.py b/fedn/fedn/clients/combiner/roundcontrol.py index 6da204840..c6d3be9db 100644 --- a/fedn/fedn/clients/combiner/roundcontrol.py +++ b/fedn/fedn/clients/combiner/roundcontrol.py @@ -1,20 +1,19 @@ -import time import json import os import queue +import sys import tempfile import time import uuid -import sys -import queue +from threading import Lock, Thread import fedn.common.net.grpc.fedn_pb2 as fedn -from threading import Thread, Lock from fedn.utils.helpers import get_helper - + + class RoundControl: """ Combiner level round controller. - + The controller recieves round configurations from the global controller and acts on them by soliciting model updates and model validations from the connected clients. @@ -40,7 +39,8 @@ def __init__(self, id, storage, server, modelservice): # TODO, make runtime configurable from fedn.aggregators.fedavg import FedAvgAggregator - self.aggregator = FedAvgAggregator(self.id, self.storage, self.server, self.modelservice, self) + self.aggregator = FedAvgAggregator( + self.id, self.storage, self.server, self.modelservice, self) def push_round_config(self, round_config): """ Recieve a round_config (job description) and push on the queue. @@ -55,10 +55,11 @@ def push_round_config(self, round_config): round_config['_job_id'] = str(uuid.uuid4()) self.round_configs.put(round_config) except: - self.server.report_status("ROUNDCONTROL: Failed to push round config.", flush=True) + self.server.report_status( + "ROUNDCONTROL: Failed to push round config.", flush=True) raise return round_config['_job_id'] - + def load_model_fault_tolerant(self, model_id, retry=3): """Load model update object. @@ -79,7 +80,8 @@ def load_model_fault_tolerant(self, model_id, retry=3): while tries < retry: tries += 1 if not model_str or sys.getsizeof(model_str) == 80: - self.server.report_status("ROUNDCONTROL: Model download failed. retrying", flush=True) + self.server.report_status( + "ROUNDCONTROL: Model download failed. retrying", flush=True) import time time.sleep(1) model_str = self.modelservice.get_model(model_id) @@ -98,11 +100,12 @@ def _training_round(self, config, clients): """ # We flush the queue at a beginning of a round (no stragglers allowed) - # TODO: Support other ways to handle stragglers. + # TODO: Support other ways to handle stragglers. with self.aggregator.model_updates.mutex: self.aggregator.model_updates.queue.clear() - self.server.report_status("ROUNDCONTROL: Initiating training round, participating members: {}".format(clients)) + self.server.report_status( + "ROUNDCONTROL: Initiating training round, participating members: {}".format(clients)) self.server.request_model_update(config['model_id'], clients=clients) meta = {} @@ -115,8 +118,9 @@ def _training_round(self, config, clients): try: helper = get_helper(config['helper_type']) model, data = self.aggregator.combine_models(nr_expected_models=len(clients), - nr_required_models=int(config['clients_required']), - helper=helper, timeout=float(config['round_timeout'])) + nr_required_models=int( + config['clients_required']), + helper=helper, timeout=float(config['round_timeout'])) except Exception as e: print("TRAINING ROUND FAILED AT COMBINER! {}".format(e), flush=True) meta['time_combination'] = time.time() - tic @@ -150,7 +154,7 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): if self.modelservice.models.exist(model_id): return - # If it is not there, download it from storage and stage it in memory at the server. + # If it is not there, download it from storage and stage it in memory at the server. tries = 0 while True: try: @@ -159,15 +163,16 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): break except Exception as e: self.server.report_status("ROUNDCONTROL: Could not fetch model from storage backend, retrying.", - flush=True) + flush=True) time.sleep(timeout_retry) tries += 1 if tries > retry: - self.server.report_status("ROUNDCONTROL: Failed to stage model {} from storage backend!".format(model_id), flush=True) + self.server.report_status( + "ROUNDCONTROL: Failed to stage model {} from storage backend!".format(model_id), flush=True) return self.modelservice.set_model(model, model_id) - + def __assign_round_clients(self, n, type="trainers"): """ Obtain a list of clients (trainers or validators) to talk to in a round. @@ -184,14 +189,14 @@ def __assign_round_clients(self, n, type="trainers"): elif type == "trainers": clients = self.server.get_active_trainers() else: - self.server.report_status("ROUNDCONTROL(ERROR): {} is not a supported type of client".format(type), flush=True) + self.server.report_status( + "ROUNDCONTROL(ERROR): {} is not a supported type of client".format(type), flush=True) raise - - # If the number of requested trainers exceeds the number of available, use all available. + # If the number of requested trainers exceeds the number of available, use all available. if n > len(clients): n = len(clients) - + # If not, we pick a random subsample of all available clients. import random clients = random.sample(clients, n) @@ -239,10 +244,12 @@ def execute_validation(self, round_config): :type round_config: [type] """ model_id = round_config['model_id'] - self.server.report_status("COMBINER orchestrating validation of model {}".format(model_id)) + self.server.report_status( + "COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) - validators = self.__assign_round_clients(self.server.max_clients,type="validators") - self._validation_round(round_config,validators,model_id) + validators = self.__assign_round_clients( + self.server.max_clients, type="validators") + self._validation_round(round_config, validators, model_id) def execute_training(self, config): """ Coordinates clients to execute training and validation tasks. """ @@ -256,17 +263,19 @@ def execute_training(self, config): # Execute the configured number of rounds round_meta['local_round'] = {} for r in range(1, int(config['rounds']) + 1): - self.server.report_status("ROUNDCONTROL: Starting training round {}".format(r), flush=True) + self.server.report_status( + "ROUNDCONTROL: Starting training round {}".format(r), flush=True) clients = self.__assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) round_meta['local_round'][str(r)] = meta if model is None: - self.server.report_status("\t Failed to update global model in round {0}!".format(r)) + self.server.report_status( + "\t Failed to update global model in round {0}!".format(r)) if model is not None: helper = get_helper(config['helper_type']) a = helper.serialize_model_to_BytesIO(model) - # Send aggregated model to server + # Send aggregated model to server model_id = str(uuid.uuid4()) self.modelservice.set_model(a, model_id) a.close() @@ -275,7 +284,8 @@ def execute_training(self, config): self.server.set_active_model(model_id) print("------------------------------------------") - self.server.report_status("ROUNDCONTROL: TRAINING ROUND COMPLETED.", flush=True) + self.server.report_status( + "ROUNDCONTROL: TRAINING ROUND COMPLETED.", flush=True) print("\n") return round_meta @@ -293,15 +303,18 @@ def run(self): if round_config['task'] == 'training': tic = time.time() round_meta = self.execute_training(round_config) - round_meta['time_exec_training'] = time.time() - tic + round_meta['time_exec_training'] = time.time() - \ + tic round_meta['name'] = self.id self.server.tracer.set_round_meta(round_meta) elif round_config['task'] == 'validation': self.execute_validation(round_config) else: - self.server.report_status("ROUNDCONTROL: Round config contains unkown task type.", flush=True) + self.server.report_status( + "ROUNDCONTROL: Round config contains unkown task type.", flush=True) else: - self.server.report_status("ROUNDCONTROL: Failed to meet client allocation requirements for this round config.", flush=True) + self.server.report_status( + "ROUNDCONTROL: Failed to meet client allocation requirements for this round config.", flush=True) except queue.Empty: time.sleep(1) diff --git a/fedn/fedn/clients/reducer/control.py b/fedn/fedn/clients/reducer/control.py index 83c82ac1a..692dfcfca 100644 --- a/fedn/fedn/clients/reducer/control.py +++ b/fedn/fedn/clients/reducer/control.py @@ -4,10 +4,10 @@ import time from datetime import datetime +import fedn.utils.helpers from fedn.clients.reducer.interfaces import CombinerUnavailableError from fedn.clients.reducer.network import Network from fedn.common.tracer.mongotracer import MongoTracer -import fedn.utils.helpers from .state import ReducerState @@ -35,10 +35,12 @@ def __init__(self, statestore): try: config = self.statestore.get_storage_backend() except: - print("REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) + print( + "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) raise MisconfiguredStorageBackend() if not config: - print("REDUCER CONTROL: No storage configuration available, exiting.", flush=True) + print( + "REDUCER CONTROL: No storage configuration available, exiting.", flush=True) raise MisconfiguredStorageBackend() if config['storage_type'] == 'S3': @@ -135,7 +137,8 @@ def get_compute_context(self): context = definition['filename'] return context except (IndexError, KeyError): - print("No context filename set for compute context definition", flush=True) + print( + "No context filename set for compute context definition", flush=True) return None else: return None @@ -164,7 +167,8 @@ def commit(self, model_id, model=None): outfile_name = helper.save_model(model) print("DONE", flush=True) print("Uploading model to Minio...", flush=True) - model_id = self.model_repository.set_model(outfile_name, is_file=True) + model_id = self.model_repository.set_model( + outfile_name, is_file=True) print("DONE", flush=True) os.unlink(outfile_name) @@ -182,7 +186,7 @@ def _out_of_sync(self, combiners=None): except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) model_id = None - + if model_id and (model_id != self.get_latest_model()): osync.append(combiner) return osync @@ -198,7 +202,7 @@ def check_round_participation_policy(self, compute_plan, combiner_state): elif compute_plan['task'] == 'validation': nr_active_clients = int(combiner_state['nr_active_validators']) else: - print("Invalid task type!",flush=True) + print("Invalid task type!", flush=True) return False if int(compute_plan['clients_required']) <= nr_active_clients: @@ -228,10 +232,10 @@ def check_round_validity_policy(self, combiners): def _handle_unavailable_combiner(self, combiner): """ This callback is triggered if a combiner is found to be unresponsive. """ # TODO: Implement strategy to handle the case. - print("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name), flush=True) + print("REDUCER CONTROL: Combiner {} unavailable.".format( + combiner.name), flush=True) - - def _select_round_combiners(self,compute_plan): + def _select_round_combiners(self, compute_plan): combiners = [] for combiner in self.network.get_combiners(): try: @@ -241,9 +245,10 @@ def _select_round_combiners(self,compute_plan): combiner_state = None if combiner_state: - is_participating = self.check_round_participation_policy(compute_plan,combiner_state) + is_participating = self.check_round_participation_policy( + compute_plan, combiner_state) if is_participating: - combiners.append((combiner,compute_plan)) + combiners.append((combiner, compute_plan)) return combiners def round(self, config, round_number): @@ -275,23 +280,26 @@ def round(self, config, round_number): combiner_state = None if combiner_state != None: - is_participating = self.check_round_participation_policy(compute_plan, combiner_state) + is_participating = self.check_round_participation_policy( + compute_plan, combiner_state) if is_participating: combiners.append((combiner, compute_plan)) round_start = self.check_round_start_policy(combiners) - print("CONTROL: round start policy met, participating combiners {}".format(combiners), flush=True) + print("CONTROL: round start policy met, participating combiners {}".format( + combiners), flush=True) if not round_start: print("CONTROL: Round start policy not met, skipping round!", flush=True) return None # 2. Sync up and ask participating combiners to coordinate model updates # TODO refactor - + statestore_config = self.statestore.get_config() - self.tracer = MongoTracer(statestore_config['mongo_config'], statestore_config['network_id']) + self.tracer = MongoTracer( + statestore_config['mongo_config'], statestore_config['network_id']) start_time = datetime.now() @@ -346,7 +354,8 @@ def round(self, config, round_number): model, data = self.reduce(updated) round_meta['reduce'] = data except Exception as e: - print("CONTROL: Failed to reduce models from combiners: {}".format(updated), flush=True) + print("CONTROL: Failed to reduce models from combiners: {}".format( + updated), flush=True) print(e, flush=True) return None, round_meta print("DONE", flush=True) @@ -361,7 +370,8 @@ def round(self, config, round_number): self.commit(model_id, model) round_meta['time_commit'] = time.time() - tic else: - print("REDUCER: failed to update model in round with config {}".format(config), flush=True) + print("REDUCER: failed to update model in round with config {}".format( + config), flush=True) return None, round_meta print("DONE", flush=True) @@ -373,11 +383,12 @@ def round(self, config, round_number): combiner_config['task'] = 'validation' combiner_config['helper_type'] = self.statestore.get_framework() - validating_combiners = self._select_round_combiners(combiner_config) + validating_combiners = self._select_round_combiners( + combiner_config) for combiner, combiner_config in validating_combiners: try: - self.sync_combiners([combiner],self.get_latest_model()) + self.sync_combiners([combiner], self.get_latest_model()) combiner.start(combiner_config) except CombinerUnavailableError: # OK if validation fails for a combiner @@ -418,7 +429,8 @@ def instruct(self, config): # TODO: Refactor from fedn.common.tracer.mongotracer import MongoTracer statestore_config = self.statestore.get_config() - self.tracer = MongoTracer(statestore_config['mongo_config'], statestore_config['network_id']) + self.tracer = MongoTracer( + statestore_config['mongo_config'], statestore_config['network_id']) last_round = self.tracer.get_latest_round() for round in range(1, int(config['rounds'] + 1)): @@ -443,7 +455,8 @@ def instruct(self, config): end_time = datetime.now() if model_id: - print("REDUCER: Global round completed, new model: {}".format(model_id), flush=True) + print("REDUCER: Global round completed, new model: {}".format( + model_id), flush=True) round_time = end_time - start_time self.tracer.set_latest_time(current_round, round_time.seconds) round_meta['status'] = 'Success' diff --git a/fedn/fedn/clients/reducer/interfaces.py b/fedn/fedn/clients/reducer/interfaces.py index 359440fae..5cb855ca5 100644 --- a/fedn/fedn/clients/reducer/interfaces.py +++ b/fedn/fedn/clients/reducer/interfaces.py @@ -1,7 +1,9 @@ +import json + +import grpc + import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc -import grpc -import json class CombinerUnavailableError(Exception): @@ -19,10 +21,13 @@ def __init__(self, address, port, certificate): self.certificate = certificate if self.certificate: import copy - credentials = grpc.ssl_channel_credentials(root_certificates=copy.deepcopy(certificate)) - self.channel = grpc.secure_channel('{}:{}'.format(self.address, str(self.port)), credentials) + credentials = grpc.ssl_channel_credentials( + root_certificates=copy.deepcopy(certificate)) + self.channel = grpc.secure_channel('{}:{}'.format( + self.address, str(self.port)), credentials) else: - self.channel = grpc.insecure_channel('{}:{}'.format(self.address, str(self.port))) + self.channel = grpc.insecure_channel( + '{}:{}'.format(self.address, str(self.port))) def get_channel(self): """ @@ -105,7 +110,8 @@ def report(self, config=None): :param config: :return: """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() try: @@ -127,7 +133,8 @@ def configure(self, config=None): """ if not config: config = self.config - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() @@ -150,7 +157,8 @@ def start(self, config): :param config: :return: """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() request.command = fedn.Command.START @@ -175,7 +183,8 @@ def set_model_id(self, model_id): :param model_id: """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() control = rpc.ControlStub(channel) request = fedn.ControlRequest() p = request.parameter.add() @@ -195,7 +204,8 @@ def get_model_id(self): :return: """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() reducer = rpc.ReducerStub(channel) request = fedn.GetGlobalModelRequest() try: @@ -211,7 +221,8 @@ def get_model_id(self): def get_model(self, id=None): """ Retrive the model bundle from a combiner. """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() modelservice = rpc.ModelServiceStub(channel) if not id: @@ -235,7 +246,8 @@ def allowing_clients(self): :return: """ - channel = Channel(self.address, self.port, self.certificate).get_channel() + channel = Channel(self.address, self.port, + self.certificate).get_channel() connector = rpc.ConnectorStub(channel) request = fedn.ConnectionRequest() diff --git a/fedn/fedn/clients/reducer/network.py b/fedn/fedn/clients/reducer/network.py index 2c64762dc..0bdd4f0ee 100644 --- a/fedn/fedn/clients/reducer/network.py +++ b/fedn/fedn/clients/reducer/network.py @@ -1,10 +1,11 @@ +import base64 import copy import time -import base64 + +from fedn.clients.reducer.interfaces import (CombinerInterface, + CombinerUnavailableError) from .state import ReducerState -from fedn.clients.reducer.interfaces import CombinerInterface -from fedn.clients.reducer.interfaces import CombinerUnavailableError class Network: @@ -67,7 +68,7 @@ def add_client(self, client): """ if self.find_client(client['name']): - return + return print("adding client {}".format(client['name']), flush=True) self.statestore.set_client(client) @@ -125,4 +126,4 @@ def update_client_data(self, client_data, status, role): def get_client_info(self): """ list available client in DB""" - return self.statestore.list_clients() \ No newline at end of file + return self.statestore.list_clients() diff --git a/fedn/fedn/clients/reducer/plots.py b/fedn/fedn/clients/reducer/plots.py index aa7a9691b..b3c6f22e3 100644 --- a/fedn/fedn/clients/reducer/plots.py +++ b/fedn/fedn/clients/reducer/plots.py @@ -1,24 +1,24 @@ -from numpy.core.einsumfunc import _flop_count -import pymongo import json -import numpy -import plotly.graph_objs as go -from datetime import datetime, timedelta -import plotly -import os -from fedn.common.storage.db.mongo import connect_to_mongodb, drop_mongodb import math +import os +from datetime import datetime, timedelta -import plotly.express as px import geoip2.database -import pandas as pd - import networkx +import numpy import pandas as pd -from bokeh.models import (Circle, Label, LabelSet, - MultiLine, NodesAndLinkedEdges, Range1d, ColumnDataSource) -from bokeh.plotting import figure, from_networkx +import plotly +import plotly.express as px +import plotly.graph_objs as go +import pymongo +from bokeh.models import (Circle, ColumnDataSource, Label, LabelSet, MultiLine, + NodesAndLinkedEdges, Range1d) from bokeh.palettes import Spectral8 +from bokeh.plotting import figure, from_networkx +from numpy.core.einsumfunc import _flop_count + +from fedn.common.storage.db.mongo import connect_to_mongodb, drop_mongodb + class Plot: """ @@ -28,7 +28,8 @@ class Plot: def __init__(self, statestore): try: statestore_config = statestore.get_config() - self.mdb = connect_to_mongodb(statestore_config['mongo_config'], statestore_config['network_id']) + self.mdb = connect_to_mongodb( + statestore_config['mongo_config'], statestore_config['network_id']) self.status = self.mdb['control.status'] self.round_time = self.mdb["control.round_time"] self.combiner_round_time = self.mdb["control.combiner_round_time"] @@ -66,7 +67,8 @@ def create_table_plot(self): metrics = self.status.find_one({'type': 'MODEL_VALIDATION'}) if metrics == None: fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for table mean metrics') + fig.update_layout( + title_text='No data currently available for table mean metrics') table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False @@ -84,9 +86,11 @@ def create_table_plot(self): for post in self.status.find({'type': 'MODEL_VALIDATION'}): e = json.loads(post['data']) try: - validations[e['modelId']].append(float(json.loads(e['data'])[metric])) + validations[e['modelId']].append( + float(json.loads(e['data'])[metric])) except KeyError: - validations[e['modelId']] = [float(json.loads(e['data'])[metric])] + validations[e['modelId']] = [ + float(json.loads(e['data'])[metric])] vals = [] models = [] @@ -202,7 +206,8 @@ def create_client_training_distribution(self): if not training: return False fig = go.Figure(data=go.Histogram(x=training)) - fig.update_layout(title_text='Client model training time, mean: {}'.format(numpy.mean(training))) + fig.update_layout( + title_text='Client model training time, mean: {}'.format(numpy.mean(training))) histogram = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return histogram @@ -224,12 +229,14 @@ def create_client_plot(self): processing.append(meta['processing_time']) from plotly.subplots import make_subplots - fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "histogram"}]]) + fig = make_subplots(rows=1, cols=2, specs=[ + [{"type": "pie"}, {"type": "histogram"}]]) fig.update_layout( template="simple_white", xaxis=dict(title_text="Seconds"), - title="Total mean client processing time: {}".format(numpy.mean(processing)), + title="Total mean client processing time: {}".format( + numpy.mean(processing)), showlegend=True ) if not processing: @@ -266,13 +273,16 @@ def create_combiner_plot(self): except: pass - labels = ['Waiting for updates', 'Aggregating model updates', 'Loading model updates'] - val = [numpy.mean(waiting), numpy.mean(aggregation), numpy.mean(model_load)] + labels = ['Waiting for updates', + 'Aggregating model updates', 'Loading model updates'] + val = [numpy.mean(waiting), numpy.mean( + aggregation), numpy.mean(model_load)] fig = go.Figure() fig.update_layout( template="simple_white", - title="Total mean combiner round time: {}".format(numpy.mean(combination)), + title="Total mean combiner round time: {}".format( + numpy.mean(combination)), showlegend=True ) if not combination: @@ -315,9 +325,11 @@ def create_box_plot(self, metric): for post in self.status.find({'type': 'MODEL_VALIDATION'}): e = json.loads(post['data']) try: - validations[e['modelId']].append(float(json.loads(e['data'])[metric])) + validations[e['modelId']].append( + float(json.loads(e['data'])[metric])) except KeyError: - validations[e['modelId']] = [float(json.loads(e['data'])[metric])] + validations[e['modelId']] = [ + float(json.loads(e['data'])[metric])] # Make sure validations are plotted in chronological order model_trail = self.mdb.control.model.find_one({'key': 'model_trail'}) @@ -343,7 +355,8 @@ def create_box_plot(self, metric): box.add_trace(go.Box(y=acc, name=str(j), marker_color="royalblue", showlegend=False, boxpoints=False)) else: - box.add_trace(go.Scatter(x=[str(j)], y=[y[j]], showlegend=False)) + box.add_trace(go.Scatter( + x=[str(j)], y=[y[j]], showlegend=False)) rounds = list(range(len(y))) box.add_trace(go.Scatter( @@ -353,7 +366,8 @@ def create_box_plot(self, metric): )) box.update_xaxes(title_text='Rounds') - box.update_yaxes(tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + box.update_yaxes( + tickvals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) box.update_layout(title_text='Metric distribution over clients: {}'.format(metric), margin=dict(l=20, r=20, t=45, b=20)) box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder) @@ -368,7 +382,8 @@ def create_round_plot(self): metrics = self.round_time.find_one({'key': 'round_time'}) if metrics == None: fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for round time') + fig.update_layout( + title_text='No data currently available for round time') ml = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False @@ -408,7 +423,8 @@ def create_cpu_plot(self): metrics = self.psutil_usage.find_one({'key': 'cpu_mem_usage'}) if metrics == None: fig = go.Figure(data=[]) - fig.update_layout(title_text='No data currently available for MEM and CPU usage') + fig.update_layout( + title_text='No data currently available for MEM and CPU usage') cpu = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) return False @@ -452,10 +468,10 @@ def create_cpu_plot(self): def get_client_df(self): clients = self.network_clients df = pd.DataFrame(list(clients.find())) - active_clients = df['status']=="active" + active_clients = df['status'] == "active" print(df[active_clients]) return df - + def make_single_node_plot(self): """ Plot single node graph with reducer @@ -468,23 +484,24 @@ def make_single_node_plot(self): ("Role", "@role"), ("Status", "@status"), ("Id", "@index"), - ] - + ] + G = networkx.Graph() G.add_node("reducer", adjusted_node_size=20, role='reducer', - status='active', - name='reducer', - color_by_this_attribute=Spectral8[0]) + status='active', + name='reducer', + color_by_this_attribute=Spectral8[0]) network_graph = from_networkx(G, networkx.spring_layout) - network_graph.node_renderer.glyph = Circle(size=20, fill_color = Spectral8[0]) + network_graph.node_renderer.glyph = Circle( + size=20, fill_color=Spectral8[0]) network_graph.node_renderer.hover_glyph = Circle(size=20, fill_color='white', line_width=2) network_graph.node_renderer.selection_glyph = Circle(size=20, fill_color='white', line_width=2) plot = figure(tooltips=HOVER_TOOLTIPS, tools="pan,wheel_zoom,save,reset", active_scroll='wheel_zoom', - width=725, height=460, sizing_mode='stretch_width', - x_range=Range1d(-1.5, 1.5), y_range=Range1d(-1.5, 1.5)) - + width=725, height=460, sizing_mode='stretch_width', + x_range=Range1d(-1.5, 1.5), y_range=Range1d(-1.5, 1.5)) + plot.renderers.append(network_graph) plot.axis.visible = False @@ -492,14 +509,11 @@ def make_single_node_plot(self): plot.outline_line_color = None label = Label(x=0, y=0, text='reducer', - background_fill_color='#4bbf73', text_font_size='15px', - background_fill_alpha=.7, x_offset=-20, y_offset=10) - + background_fill_color='#4bbf73', text_font_size='15px', + background_fill_alpha=.7, x_offset=-20, y_offset=10) + plot.add_layout(label) return plot - - - def make_netgraph_plot(self, df, df_nodes): """ @@ -514,18 +528,21 @@ def make_netgraph_plot(self, df, df_nodes): """ if df.empty: - #no combiners and thus no clients, plot only reducer + # no combiners and thus no clients, plot only reducer plot = self.make_single_node_plot() return plot - - G = networkx.from_pandas_edgelist(df, 'source', 'target', create_using=networkx.Graph()) + + G = networkx.from_pandas_edgelist( + df, 'source', 'target', create_using=networkx.Graph()) degrees = dict(networkx.degree(G)) networkx.set_node_attributes(G, name='degree', values=degrees) number_to_adjust_by = 20 - adjusted_node_size = dict([(node, degree + number_to_adjust_by) for node, degree in networkx.degree(G)]) - networkx.set_node_attributes(G, name='adjusted_node_size', values=adjusted_node_size) - + adjusted_node_size = dict( + [(node, degree + number_to_adjust_by) for node, degree in networkx.degree(G)]) + networkx.set_node_attributes( + G, name='adjusted_node_size', values=adjusted_node_size) + # community from networkx.algorithms import community communities = community.greedy_modularity_communities(G) @@ -543,16 +560,15 @@ def make_netgraph_plot(self, df, df_nodes): networkx.set_node_attributes(G, modularity_class, 'modularity_class') networkx.set_node_attributes(G, modularity_color, 'modularity_color') - node_role = {k:v for k,v in zip(df_nodes.id, df_nodes.role)} + node_role = {k: v for k, v in zip(df_nodes.id, df_nodes.role)} networkx.set_node_attributes(G, node_role, 'role') - - node_status = {k:v for k,v in zip(df_nodes.id, df_nodes.status)} + + node_status = {k: v for k, v in zip(df_nodes.id, df_nodes.status)} networkx.set_node_attributes(G, node_status, 'status') - node_name = {k:v for k,v in zip(df_nodes.id, df_nodes.name)} + node_name = {k: v for k, v in zip(df_nodes.id, df_nodes.name)} networkx.set_node_attributes(G, node_name, 'name') - # Choose colors for node and edge highlighting node_highlight_color = 'white' edge_highlight_color = 'black' @@ -578,29 +594,34 @@ def make_netgraph_plot(self, df, df_nodes): # Create a network graph object # https://networkx.github.io/documentation/networkx-1.9/reference/generated/networkx.drawing.layout.spring_layout.html # if one like lock reducer add args: pos={'reducer':(0,1)}, fixed=['reducer'] - network_graph = from_networkx(G, networkx.spring_layout, scale=1, center=(0, 0), seed=45) + network_graph = from_networkx( + G, networkx.spring_layout, scale=1, center=(0, 0), seed=45) # Set node sizes and colors according to node degree (color as category from attribute) - network_graph.node_renderer.glyph = Circle(size=size_by_this_attribute, fill_color=color_by_this_attribute) + network_graph.node_renderer.glyph = Circle( + size=size_by_this_attribute, fill_color=color_by_this_attribute) # Set node highlight colors network_graph.node_renderer.hover_glyph = Circle(size=size_by_this_attribute, fill_color=node_highlight_color, line_width=2) network_graph.node_renderer.selection_glyph = Circle(size=size_by_this_attribute, fill_color=node_highlight_color, line_width=2) - + # Set edge opacity and width - network_graph.edge_renderer.glyph = MultiLine(line_alpha=0.5, line_width=1) + network_graph.edge_renderer.glyph = MultiLine( + line_alpha=0.5, line_width=1) # Set edge highlight colors - network_graph.edge_renderer.selection_glyph = MultiLine(line_color=edge_highlight_color, line_width=2) - network_graph.edge_renderer.hover_glyph = MultiLine(line_color=edge_highlight_color, line_width=2) + network_graph.edge_renderer.selection_glyph = MultiLine( + line_color=edge_highlight_color, line_width=2) + network_graph.edge_renderer.hover_glyph = MultiLine( + line_color=edge_highlight_color, line_width=2) # Highlight nodes and edges network_graph.selection_policy = NodesAndLinkedEdges() network_graph.inspection_policy = NodesAndLinkedEdges() plot.renderers.append(network_graph) - - #Node labels, red if status is offline, green is active + + # Node labels, red if status is offline, green is active x, y = zip(*network_graph.layout_provider.graph_layout.values()) node_names = list(G.nodes(data='name')) node_status = list(G.nodes(data='status')) @@ -615,14 +636,15 @@ def make_netgraph_plot(self, df, df_nodes): idx_offline.append(e) node_labels.append(n[1]) - source_on = ColumnDataSource({'x': numpy.asarray(x)[idx_online], 'y': numpy.asarray(y)[idx_online], 'name': numpy.asarray(node_labels)[idx_online]}) + source_on = ColumnDataSource({'x': numpy.asarray(x)[idx_online], 'y': numpy.asarray(y)[ + idx_online], 'name': numpy.asarray(node_labels)[idx_online]}) labels = LabelSet(x='x', y='y', text='name', source=source_on, background_fill_color='#4bbf73', text_font_size='15px', background_fill_alpha=.7, x_offset=-20, y_offset=10) plot.renderers.append(labels) - - source_off = ColumnDataSource({'x': numpy.asarray(x)[idx_offline], 'y': numpy.asarray(y)[idx_offline], 'name': numpy.asarray(node_labels)[idx_offline]}) + source_off = ColumnDataSource({'x': numpy.asarray(x)[idx_offline], 'y': numpy.asarray( + y)[idx_offline], 'name': numpy.asarray(node_labels)[idx_offline]}) labels = LabelSet(x='x', y='y', text='name', source=source_off, background_fill_color='#d9534f', text_font_size='15px', background_fill_alpha=.7, x_offset=-20, y_offset=10) @@ -631,4 +653,4 @@ def make_netgraph_plot(self, df, df_nodes): plot.axis.visible = False plot.grid.visible = False plot.outline_line_color = None - return plot \ No newline at end of file + return plot diff --git a/fedn/fedn/clients/reducer/restservice.py b/fedn/fedn/clients/reducer/restservice.py index 313b85549..860c8844f 100644 --- a/fedn/fedn/clients/reducer/restservice.py +++ b/fedn/fedn/clients/reducer/restservice.py @@ -1,33 +1,32 @@ -from urllib import response +import datetime +import json +import math +import os +import re import uuid -from fedn.clients.reducer.interfaces import CombinerInterface -from fedn.clients.reducer.state import ReducerState, ReducerStateToString -from idna import check_initial_combiner -from tenacity import retry -from werkzeug.utils import secure_filename - -from flask import Flask, jsonify, make_response, render_template, request -from flask import redirect, url_for, flash, abort - from threading import Lock -import re +from urllib import response -import os +import geoip2.database import jwt -import datetime -import json -import plotly -import pandas as pd import numpy -import math - +import pandas as pd +import plotly import plotly.express as px -import geoip2.database +from flask import (Flask, abort, flash, jsonify, make_response, redirect, + render_template, request, url_for) +from idna import check_initial_combiner +from tenacity import retry +from werkzeug.utils import secure_filename + +from fedn.clients.reducer.interfaces import CombinerInterface from fedn.clients.reducer.plots import Plot +from fedn.clients.reducer.state import ReducerState, ReducerStateToString UPLOAD_FOLDER = '/app/client/package/' ALLOWED_EXTENSIONS = {'gz', 'bz2', 'tar', 'zip', 'tgz'} + def allowed_file(filename): """ @@ -37,6 +36,7 @@ def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + def encode_auth_token(secret_key): """Generates the Auth Token :return: string @@ -57,6 +57,7 @@ def encode_auth_token(secret_key): except Exception as e: return e + def decode_auth_token(auth_token, secret): """Decodes the auth token :param auth_token: @@ -64,7 +65,7 @@ def decode_auth_token(auth_token, secret): """ try: payload = jwt.decode( - auth_token, + auth_token, secret, algorithms=['HS256'] ) @@ -92,18 +93,17 @@ def __init__(self, config, control, certificate_manager, certificate=None): self.port = config['discover_port'] self.network_id = config['name'] + '-network' - + if 'token' in config.keys(): self.token_auth_enabled = True else: self.token_auth_enabled = False - if 'secret_key' in config.keys(): + if 'secret_key' in config.keys(): self.SECRET_KEY = config['secret_key'] else: self.SECRET_KEY = None - self.remote_compute_context = config["remote_compute_context"] if self.remote_compute_context: self.package = 'remote' @@ -139,7 +139,7 @@ def check_compute_context(self): return False else: return True - + def check_initial_model(self): """Check if initial model (seed model) has been configured @@ -151,7 +151,7 @@ def check_initial_model(self): return True else: return False - + def check_configured_response(self): """Check if everything has been configured for client to connect, return response if not. @@ -168,7 +168,7 @@ def check_configured_response(self): return jsonify({'status': 'retry', 'package': self.package, 'msg': "Compute package is not configured. Please upload the compute package."}) - + if not self.check_initial_model(): return jsonify({'status': 'retry', 'package': self.package, @@ -211,28 +211,31 @@ def authorize(self, r, secret): """ try: # Get token - if 'Authorization' in r.headers: # header auth + if 'Authorization' in r.headers: # header auth request_token = r.headers.get('Authorization').split()[1] - elif 'token' in r.args: # args auth + elif 'token' in r.args: # args auth request_token = str(r.args.get('token')) elif 'fedn_token' in r.cookies: request_token = r.cookies.get('fedn_token') - else: # no token provided + else: # no token provided print('Authorization failed. No token provided.', flush=True) abort(401) # Log token and secret - print(f'Secret: {secret}. Request token: {request_token}.', flush=True) + print( + f'Secret: {secret}. Request token: {request_token}.', flush=True) # Authenticate status = decode_auth_token(request_token, secret) if status == 'Success': return True else: - print('Authorization failed. Status: "{}"'.format(status), flush=True) + print('Authorization failed. Status: "{}"'.format( + status), flush=True) abort(401) except Exception as e: - print('Authorization failed. Expection encountered: "{}".'.format(e), flush=True) + print('Authorization failed. Expection encountered: "{}".'.format( + e), flush=True) abort(401) def run(self): @@ -263,12 +266,12 @@ def index(): message = request.args.get('message', None) message_type = request.args.get('message_type', None) template = render_template('events.html', client=self.name, state=ReducerStateToString(self.control.state()), - events=events, - logs=None, refresh=True, configured=True, message=message, message_type=message_type) + events=events, + logs=None, refresh=True, configured=True, message=message, message_type=message_type) # Set token cookie in response if needed response = make_response(template) - if 'token' in request.args: # args auth + if 'token' in request.args: # args auth response.set_cookie('fedn_token', str(request.args['token'])) # Return response @@ -298,16 +301,16 @@ def netgraph(): "label": "Reducer", "role": 'reducer', "status": 'active', - "name": 'reducer', #TODO: get real host name + "name": 'reducer', # TODO: get real host name "type": 'reducer', }) - + combiner_info = combiner_status() client_info = client_status() if len(combiner_info) < 1: return result - + for combiner in combiner_info: print("combiner info {}".format(combiner_info), flush=True) try: @@ -315,7 +318,7 @@ def netgraph(): "id": combiner['name'], # "n{}".format(count), "label": "Combiner ({} clients)".format(combiner['nr_active_clients']), "role": 'combiner', - "status": 'active', #TODO: Hard-coded, combiner_info does not contain status + "status": 'active', # TODO: Hard-coded, combiner_info does not contain status "name": combiner['name'], "type": 'combiner', }) @@ -335,7 +338,7 @@ def netgraph(): }) except Exception as err: print(err) - + count = 0 for node in result['nodes']: try: @@ -380,6 +383,7 @@ def events(): :return: """ import json + from bson import json_util json_docs = [] @@ -413,7 +417,8 @@ def add(): if not combiner: # Create a new combiner import base64 - certificate, key = self.certificate_manager.get_or_create(address).get_keypair_raw() + certificate, key = self.certificate_manager.get_or_create( + address).get_keypair_raw() cert_b64 = base64.b64encode(certificate) key_b64 = base64.b64encode(key) @@ -490,7 +495,8 @@ def delete_model_trail(): if request.method == 'POST': from fedn.common.tracer.mongotracer import MongoTracer statestore_config = self.control.statestore.get_config() - self.tracer = MongoTracer(statestore_config['mongo_config'], statestore_config['network_id']) + self.tracer = MongoTracer( + statestore_config['mongo_config'], statestore_config['network_id']) try: self.control.drop_models() except: @@ -550,7 +556,7 @@ def control(): # checking if there are enough clients connected to start! clients_available = 0 for combiner in self.control.network.get_combiners(): - try: + try: combiner_state = combiner.report() nac = combiner_state['nr_active_clients'] clients_available = clients_available + int(nac) @@ -577,7 +583,8 @@ def control(): 'validate': validate, 'helper_type': helper_type} import threading - threading.Thread(target=self.control.instruct, args=(config,)).start() + threading.Thread(target=self.control.instruct, + args=(config,)).start() # self.control.instruct(config) return redirect(url_for('index', state=state, refresh=refresh, message="Sent execution plan.", message_type='SUCCESS')) @@ -614,7 +621,6 @@ def assign(): if response: return response - name = request.args.get('name', None) combiner_preferred = request.args.get('combiner', None) @@ -636,7 +642,7 @@ def assign(): 'status': 'available' } - # Add client to database + # Add client to database self.control.network.add_client(client) # Return connection information to client @@ -698,19 +704,26 @@ def client_status(): for client in combiner_info: active_trainers_str = client['active_trainers'] active_validators_str = client['active_validators'] - active_trainers_str = re.sub('[^a-zA-Z0-9-:\n\.]', '', active_trainers_str).replace('name:', ' ') - active_validators_str = re.sub('[^a-zA-Z0-9-:\n\.]', '', active_validators_str).replace('name:', ' ') - all_active_trainers.extend(' '.join(active_trainers_str.split(" ")).split()) - all_active_validators.extend(' '.join(active_validators_str.split(" ")).split()) - - active_trainers_list = [client for client in client_info if client['name'] in all_active_trainers] - active_validators_list = [cl for cl in client_info if cl['name'] in all_active_validators] + active_trainers_str = re.sub( + '[^a-zA-Z0-9-:\n\.]', '', active_trainers_str).replace('name:', ' ') + active_validators_str = re.sub( + '[^a-zA-Z0-9-:\n\.]', '', active_validators_str).replace('name:', ' ') + all_active_trainers.extend( + ' '.join(active_trainers_str.split(" ")).split()) + all_active_validators.extend( + ' '.join(active_validators_str.split(" ")).split()) + + active_trainers_list = [ + client for client in client_info if client['name'] in all_active_trainers] + active_validators_list = [ + cl for cl in client_info if cl['name'] in all_active_validators] all_clients = [cl for cl in client_info] for client in all_clients: status = 'offline' role = 'None' - self.control.network.update_client_data(client, status, role) + self.control.network.update_client_data( + client, status, role) all_active_clients = active_validators_list + active_trainers_list for client in all_active_clients: @@ -723,14 +736,15 @@ def client_status(): role = 'validator' else: role = 'unknown' - self.control.network.update_client_data(client, status, role) + self.control.network.update_client_data( + client, status, role) return {'active_clients': all_clients, 'active_trainers': active_trainers_list, 'active_validators': active_validators_list } except: - pass + pass return {'active_clients': [], 'active_trainers': [], @@ -757,7 +771,7 @@ def dashboard(): # Token auth if self.token_auth_enabled: self.authorize(request, app.config.get('SECRET_KEY')) - + not_configured = self.check_configured() if not_configured: return not_configured @@ -792,7 +806,7 @@ def network(): # Token auth if self.token_auth_enabled: self.authorize(request, app.config.get('SECRET_KEY')) - + not_configured = self.check_configured() if not_configured: return not_configured @@ -847,6 +861,7 @@ def config_download(): chk_string=chk_string) from io import BytesIO + from flask import send_file obj = BytesIO() obj.write(ctx.encode('UTF-8')) @@ -868,7 +883,8 @@ def context(): # if self.control.state() != ReducerState.setup or self.control.state() != ReducerState.idle: # return "Error, Context already assigned!" - reset = request.args.get('reset', None) # if reset is not empty then allow context re-set + # if reset is not empty then allow context re-set + reset = request.args.get('reset', None) if reset: return render_template('context.html') @@ -888,7 +904,8 @@ def context(): if file and allowed_file(file.filename): filename = secure_filename(file.filename) - file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file_path = os.path.join( + app.config['UPLOAD_FOLDER'], filename) file.save(file_path) if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: @@ -907,7 +924,7 @@ def context(): return render_template('context.html') # There is a potential race condition here, if one client requests a package and at - # the same time another one triggers a fetch from Minio and writes to disk. + # the same time another one triggers a fetch from Minio and writes to disk. try: mutex = Lock() mutex.acquire() @@ -957,5 +974,5 @@ def checksum(): str(self.certificate.key_path)), flush=True) app.run(host="0.0.0.0", port=self.port, ssl_context=(str(self.certificate.cert_path), str(self.certificate.key_path))) - + return app diff --git a/fedn/fedn/clients/reducer/statestore/mongoreducerstatestore.py b/fedn/fedn/clients/reducer/statestore/mongoreducerstatestore.py index 1c36c2466..1a1030853 100644 --- a/fedn/fedn/clients/reducer/statestore/mongoreducerstatestore.py +++ b/fedn/fedn/clients/reducer/statestore/mongoreducerstatestore.py @@ -1,5 +1,7 @@ -from fedn.clients.reducer.state import ReducerStateToString, StringToReducerState +from fedn.clients.reducer.state import (ReducerStateToString, + StringToReducerState) from fedn.common.storage.db.mongo import connect_to_mongodb + from .reducerstatestore import ReducerStateStore @@ -22,7 +24,7 @@ def __init__(self, network_id, config, defaults=None): self.clients = self.network['clients'] self.storage = self.network['storage'] self.certificates = self.network['certificates'] - # Control + # Control self.control = self.mdb['control'] self.control_config = self.control['config'] self.state = self.control['state'] @@ -70,11 +72,12 @@ def __init__(self, network_id, config, defaults=None): flush=True) if "context" in control: - print("Setting filepath to {}".format(control['context']), flush=True) + print("Setting filepath to {}".format( + control['context']), flush=True) # TODO Fix the ugly latering of indirection due to a bug in secure_filename returning an object with filename as attribute # TODO fix with unboxing of value before storing and where consuming. self.control.config.update_one({'key': 'package'}, - {'$set': {'filename': control['context']}}, True) + {'$set': {'filename': control['context']}}, True) if "helper" in control: # self.set_framework(control['helper']) pass @@ -133,7 +136,8 @@ def transition(self, state): if old_state != state: return self.state.update_one({'state': 'current_state'}, {'$set': {'state': ReducerStateToString(state)}}, True) else: - print("Not updating state, already in {}".format(ReducerStateToString(state))) + print("Not updating state, already in {}".format( + ReducerStateToString(state))) def set_latest(self, model_id): """ @@ -141,14 +145,16 @@ def set_latest(self, model_id): :param model_id: """ from datetime import datetime - x = self.model.update_one({'key': 'current_model'}, {'$set': {'model': model_id}}, True) + x = self.model.update_one({'key': 'current_model'}, { + '$set': {'model': model_id}}, True) self.model.update_one({'key': 'model_trail'}, {'$push': {'model': model_id, 'committed_at': str(datetime.now())}}, - True) + True) def get_first(self): """ Return model_id for the latest model in the model_trail """ import pymongo - ret = self.model.find_one({'key': 'model_trail'}, sort=[("committed_at", pymongo.ASCENDING)]) + ret = self.model.find_one({'key': 'model_trail'}, sort=[ + ("committed_at", pymongo.ASCENDING)]) if ret == None: return None @@ -180,7 +186,8 @@ def set_round_config(self, config): :param config: """ from datetime import datetime - x = self.control.config.update_one({'key': 'round_config'}, {'$set': config}, True) + x = self.control.config.update_one( + {'key': 'round_config'}, {'$set': config}, True) def get_round_config(self): """ @@ -202,9 +209,10 @@ def set_compute_context(self, filename): :param filename: """ from datetime import datetime - x = self.control.config.update_one({'key': 'package'}, {'$set': {'filename': filename}}, True) + x = self.control.config.update_one( + {'key': 'package'}, {'$set': {'filename': filename}}, True) self.control.config.update_one({'key': 'package_trail'}, - {'$push': {'filename': filename, 'committed_at': str(datetime.now())}}, True) + {'$push': {'filename': filename, 'committed_at': str(datetime.now())}}, True) def get_compute_context(self): """ @@ -226,7 +234,7 @@ def set_framework(self, helper): :param helper: """ self.control.config.update_one({'key': 'package'}, - {'$set': {'helper': helper}}, True) + {'$set': {'helper': helper}}, True) def get_framework(self): """ @@ -234,9 +242,9 @@ def get_framework(self): :return: """ ret = self.control.config.find_one({'key': 'package'}) - #if local compute package used, then 'package' is None + # if local compute package used, then 'package' is None if not ret: - #get framework from round_config instead + # get framework from round_config instead ret = self.control.config.find_one({'key': 'round_config'}) print('FRAMEWORK:', ret) try: @@ -275,25 +283,28 @@ def get_events(self): def get_storage_backend(self): """ """ try: - ret = self.storage.find({'status': 'enabled'}, projection={'_id': False}) + ret = self.storage.find( + {'status': 'enabled'}, projection={'_id': False}) return ret[0] except (KeyError, IndexError): return None def set_storage_backend(self, config): """ """ - from datetime import datetime import copy + from datetime import datetime config = copy.deepcopy(config) config['updated_at'] = str(datetime.now()) config['status'] = 'enabled' - ret = self.storage.update_one({'storage_type': config['storage_type']}, {'$set': config}, True) + ret = self.storage.update_one( + {'storage_type': config['storage_type']}, {'$set': config}, True) def set_reducer(self, reducer_data): """ """ from datetime import datetime reducer_data['updated_at'] = str(datetime.now()) - ret = self.reducer.update_one({'name': reducer_data['name']}, {'$set': reducer_data}, True) + ret = self.reducer.update_one({'name': reducer_data['name']}, { + '$set': reducer_data}, True) def get_reducer(self): """ """ @@ -334,14 +345,16 @@ def set_combiner(self, combiner_data): """ from datetime import datetime combiner_data['updated_at'] = str(datetime.now()) - ret = self.combiners.update_one({'name': combiner_data['name']}, {'$set': combiner_data}, True) + ret = self.combiners.update_one({'name': combiner_data['name']}, { + '$set': combiner_data}, True) def delete_combiner(self, combiner): """ """ try: self.combiners.delete_one({'name': combiner}) except: - print("WARNING, failed to delete combiner: {}".format(combiner), flush=True) + print("WARNING, failed to delete combiner: {}".format( + combiner), flush=True) def set_client(self, client_data): """ @@ -350,7 +363,8 @@ def set_client(self, client_data): """ from datetime import datetime client_data['updated_at'] = str(datetime.now()) - ret = self.clients.update_one({'name': client_data['name']}, {'$set': client_data}, True) + ret = self.clients.update_one({'name': client_data['name']}, { + '$set': client_data}, True) def get_client(self, name): """ """ @@ -373,7 +387,7 @@ def list_clients(self): def drop_control(self): """ """ - # Control + # Control self.state.drop() self.control_config.drop() self.control.drop() @@ -400,4 +414,4 @@ def update_client_status(self, client_data, status, role): "status": status, "role": role } - }) + }) diff --git a/fedn/fedn/combiner.py b/fedn/fedn/combiner.py index cb907fc51..8b38a6d37 100644 --- a/fedn/fedn/combiner.py +++ b/fedn/fedn/combiner.py @@ -1,27 +1,25 @@ +import base64 +import io +import json import os import queue import sys import threading +import time import uuid +from collections import defaultdict from datetime import datetime, timedelta +from enum import Enum + +import requests -from fedn.common.net.connect import ConnectorCombiner, Status -from fedn.common.net.grpc.server import Server import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc - from fedn.clients.combiner.modelservice import ModelService +from fedn.common.net.connect import ConnectorCombiner, Status +from fedn.common.net.grpc.server import Server from fedn.common.storage.s3.s3repo import S3ModelRepository -import requests -import json -import io -import time -import base64 - -from collections import defaultdict - -from enum import Enum class Role(Enum): WORKER = 1 @@ -54,7 +52,7 @@ class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, def __init__(self, connect_config): - # Holds client queues + # Holds client queues self.clients = {} self.modelservice = ModelService() @@ -81,13 +79,13 @@ def __init__(self, connect_config): continue if status == Status.Assigned: config = response - print("COMBINER: was announced successfully. Waiting for clients and commands!", flush=True) + print( + "COMBINER: was announced successfully. Waiting for clients and commands!", flush=True) break if status == Status.UnAuthorized: print(response, flush=True) sys.exit("Exiting: Unauthorized") - cert = base64.b64decode(config['certificate']) # .decode('utf-8') key = base64.b64decode(config['key']) # .decode('utf-8') @@ -96,15 +94,18 @@ def __init__(self, connect_config): 'certificate': cert, 'key': key} - self.repository = S3ModelRepository(config['storage']['storage_config']) + self.repository = S3ModelRepository( + config['storage']['storage_config']) self.server = Server(self, self.modelservice, grpc_config) from fedn.common.tracer.mongotracer import MongoTracer - self.tracer = MongoTracer(config['statestore']['mongo_config'], config['statestore']['network_id']) + self.tracer = MongoTracer( + config['statestore']['mongo_config'], config['statestore']['network_id']) from fedn.clients.combiner.roundcontrol import RoundControl - self.control = RoundControl(self.id, self.repository, self, self.modelservice) - threading.Thread(target=self.control.run, daemon=True).start() + self.control = RoundControl( + self.id, self.repository, self, self.modelservice) + threading.Thread(target=self.control.run, daemon=True).start() self.server.start() @@ -144,11 +145,12 @@ def set_active_model(self, model_id): self.model_id = model_id def report_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, flush=True): - print("{}:COMBINER({}):{} {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), self.id, log_level, msg), flush=flush) + print("{}:COMBINER({}):{} {}".format(datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), self.id, log_level, msg), flush=flush) def request_model_update(self, model_id, clients=[]): """ Ask clients to update the current global model. - + Parameters ---------- model_id : str @@ -157,7 +159,7 @@ def request_model_update(self, model_id, clients=[]): List of clients to submit a model update request to. An empty list (default) results in a broadcast to all connected trainig clients. - + """ request = fedn.ModelUpdateRequest() @@ -173,9 +175,9 @@ def request_model_update(self, model_id, clients=[]): request.receiver.name = client.name request.receiver.role = fedn.WORKER self.SendModelUpdateRequest(request, self) - - print("COMBINER: Sent model update request for model {} to clients {}".format(model_id,clients), flush=True) + print("COMBINER: Sent model update request for model {} to clients {}".format( + model_id, clients), flush=True) def request_model_validation(self, model_id, clients=[]): """ Ask clients to validate the current global model. @@ -188,9 +190,9 @@ def request_model_validation(self, model_id, clients=[]): List of clients to submit a model update request to. An empty list (default) results in a broadcast to all connected trainig clients. - + """ - + request = fedn.ModelValidationRequest() self.__whoami(request.sender, self) request.model_id = model_id @@ -205,7 +207,8 @@ def request_model_validation(self, model_id, clients=[]): request.receiver.role = fedn.WORKER self.SendModelValidationRequest(request, self) - print("COMBINER: Sent validation request for model {} to clients {}".format(model_id,clients), flush=True) + print("COMBINER: Sent validation request for model {} to clients {}".format( + model_id, clients), flush=True) def _list_clients(self, channel): request = fedn.ListClientsRequest() @@ -279,7 +282,8 @@ def __route_request_to_client(self, request, client, queue_name): q = self.__get_queue(client, queue_name) q.put(request) except: - print("Failed to route request to client: {} {}", request.receiver, queue_name) + print("Failed to route request to client: {} {}", + request.receiver, queue_name) raise def _send_status(self, status): @@ -300,7 +304,7 @@ def __register_heartbeat(self, client): ##################################################################################################################### - ## Control Service + # Control Service def Start(self, control: fedn.ControlRequest, context): """ Push a round config to RoundControl. @@ -315,7 +319,8 @@ def Start(self, control: fedn.ControlRequest, context): config = {} for parameter in control.parameter: config.update({parameter.key: parameter.value}) - print("\n\nSTARTING ROUND AT COMBINER WITH ROUND CONFIG: {}\n\n".format(config), flush=True) + print("\n\nSTARTING ROUND AT COMBINER WITH ROUND CONFIG: {}\n\n".format( + config), flush=True) job_id = self.control.push_round_config(config) return response @@ -353,7 +358,7 @@ def Report(self, control: fedn.ControlRequest, context): p = response.parameter.add() p.key = "nr_active_trainers" p.value = str(len(active_trainers)) - + active_validators = self.get_active_validators() p = response.parameter.add() p.key = "nr_active_validators" @@ -378,7 +383,7 @@ def Report(self, control: fedn.ControlRequest, context): p = response.parameter.add() p.key = "nr_active_clients" p.value = str(len(active_trainers)+len(active_validators)) - + p = response.parameter.add() p.key = "model_id" model_id = self.get_active_model() @@ -393,14 +398,15 @@ def Report(self, control: fedn.ControlRequest, context): p = response.parameter.add() p.key = "name" p.value = str(self.id) - + return response ##################################################################################################################### def AllianceStatusStream(self, response, context): """ A server stream RPC endpoint that emits status messages. """ - status = fedn.Status(status="Client {} connecting to AllianceStatusStream.".format(response.sender)) + status = fedn.Status( + status="Client {} connecting to AllianceStatusStream.".format(response.sender)) status.log_level = fedn.Status.INFO status.sender.name = self.id status.sender.role = role_to_proto_role(self.role) @@ -466,7 +472,8 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): :return: """ response = fedn.ConnectionResponse() - active_clients = self._list_active_clients(fedn.Channel.MODEL_UPDATE_REQUESTS) + active_clients = self._list_active_clients( + fedn.Channel.MODEL_UPDATE_REQUESTS) try: requested = int(self.max_clients) @@ -494,7 +501,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): response.response = "Heartbeat received" return response - ## Combiner Service + # Combiner Service def ModelUpdateStream(self, update, context): """ @@ -503,7 +510,8 @@ def ModelUpdateStream(self, update, context): :param context: """ client = update.sender - status = fedn.Status(status="Client {} connecting to ModelUpdateStream.".format(client.name)) + status = fedn.Status( + status="Client {} connecting to ModelUpdateStream.".format(client.name)) status.log_level = fedn.Status.INFO status.sender.name = self.id status.sender.role = role_to_proto_role(self.role) @@ -517,7 +525,7 @@ def ModelUpdateStream(self, update, context): try: yield q.get(timeout=1.0) except queue.Empty: - pass + pass def ModelUpdateRequestStream(self, response, context): """ A server stream RPC endpoint. Messages from client stream. """ @@ -527,14 +535,15 @@ def ModelUpdateRequestStream(self, response, context): if metadata: print("\n\n\nGOT METADATA: {}\n\n\n".format(metadata), flush=True) - status = fedn.Status(status="Client {} connecting to ModelUpdateRequestStream.".format(client.name)) + status = fedn.Status( + status="Client {} connecting to ModelUpdateRequestStream.".format(client.name)) status.log_level = fedn.Status.INFO status.timestamp = str(datetime.now()) - self.__whoami(status.sender, self) - self._subscribe_client_to_queue(client, fedn.Channel.MODEL_UPDATE_REQUESTS) + self._subscribe_client_to_queue( + client, fedn.Channel.MODEL_UPDATE_REQUESTS) q = self.__get_queue(client, fedn.Channel.MODEL_UPDATE_REQUESTS) self._send_status(status) @@ -543,8 +552,7 @@ def ModelUpdateRequestStream(self, response, context): try: yield q.get(timeout=1.0) except queue.Empty: - pass - + pass def ModelValidationStream(self, update, context): """ @@ -553,7 +561,8 @@ def ModelValidationStream(self, update, context): :param context: """ client = update.sender - status = fedn.Status(status="Client {} connecting to ModelValidationStream.".format(client.name)) + status = fedn.Status( + status="Client {} connecting to ModelValidationStream.".format(client.name)) status.log_level = fedn.Status.INFO status.sender.name = self.id @@ -568,20 +577,21 @@ def ModelValidationStream(self, update, context): try: yield q.get(timeout=1.0) except queue.Empty: - pass + pass def ModelValidationRequestStream(self, response, context): """ A server stream RPC endpoint. Messages from client stream. """ client = response.sender - status = fedn.Status(status="Client {} connecting to ModelValidationRequestStream.".format(client.name)) + status = fedn.Status( + status="Client {} connecting to ModelValidationRequestStream.".format(client.name)) status.log_level = fedn.Status.INFO status.sender.name = self.id status.sender.role = role_to_proto_role(self.role) status.timestamp = str(datetime.now()) - - self._subscribe_client_to_queue(client, fedn.Channel.MODEL_VALIDATION_REQUESTS) + self._subscribe_client_to_queue( + client, fedn.Channel.MODEL_VALIDATION_REQUESTS) q = self.__get_queue(client, fedn.Channel.MODEL_VALIDATION_REQUESTS) self._send_status(status) @@ -590,14 +600,15 @@ def ModelValidationRequestStream(self, response, context): try: yield q.get(timeout=1.0) except queue.Empty: - pass + pass def SendModelUpdateRequest(self, request, context): """ Send a model update request. """ self._send_request(request, fedn.Channel.MODEL_UPDATE_REQUESTS) response = fedn.Response() - response.response = "CONTROLLER RECEIVED ModelUpdateRequest from client {}".format(request.sender.name) + response.response = "CONTROLLER RECEIVED ModelUpdateRequest from client {}".format( + request.sender.name) return response # TODO Fill later def SendModelUpdate(self, request, context): @@ -606,7 +617,8 @@ def SendModelUpdate(self, request, context): print("ORCHESTRATOR: Received model update", flush=True) response = fedn.Response() - response.response = "RECEIVED ModelUpdate {} from client {}".format(response, response.sender.name) + response.response = "RECEIVED ModelUpdate {} from client {}".format( + response, response.sender.name) return response # TODO Fill later def SendModelValidationRequest(self, request, context): @@ -614,7 +626,8 @@ def SendModelValidationRequest(self, request, context): self._send_request(request, fedn.Channel.MODEL_VALIDATION_REQUESTS) response = fedn.Response() - response.response = "CONTROLLER RECEIVED ModelValidationRequest from client {}".format(request.sender.name) + response.response = "CONTROLLER RECEIVED ModelValidationRequest from client {}".format( + request.sender.name) return response # TODO Fill later def SendModelValidation(self, request, context): @@ -622,10 +635,11 @@ def SendModelValidation(self, request, context): self.control.aggregator.on_model_validation(request) print("ORCHESTRATOR received validation ", flush=True) response = fedn.Response() - response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name) + response.response = "RECEIVED ModelValidation {} from client {}".format( + response, response.sender.name) return response # TODO Fill later - ## Reducer Service + # Reducer Service def GetGlobalModel(self, request, context): """ @@ -651,7 +665,8 @@ def run(self): """ import signal - print("COMBINER: {} started, ready for requests. ".format(self.id), flush=True) + print("COMBINER: {} started, ready for requests. ".format( + self.id), flush=True) try: while True: signal.pause() diff --git a/fedn/fedn/common/control/package.py b/fedn/fedn/common/control/package.py index 8f31bb83e..0118b4529 100644 --- a/fedn/fedn/common/control/package.py +++ b/fedn/fedn/common/control/package.py @@ -68,11 +68,14 @@ def upload(self): """ if self.package_file: - import requests import os + + import requests + # data = {'name': self.package_file, 'hash': str(self.package_hash)} # print("going to send {}".format(data),flush=True) - f = open(os.path.join(os.path.dirname(self.file_path), self.package_file), 'rb') + f = open(os.path.join(os.path.dirname( + self.file_path), self.package_file), 'rb') print("Sending the following file {}".format(f.read()), flush=True) f.seek(0, 0) files = {'file': f} @@ -82,7 +85,8 @@ def upload(self): # data=data, headers={'Authorization': 'Token {}'.format(self.reducer_token)}) except Exception as e: - print("failed to put execution context to reducer. {}".format(e), flush=True) + print("failed to put execution context to reducer. {}".format( + e), flush=True) finally: f.close() @@ -97,9 +101,9 @@ class PackageRuntime: def __init__(self, package_path, package_dir): self.dispatch_config = {'entry_points': - {'predict': {'command': 'python3 predict.py'}, - 'train': {'command': 'python3 train.py'}, - 'validate': {'command': 'python3 validate.py'}}} + {'predict': {'command': 'python3 predict.py'}, + 'train': {'command': 'python3 train.py'}, + 'validate': {'command': 'python3 validate.py'}}} self.pkg_path = package_path self.pkg_name = None @@ -125,7 +129,8 @@ def download(self, host, port, token, name=None): with requests.get(path, stream=True, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: if 200 <= r.status_code < 204: import cgi - params = cgi.parse_header(r.headers.get('Content-Disposition', ''))[-1] + params = cgi.parse_header( + r.headers.get('Content-Disposition', ''))[-1] try: self.pkg_name = params['filename'] except KeyError: @@ -186,11 +191,14 @@ def unpack(self): if self.pkg_name: f = None if self.pkg_name.endswith('tar.gz'): - f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), 'r:gz') + f = tarfile.open(os.path.join( + self.pkg_path, self.pkg_name), 'r:gz') if self.pkg_name.endswith('.tgz'): - f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), 'r:gz') + f = tarfile.open(os.path.join( + self.pkg_path, self.pkg_name), 'r:gz') if self.pkg_name.endswith('tar.bz2'): - f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), 'r:bz2') + f = tarfile.open(os.path.join( + self.pkg_path, self.pkg_name), 'r:bz2') else: print( "Failed to unpack compute package, no pkg_name set. Has the reducer been configured with a compute package?") @@ -202,7 +210,8 @@ def unpack(self): if f: f.extractall() - print("Successfully extracted compute package content in {}".format(self.dir), flush=True) + print("Successfully extracted compute package content in {}".format( + self.dir), flush=True) except: print("Error extracting files!") @@ -213,9 +222,11 @@ def dispatcher(self, run_path): :return: """ from_path = os.path.join(os.getcwd(), 'client') - + from distutils.dir_util import copy_tree - copy_tree(from_path, run_path, preserve_times=False) # preserve_times=False ensures compatibility with Gramine LibOS + + # preserve_times=False ensures compatibility with Gramine LibOS + copy_tree(from_path, run_path, preserve_times=False) try: cfg = None @@ -225,7 +236,8 @@ def dispatcher(self, run_path): self.dispatch_config = cfg except Exception as e: - print("Error trying to load and unpack dispatcher config - trying default", flush=True) + print( + "Error trying to load and unpack dispatcher config - trying default", flush=True) dispatcher = Dispatcher(self.dispatch_config, run_path) diff --git a/fedn/fedn/common/net/connect.py b/fedn/fedn/common/net/connect.py index 0e4a3e806..ebfd15bc5 100644 --- a/fedn/fedn/common/net/connect.py +++ b/fedn/fedn/common/net/connect.py @@ -3,6 +3,8 @@ import requests as r +from fedn.common.security.certificate import Certificate + class State(enum.Enum): Disconnected = 0 @@ -15,10 +17,7 @@ class Status(enum.Enum): Assigned = 1 TryAgain = 2 UnAuthorized = 3 - UnMatchedConfig = 4 - - -from fedn.common.security.certificate import Certificate + UnMatchedConfig = 4 class ConnectorClient: @@ -56,8 +55,10 @@ def __init__(self, host, port, token, name, remote_package, combiner=None, id=No self.verify_cert = False self.prefix = prefix - self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) - print("\n\nsetting the connection string to {}\n\n".format(self.connect_string), flush=True) + self.connect_string = "{}{}:{}".format( + self.prefix, self.host, self.port) + print("\n\nsetting the connection string to {}\n\n".format( + self.connect_string), flush=True) if self.certificate: print("Securely connecting with certificate", flush=True) @@ -87,15 +88,16 @@ def assign(self): print('***** {}'.format(e), flush=True) # self.state = State.Disconnected return Status.Unassigned, {} - + if retval.status_code == 401: reason = "Unauthorized connection to reducer, make sure the correct token is set" return Status.UnAuthorized, reason - + reducer_package = retval.json()['package'] if reducer_package != self.package: - reason = "Unmatched config of compute package between client and reducer.\n"+\ - "Reducer uses {} package and client uses {}.".format(reducer_package, self.package) + reason = "Unmatched config of compute package between client and reducer.\n" +\ + "Reducer uses {} package and client uses {}.".format( + reducer_package, self.package) return Status.UnMatchedConfig, reason if retval.status_code >= 200 and retval.status_code < 204: @@ -145,8 +147,10 @@ def __init__(self, host, port, myhost, myport, token, name, secure=True, preshar self.verify_cert = False self.prefix = prefix - self.connect_string = "{}{}:{}".format(self.prefix, self.host, self.port) - print("\n\nsetting the connection string to {}\n\n".format(self.connect_string), flush=True) + self.connect_string = "{}{}:{}".format( + self.prefix, self.host, self.port) + print("\n\nsetting the connection string to {}\n\n".format( + self.connect_string), flush=True) print("Securely connecting with certificate", flush=True) def state(self): @@ -172,11 +176,11 @@ def announce(self): except Exception as e: # self.state = State.Disconnected return Status.Unassigned, {} - + if retval.status_code == 401: reason = "Unauthorized connection to reducer, make sure the correct token is set" return Status.UnAuthorized, reason - + if retval.status_code >= 200 and retval.status_code < 204: if retval.json()['status'] == 'retry': reason = "Reducer was not ready. Try again later." diff --git a/fedn/fedn/common/net/grpc/fedn_pb2.py b/fedn/fedn/common/net/grpc/fedn_pb2.py index 64a7a6ddf..f53fd40e6 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2.py @@ -2,11 +2,11 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: fedn/common/net/grpc/fedn.proto """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import enum_type_wrapper # @@protoc_insertion_point(imports) diff --git a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py index 072493cc5..9989824f7 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py @@ -2,7 +2,8 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from fedn.common.net.grpc import fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 +from fedn.common.net.grpc import \ + fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 class ModelServiceStub(object): diff --git a/fedn/fedn/common/net/grpc/server.py b/fedn/fedn/common/net/grpc/server.py index 1b15f7383..79a251de1 100644 --- a/fedn/fedn/common/net/grpc/server.py +++ b/fedn/fedn/common/net/grpc/server.py @@ -1,8 +1,9 @@ from concurrent import futures -import fedn.common.net.grpc.fedn_pb2_grpc as rpc import grpc +import fedn.common.net.grpc.fedn_pb2_grpc as rpc + class Server: """ @@ -26,15 +27,17 @@ def __init__(self, servicer, modelservicer, config): rpc.add_ControlServicer_to_server(servicer, self.server) if config['secure']: - from fedn.common.security.certificate import Certificate import os - # self.certificate = Certificate(os.getcwd() + '/certs/', cert_name='combiner-cert.pem', key_name='combiner-key.pem') + from fedn.common.security.certificate import Certificate + + # self.certificate = Certificate(os.getcwd() + '/certs/', cert_name='combiner-cert.pem', key_name='combiner-key.pem') # self.certificate.set_keypair_raw(config['certificate'], config['key']) server_credentials = grpc.ssl_server_credentials( ((config['key'], config['certificate'],),)) - self.server.add_secure_port('[::]:' + str(config['port']), server_credentials) + self.server.add_secure_port( + '[::]:' + str(config['port']), server_credentials) else: self.server.add_insecure_port('[::]:' + str(config['port'])) diff --git a/fedn/fedn/common/net/web/__init__.py b/fedn/fedn/common/net/web/__init__.py index ff1b58e20..c87d94d17 100644 --- a/fedn/fedn/common/net/web/__init__.py +++ b/fedn/fedn/common/net/web/__init__.py @@ -1,5 +1,6 @@ -from os.path import dirname, basename, isfile import glob +from os.path import basename, dirname, isfile modules = glob.glob(dirname(__file__) + "/*.py") -__all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] +__all__ = [basename(f)[:-3] for f in modules if isfile(f) + and not f.endswith('__init__.py')] diff --git a/fedn/fedn/common/security/certificate.py b/fedn/fedn/common/security/certificate.py index cd2412d3a..bf90645f7 100644 --- a/fedn/fedn/common/security/certificate.py +++ b/fedn/fedn/common/security/certificate.py @@ -18,7 +18,8 @@ def __init__(self, cwd, name=None, key_name="key.pem", cert_name="cert.pem", cre except OSError: print("Directory exists, will store all cert and keys here.") else: - print("Successfully created the directory to store cert and keys in {}".format(cwd)) + print( + "Successfully created the directory to store cert and keys in {}".format(cwd)) self.key_path = os.path.join(cwd, key_name) self.cert_path = os.path.join(cwd, cert_name) import uuid @@ -62,10 +63,12 @@ def set_keypair_raw(self, certificate, privatekey): :param privatekey: """ with open(self.key_path, "wb") as keyfile: - keyfile.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, privatekey)) + keyfile.write(crypto.dump_privatekey( + crypto.FILETYPE_PEM, privatekey)) with open(self.cert_path, "wb") as certfile: - certfile.write(crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)) + certfile.write(crypto.dump_certificate( + crypto.FILETYPE_PEM, certificate)) def get_keypair_raw(self): """ diff --git a/fedn/fedn/common/security/certificatemanager.py b/fedn/fedn/common/security/certificatemanager.py index acfeb6e52..5637d7344 100644 --- a/fedn/fedn/common/security/certificatemanager.py +++ b/fedn/fedn/common/security/certificatemanager.py @@ -23,7 +23,8 @@ def get_or_create(self, name): if search: return search else: - cert = Certificate(self.directory, name=name, cert_name=name + '-cert.pem', key_name=name + '-key.pem') + cert = Certificate(self.directory, name=name, + cert_name=name + '-cert.pem', key_name=name + '-key.pem') cert.gen_keypair() self.certificates.append(cert) return cert @@ -52,7 +53,8 @@ def load_all(self): # print("trying with {}".format(key_name)) if os.path.isfile(os.path.join(self.directory, key_name)): - c = Certificate(self.directory, name=name, cert_name=filename, key_name=key_name) + c = Certificate(self.directory, name=name, + cert_name=filename, key_name=key_name) self.certificates.append(c) else: c = Certificate(self.directory, name=name, cert_name=filename, diff --git a/fedn/fedn/common/storage/models/memorymodelstorage.py b/fedn/fedn/common/storage/models/memorymodelstorage.py index 4e8047a6e..27ad9c8be 100644 --- a/fedn/fedn/common/storage/models/memorymodelstorage.py +++ b/fedn/fedn/common/storage/models/memorymodelstorage.py @@ -1,6 +1,7 @@ -from fedn.common.storage.models.modelstorage import ModelStorage -from collections import defaultdict import io +from collections import defaultdict + +from fedn.common.storage.models.modelstorage import ModelStorage CHUNK_SIZE = 1024 * 1024 @@ -12,6 +13,7 @@ class MemoryModelStorage(ModelStorage): def __init__(self): import tempfile + # self.dir = tempfile.TemporaryDirectory() self.models = defaultdict(io.BytesIO) self.models_metadata = {} diff --git a/fedn/fedn/common/storage/models/tempmodelstorage.py b/fedn/fedn/common/storage/models/tempmodelstorage.py index 89232c27d..681573b94 100644 --- a/fedn/fedn/common/storage/models/tempmodelstorage.py +++ b/fedn/fedn/common/storage/models/tempmodelstorage.py @@ -1,10 +1,10 @@ +import os + import fedn.common.net.grpc.fedn_pb2 as fedn from fedn.common.storage.models.modelstorage import ModelStorage CHUNK_SIZE = 1024 * 1024 -import os - class TempModelStorage(ModelStorage): """ @@ -13,7 +13,8 @@ class TempModelStorage(ModelStorage): def __init__(self): - self.default_dir = os.environ.get('FEDN_MODEL_DIR', '/tmp/models') # set default to tmp + self.default_dir = os.environ.get( + 'FEDN_MODEL_DIR', '/tmp/models') # set default to tmp if not os.path.exists(self.default_dir): os.makedirs(self.default_dir) diff --git a/fedn/fedn/common/storage/s3/miniorepo.py b/fedn/fedn/common/storage/s3/miniorepo.py index c0977e7e8..cb489de89 100644 --- a/fedn/fedn/common/storage/s3/miniorepo.py +++ b/fedn/fedn/common/storage/s3/miniorepo.py @@ -1,14 +1,15 @@ +import io +import json +import logging import os +import uuid +from urllib.parse import urlparse + import requests -from .base import Repository from minio import Minio from minio.error import InvalidResponseError -import io -import logging -from urllib.parse import urlparse -import uuid -import json +from .base import Repository logger = logging.getLogger(__name__) @@ -43,11 +44,13 @@ def __init__(self, config): self.secure_mode = False if not self.secure_mode: - print("\n\n\nWARNING : S3/MINIO RUNNING IN **INSECURE** MODE! THIS IS NOT FOR PRODUCTION!\n\n\n") + print( + "\n\n\nWARNING : S3/MINIO RUNNING IN **INSECURE** MODE! THIS IS NOT FOR PRODUCTION!\n\n\n") if self.secure_mode: from urllib3.poolmanager import PoolManager - manager = PoolManager(num_pools=100, cert_reqs='CERT_NONE', assert_hostname=False) + manager = PoolManager( + num_pools=100, cert_reqs='CERT_NONE', assert_hostname=False) self.client = Minio("{0}:{1}".format(config['storage_hostname'], config['storage_port']), access_key=access_key, secret_key=secret_key, @@ -83,7 +86,8 @@ def set_artifact(self, instance_name, instance, is_file=False, bucket=''): self.client.fput_object(bucket, instance_name, instance) else: try: - self.client.put_object(bucket, instance_name, io.BytesIO(instance), len(instance)) + self.client.put_object( + bucket, instance_name, io.BytesIO(instance), len(instance)) except Exception as e: raise Exception("Could not load data into bytes {}".format(e)) @@ -129,7 +133,8 @@ def list_artifacts(self): print(obj.object_name) objects_to_delete.append(obj.object_name) except Exception as e: - raise Exception("Could not list models in bucket {}".format(self.bucket)) + raise Exception( + "Could not list models in bucket {}".format(self.bucket)) return objects_to_delete def delete_artifact(self, instance_name, bucket=[]): diff --git a/fedn/fedn/common/storage/s3/s3repo.py b/fedn/fedn/common/storage/s3/s3repo.py index 89518818d..d572be4ba 100644 --- a/fedn/fedn/common/storage/s3/s3repo.py +++ b/fedn/fedn/common/storage/s3/s3repo.py @@ -15,7 +15,8 @@ def get_model(self, model_id): :param model_id: :return: """ - print("Client {} trying to get model with id: {}".format(self.client, model_id), flush=True) + print("Client {} trying to get model with id: {}".format( + self.client, model_id), flush=True) return self.get_artifact(model_id) def get_model_stream(self, model_id): @@ -24,7 +25,8 @@ def get_model_stream(self, model_id): :param model_id: :return: """ - print("Client {} trying to get model with id: {}".format(self.client, model_id), flush=True) + print("Client {} trying to get model with id: {}".format( + self.client, model_id), flush=True) return self.get_artifact_stream(model_id) def set_model(self, model, is_file=True): @@ -38,7 +40,8 @@ def set_model(self, model, is_file=True): model_id = uuid.uuid4() # TODO: Check that this call succeeds try: - self.set_artifact(str(model_id), model, bucket=self.bucket, is_file=is_file) + self.set_artifact(str(model_id), model, + bucket=self.bucket, is_file=is_file) except Exception as e: print("Failed to write model with ID {} to repository.".format(model_id)) raise @@ -52,7 +55,8 @@ def set_compute_context(self, name, compute_package, is_file=True): :param is_file: """ try: - self.set_artifact(str(name), compute_package, bucket="fedn-context", is_file=is_file) + self.set_artifact(str(name), compute_package, + bucket="fedn-context", is_file=is_file) except Exception as e: print("Failed to write compute_package to repository.") raise diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index be401d9b7..0af957e79 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -1,10 +1,12 @@ -from fedn.common.tracer.tracer import Tracer -from fedn.common.storage.db.mongo import connect_to_mongodb -import time import threading -import psutil +import time from datetime import datetime +import psutil + +from fedn.common.storage.db.mongo import connect_to_mongodb +from fedn.common.tracer.tracer import Tracer + class MongoTracer(Tracer): """ @@ -37,7 +39,7 @@ def report(self, msg): print("LOG: \n {} \n".format(data), flush=True) - if self.status!=None: + if self.status != None: self.status.insert_one(data) def drop_round_time(self): @@ -95,8 +97,10 @@ def set_latest_time(self, round, round_time): :param round: :param round_time: """ - self.round_time.update_one({'key': 'round_time'}, {'$push': {'round': round}}, True) - self.round_time.update_one({'key': 'round_time'}, {'$push': {'round_time': round_time}}, True) + self.round_time.update_one({'key': 'round_time'}, { + '$push': {'round': round}}, True) + self.round_time.update_one({'key': 'round_time'}, { + '$push': {'round_time': round_time}}, True) def set_combiner_time(self, round, round_time): """ @@ -104,8 +108,10 @@ def set_combiner_time(self, round, round_time): :param round: :param round_time: """ - self.combiner_round_time.update_one({'key': 'combiner_round_time'}, {'$push': {'round': round}}, True) - self.combiner_round_time.update_one({'key': 'combiner_round_time'}, {'$push': {'round_time': round_time}}, True) + self.combiner_round_time.update_one({'key': 'combiner_round_time'}, { + '$push': {'round': round}}, True) + self.combiner_round_time.update_one({'key': 'combiner_round_time'}, { + '$push': {'round_time': round_time}}, True) # def set_combiner_queue_length(self,timestamp,ql): # self.combiner_queue_length({'key': 'combiner_queue_length'}, {'$push': {'queue_length': ql}}, True) @@ -117,14 +123,16 @@ def set_round_meta(self, round_meta): :param round_meta: """ - self.round.update_one({'key': str(round_meta['round_id'])}, {'$push': {'combiners': round_meta}}, True) + self.round.update_one({'key': str(round_meta['round_id'])}, { + '$push': {'combiners': round_meta}}, True) def set_round_meta_reducer(self, round_meta): """ :param round_meta: """ - self.round.update_one({'key': str(round_meta['round_id'])}, {'$push': {'reducer': round_meta}}, True) + self.round.update_one({'key': str(round_meta['round_id'])}, { + '$push': {'reducer': round_meta}}, True) def get_latest_round(self): """ @@ -149,10 +157,14 @@ def ps_util_monitor(self, round=None): mem_percents = currentProcess.memory_percent() ps_time = str(datetime.now()) - self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, {'$push': {'cpu': cpu_percents}}, True) - self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, {'$push': {'mem': mem_percents}}, True) - self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, {'$push': {'time': ps_time}}, True) - self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, {'$push': {'round': round}}, True) + self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, { + '$push': {'cpu': cpu_percents}}, True) + self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, { + '$push': {'mem': mem_percents}}, True) + self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, { + '$push': {'time': ps_time}}, True) + self.psutil_monitoring.update_one({'key': 'cpu_mem_usage'}, { + '$push': {'round': round}}, True) def start_monitor(self, round=None): """ diff --git a/fedn/fedn/reducer.py b/fedn/fedn/reducer.py index ce2592b7c..1394c5904 100644 --- a/fedn/fedn/reducer.py +++ b/fedn/fedn/reducer.py @@ -5,8 +5,9 @@ from fedn.clients.reducer.interfaces import ReducerInferenceInterface from fedn.clients.reducer.restservice import ReducerRestService from fedn.clients.reducer.state import ReducerStateToString +from fedn.clients.reducer.statestore.mongoreducerstatestore import \ + MongoReducerStateStore from fedn.common.security.certificatemanager import CertificateManager -from fedn.clients.reducer.statestore.mongoreducerstatestore import MongoReducerStateStore class InvalidReducerConfiguration(Exception): @@ -42,7 +43,8 @@ def __init__(self, statestore): self.control = ReducerControl(self.statestore) self.inference = ReducerInferenceInterface() rest_certificate = self.certificate_manager.get_or_create("reducer") - self.rest = ReducerRestService(config, self.control, self.certificate_manager, certificate=rest_certificate) + self.rest = ReducerRestService( + config, self.control, self.certificate_manager, certificate=rest_certificate) def run(self): """ diff --git a/fedn/fedn/tests/test_reducer_service.py b/fedn/fedn/tests/test_reducer_service.py index 199b3da08..93cf761b1 100644 --- a/fedn/fedn/tests/test_reducer_service.py +++ b/fedn/fedn/tests/test_reducer_service.py @@ -1,7 +1,8 @@ import unittest +from unittest.mock import MagicMock, patch + from fedn.clients.reducer.restservice import ReducerRestService from fedn.clients.reducer.state import ReducerState -from unittest.mock import MagicMock, patch class TestInit(unittest.TestCase): @@ -18,7 +19,7 @@ def test_discover_host(self, mock_control): restservice = ReducerRestService(CONFIG, mock_control, None) self.assertEqual(restservice.name, 'TEST_HOST') self.assertEqual(restservice.network_id, 'TEST_NAME-network') - + @patch('fedn.clients.reducer.control.ReducerControl') def test_name(self, mock_control): CONFIG = { @@ -30,7 +31,7 @@ def test_name(self, mock_control): } restservice = ReducerRestService(CONFIG, mock_control, None) self.assertEqual(restservice.name, 'TEST_NAME') - + @patch('fedn.clients.reducer.control.ReducerControl') def test_network_id(self, mock_control): CONFIG = { @@ -43,6 +44,7 @@ def test_network_id(self, mock_control): restservice = ReducerRestService(CONFIG, mock_control, None) self.assertEqual(restservice.network_id, 'TEST_NAME-network') + class TestChecks(unittest.TestCase): @patch('fedn.clients.reducer.control.ReducerControl') def setUp(self, mock_control): @@ -58,7 +60,8 @@ def setUp(self, mock_control): def test_check_compute_package(self): - self.restservice.control.get_compute_context.return_value = {'NOT': 'NONE'} + self.restservice.control.get_compute_context.return_value = { + 'NOT': 'NONE'} retval = self.restservice.check_compute_context() self.assertTrue(retval) @@ -70,7 +73,7 @@ def test_check_compute_package(self): retval = self.restservice.check_compute_context() self.assertFalse(retval) - self.restservice.remote_compute_context = False + self.restservice.remote_compute_context = False retval = self.restservice.check_compute_context() self.assertTrue(retval) @@ -102,21 +105,19 @@ def setUp(self, mock_control): } self.restservice = ReducerRestService(CONFIG, mock_control, None) - + def test_encode_decode_auth_token(self): SECRET_KEY = 'test_secret' token = self.restservice.encode_auth_token(SECRET_KEY) payload_success = self.restservice.decode_auth_token(token, SECRET_KEY) - payload_invalid = self.restservice.decode_auth_token('wrong_token', SECRET_KEY) + payload_invalid = self.restservice.decode_auth_token( + 'wrong_token', SECRET_KEY) payload_error = self.restservice.decode_auth_token(token, 'wrong_key') self.assertEqual(payload_success, "Success") self.assertEqual(payload_invalid, "Invalid token.") self.assertEqual(payload_error, "Invalid token.") - - - if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/fedn/fedn/utils/dispatcher.py b/fedn/fedn/utils/dispatcher.py index d2b8d1332..228c4f4bd 100644 --- a/fedn/fedn/utils/dispatcher.py +++ b/fedn/fedn/utils/dispatcher.py @@ -1,5 +1,6 @@ -import re import logging +import re + from fedn.utils.process import run_process logger = logging.getLogger(__name__) diff --git a/fedn/fedn/utils/helpers.py b/fedn/fedn/utils/helpers.py index 52cd5af6a..6e4d990a3 100644 --- a/fedn/fedn/utils/helpers.py +++ b/fedn/fedn/utils/helpers.py @@ -1,7 +1,7 @@ import collections -from abc import ABC, abstractmethod import os import tempfile +from abc import ABC, abstractmethod class HelperBase(ABC): diff --git a/fedn/fedn/utils/kerashelper.py b/fedn/fedn/utils/kerashelper.py index 9c4613fe1..5d5b2a164 100644 --- a/fedn/fedn/utils/kerashelper.py +++ b/fedn/fedn/utils/kerashelper.py @@ -1,8 +1,8 @@ +import collections import os import tempfile + import numpy as np -import collections -import tempfile from .helpers import HelperBase diff --git a/fedn/fedn/utils/logger.py b/fedn/fedn/utils/logger.py index 00d0fd61e..22e50ebc1 100644 --- a/fedn/fedn/utils/logger.py +++ b/fedn/fedn/utils/logger.py @@ -22,6 +22,7 @@ def __init__(self, log_level=logging.DEBUG, to_file='', file_path=os.getcwd()): root.addHandler(sh) if to_file != '': - fh = logging.FileHandler(os.path.join(file_path, '{}'.format(to_file))) + fh = logging.FileHandler(os.path.join( + file_path, '{}'.format(to_file))) fh.setFormatter(logging.Formatter(log_format)) root.addHandler(fh) diff --git a/fedn/fedn/utils/numpyarrayhelper.py b/fedn/fedn/utils/numpyarrayhelper.py index d54574206..1f481a309 100644 --- a/fedn/fedn/utils/numpyarrayhelper.py +++ b/fedn/fedn/utils/numpyarrayhelper.py @@ -1,10 +1,10 @@ +import collections import os -import tempfile -import numpy as np import pickle -import collections import tempfile +import numpy as np + from .helpers import HelperBase diff --git a/fedn/fedn/utils/process.py b/fedn/fedn/utils/process.py index 11850a1ae..5a005e0cd 100644 --- a/fedn/fedn/utils/process.py +++ b/fedn/fedn/utils/process.py @@ -1,5 +1,5 @@ -import subprocess import logging +import subprocess logger = logging.getLogger() @@ -10,7 +10,8 @@ def run_process(args, cwd): :param args: :param cwd: """ - status = subprocess.Popen(args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + status = subprocess.Popen( + args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) # print(status) def check_io(): diff --git a/fedn/fedn/utils/pytorchhelper.py b/fedn/fedn/utils/pytorchhelper.py index 3066a2763..7d816bc8c 100644 --- a/fedn/fedn/utils/pytorchhelper.py +++ b/fedn/fedn/utils/pytorchhelper.py @@ -1,10 +1,12 @@ import os import tempfile from collections import OrderedDict -from .helpers import HelperBase from functools import reduce + import numpy as np +from .helpers import HelperBase + class PytorchHelper(HelperBase): diff --git a/fedn/setup.py b/fedn/setup.py index 861bf98f9..611d28ead 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name='fedn',