Skip to content

Commit

Permalink
Feature/SK-838 | Ruff linting (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored May 7, 2024
1 parent 5f1f43a commit cbc7dbc
Show file tree
Hide file tree
Showing 76 changed files with 1,609 additions and 1,960 deletions.
32 changes: 13 additions & 19 deletions examples/async-clients/client/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics

HELPER_MODULE = 'numpyhelper'
HELPER_MODULE = "numpyhelper"
ARRAY_SIZE = 10000


Expand All @@ -22,7 +22,7 @@ def compile_model(max_iter=1):


def save_parameters(model, out_path):
""" Save model to disk.
"""Save model to disk.
:param model: The model to save.
:type model: torch.nn.Module
Expand All @@ -36,7 +36,7 @@ def save_parameters(model, out_path):


def load_parameters(model_path):
""" Load model from disk.
"""Load model from disk.
param model_path: The path to load from.
:type model_path: str
Expand All @@ -49,8 +49,8 @@ def load_parameters(model_path):
return parameters


def init_seed(out_path='seed.npz'):
""" Initialize seed model.
def init_seed(out_path="seed.npz"):
"""Initialize seed model.
:param out_path: The path to save the seed model to.
:type out_path: str
Expand All @@ -61,7 +61,7 @@ def init_seed(out_path='seed.npz'):


def make_data(n_min=50, n_max=100):
""" Generate / simulate a random number n data points.
"""Generate / simulate a random number n data points.
n will fall in the interval (n_min, n_max)
Expand All @@ -78,14 +78,12 @@ def make_data(n_min=50, n_max=100):


def train(in_model_path, out_model_path):
""" Train model.
"""
"""Train model."""

# Load model
parameters = load_parameters(in_model_path)
model = compile_model()
n = len(parameters)//2
n = len(parameters) // 2
model.coefs_ = parameters[:n]
model.intercepts_ = parameters[n:]

Expand All @@ -97,7 +95,7 @@ def train(in_model_path, out_model_path):

# Metadata needed for aggregation server side
metadata = {
'num_examples': len(X_train),
"num_examples": len(X_train),
}

# Save JSON metadata file
Expand All @@ -108,7 +106,7 @@ def train(in_model_path, out_model_path):


def validate(in_model_path, out_json_path):
""" Validate model.
"""Validate model.
:param in_model_path: The path to the input model.
:type in_model_path: str
Expand All @@ -119,7 +117,7 @@ def validate(in_model_path, out_json_path):
"""
parameters = load_parameters(in_model_path)
model = compile_model()
n = len(parameters)//2
n = len(parameters) // 2
model.coefs_ = parameters[:n]
model.intercepts_ = parameters[n:]

Expand All @@ -134,9 +132,5 @@ def validate(in_model_path, out_json_path):
save_metrics(report, out_json_path)


if __name__ == '__main__':
fire.Fire({
'init_seed': init_seed,
'train': train,
'validate': validate
})
if __name__ == "__main__":
fire.Fire({"init_seed": init_seed, "train": train, "validate": validate})
6 changes: 3 additions & 3 deletions examples/async-clients/init_fedn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fedn import APIClient

DISCOVER_HOST = '127.0.0.1'
DISCOVER_HOST = "127.0.0.1"
DISCOVER_PORT = 8092

client = APIClient(DISCOVER_HOST, DISCOVER_PORT)
client.set_active_package('package.tgz', 'numpyhelper')
client.set_active_model('seed.npz')
client.set_active_package("package.tgz", "numpyhelper")
client.set_active_model("seed.npz")
62 changes: 41 additions & 21 deletions examples/async-clients/run_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,39 @@

# Use with a local deployment
settings = {
'DISCOVER_HOST': '127.0.0.1',
'DISCOVER_PORT': 8092,
'TOKEN': None,
'N_CLIENTS': 10,
'N_CYCLES': 100,
'CLIENTS_MAX_DELAY': 10,
'CLIENTS_ONLINE_FOR_SECONDS': 120
"DISCOVER_HOST": "127.0.0.1",
"DISCOVER_PORT": 8092,
"TOKEN": None,
"N_CLIENTS": 10,
"N_CYCLES": 100,
"CLIENTS_MAX_DELAY": 10,
"CLIENTS_ONLINE_FOR_SECONDS": 120,
}

client_config = {'discover_host': settings['DISCOVER_HOST'], 'discover_port': settings['DISCOVER_PORT'], 'token': settings['TOKEN'], 'name': 'testclient',
'client_id': 1, 'remote_compute_context': True, 'force_ssl': False, 'dry_run': False, 'secure': False,
'preshared_cert': False, 'verify': False, 'preferred_combiner': False,
'validator': True, 'trainer': True, 'init': None, 'logfile': 'test.log', 'heartbeat_interval': 2,
'reconnect_after_missed_heartbeat': 30}
client_config = {
"discover_host": settings["DISCOVER_HOST"],
"discover_port": settings["DISCOVER_PORT"],
"token": settings["TOKEN"],
"name": "testclient",
"client_id": 1,
"remote_compute_context": True,
"force_ssl": False,
"dry_run": False,
"secure": False,
"preshared_cert": False,
"verify": False,
"preferred_combiner": False,
"validator": True,
"trainer": True,
"init": None,
"logfile": "test.log",
"heartbeat_interval": 2,
"reconnect_after_missed_heartbeat": 30,
}


def run_client(online_for=120, name='client'):
""" Simulates a client that starts and stops
def run_client(online_for=120, name="client"):
"""Simulates a client that starts and stops
at random intervals.
The client will start after a radom time 'mean_delay',
Expand All @@ -55,23 +70,28 @@ def run_client(online_for=120, name='client'):
"""

conf = copy.deepcopy(client_config)
conf['name'] = name
conf["name"] = name

for i in range(settings['N_CYCLES']):
for i in range(settings["N_CYCLES"]):
# Sample a delay until the client starts
t_start = np.random.randint(0, settings['CLIENTS_MAX_DELAY'])
t_start = np.random.randint(0, settings["CLIENTS_MAX_DELAY"])
time.sleep(t_start)
fl_client = Client(conf)
time.sleep(online_for)
fl_client.disconnect()


if __name__ == '__main__':

if __name__ == "__main__":
# We start N_CLIENTS independent client processes
processes = []
for i in range(settings['N_CLIENTS']):
p = Process(target=run_client, args=(settings['CLIENTS_ONLINE_FOR_SECONDS'], 'client{}'.format(i),))
for i in range(settings["N_CLIENTS"]):
p = Process(
target=run_client,
args=(
settings["CLIENTS_ONLINE_FOR_SECONDS"],
"client{}".format(i),
),
)
processes.append(p)
p.start()

Expand Down
12 changes: 5 additions & 7 deletions examples/async-clients/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

from fedn import APIClient

DISCOVER_HOST = '127.0.0.1'
DISCOVER_HOST = "127.0.0.1"
DISCOVER_PORT = 8092
client = APIClient(DISCOVER_HOST, DISCOVER_PORT)

if __name__ == '__main__':

if __name__ == "__main__":
# Run six sessions, each with 100 rounds.
num_sessions = 6
for s in range(num_sessions):

session_config = {
"helper": "numpyhelper",
"id": str(uuid.uuid4()),
Expand All @@ -23,12 +21,12 @@
}

session = client.start_session(**session_config)
if session['success'] is False:
print(session['message'])
if session["success"] is False:
print(session["message"])
exit(0)

print("Started session: {}".format(session))

# Wait for session to finish
while not client.session_is_finished(session_config['id']):
while not client.session_is_finished(session_config["id"]):
time.sleep(2)
4 changes: 1 addition & 3 deletions examples/flower-client/client/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def train(in_model_path, out_model_path):
parameters_np = helper.load(in_model_path)

# Train on flower client
params, num_examples = flwr_adapter.train(
parameters=parameters_np, partition_id=_get_node_id(), config={}
)
params, num_examples = flwr_adapter.train(parameters=parameters_np, partition_id=_get_node_id(), config={})

# Metadata needed for aggregation server side
metadata = {
Expand Down
7 changes: 2 additions & 5 deletions examples/flower-client/client/flwr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
"""

from flwr.client import ClientApp, NumPyClient
from flwr_task import (DEVICE, Net, get_weights, load_data, set_weights, test,
train)
from flwr_task import DEVICE, Net, get_weights, load_data, set_weights, test, train


# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
def __init__(self, cid) -> None:
super().__init__()
self.net = Net().to(DEVICE)
self.trainloader, self.testloader = load_data(
partition_id=int(cid), num_clients=10
)
self.trainloader, self.testloader = load_data(partition_id=int(cid), num_clients=10)

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
Expand Down
6 changes: 3 additions & 3 deletions examples/flower-client/init_fedn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fedn import APIClient

DISCOVER_HOST = '127.0.0.1'
DISCOVER_HOST = "127.0.0.1"
DISCOVER_PORT = 8092

client = APIClient(DISCOVER_HOST, DISCOVER_PORT)
client.set_package('package.tgz', 'numpyhelper')
client.set_initial_model('seed.npz')
client.set_package("package.tgz", "numpyhelper")
client.set_initial_model("seed.npz")
Loading

0 comments on commit cbc7dbc

Please sign in to comment.