Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 10, 2024
1 parent ab88fe5 commit f267971
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import json
import queue
import traceback
from abc import ABC, abstractmethod

from fedn.common.log_config import logger
Expand All @@ -9,7 +10,7 @@


class AggregatorBase(ABC):
""" Abstract class defining an aggregator.
"""Abstract class defining an aggregator.
:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
:type id: str
Expand All @@ -25,7 +26,7 @@ class AggregatorBase(ABC):

@abstractmethod
def __init__(self, storage, server, modelservice, round_handler):
""" Initialize the aggregator."""
"""Initialize the aggregator."""
self.name = self.__class__.__name__
self.storage = storage
self.server = server
Expand Down Expand Up @@ -75,11 +76,13 @@ def on_model_update(self, model_update):
else:
logger.warning("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
except Exception as e:
logger.error("AGGREGATOR({}): failed to receive model update! {}".format(self.name, e))
tb = traceback.format_exc()
logger.error("AGGREGATOR({}): failed to receive model update: {}".format(self.name, e))
logger.error(tb)
pass

def _validate_model_update(self, model_update):
""" Validate the model update.
"""Validate the model update.
:param model_update: A ModelUpdate message.
:type model_update: object
Expand All @@ -88,14 +91,16 @@ def _validate_model_update(self, model_update):
"""
try:
data = json.loads(model_update.meta)["training_metadata"]
num_examples = data["num_examples"]
except KeyError as e:
_ = data["num_examples"]
except KeyError:
tb = traceback.format_exc()
logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name))
logger.error(tb)
return False
return True

def next_model_update(self):
""" Get the next model update from the queue.
"""Get the next model update from the queue.
:param helper: A helper object.
:type helper: object
Expand All @@ -106,7 +111,7 @@ def next_model_update(self):
return model_update

def load_model_update(self, model_update, helper):
""" Load the memory representation of the model update.
"""Load the memory representation of the model update.
Load the model update paramters and the
associate metadata into memory.
Expand Down Expand Up @@ -134,15 +139,13 @@ def load_model_update(self, model_update, helper):
return model, training_metadata

def get_state(self):
""" Get the state of the aggregator's queue, including the number of model updates."""
state = {
"queue_len": self.model_updates.qsize()
}
"""Get the state of the aggregator's queue, including the number of model updates."""
state = {"queue_len": self.model_updates.qsize()}
return state


def get_aggregator(aggregator_module_name, storage, server, modelservice, control):
""" Return an instance of the helper class.
"""Return an instance of the helper class.
:param helper_module_name: The name of the helper plugin module.
:type helper_module_name: str
Expand Down

0 comments on commit f267971

Please sign in to comment.