Skip to content

Commit

Permalink
Revert changes in aggregator, setup.py
Browse files Browse the repository at this point in the history
Signed-off-by: Chaurasiya, Payal <[email protected]>
  • Loading branch information
payalcha committed Dec 11, 2024
1 parent 1a20951 commit 8d20e42
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
31 changes: 30 additions & 1 deletion openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.utilities import TaskResultKey, TensorKey, change_tags
from openfl.utilities.logs import get_memory_usage, write_memory_usage_to_file
from openfl.utilities.logs import get_memory_usage, write_memory_usage_to_file, write_metric


class Aggregator:
Expand All @@ -39,6 +39,7 @@ class Aggregator:
db_store_rounds* (int): Rounds to store in TensorDB.
logger: Object for logging.
write_logs (bool): Flag to enable log writing.
log_metric_callback: Callback for logging metrics.
best_model_score (optional): Score of the best model. Defaults to
None.
metric_queue (queue.Queue): Queue for metrics.
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
db_store_rounds=1,
write_logs=False,
log_memory_usage=False,
log_metric_callback=None,
initial_tensor_dict=None,
):
"""Initializes the Aggregator.
Expand Down Expand Up @@ -104,6 +106,8 @@ def __init__(
Defaults to 1.
write_logs (bool, optional): Whether to write logs. Defaults to
False.
log_metric_callback (optional): Callback for log metric. Defaults
to None.
**kwargs: Additional keyword arguments.
"""
self.logger = getLogger(__name__)
Expand Down Expand Up @@ -143,6 +147,13 @@ def __init__(

# Gathered together logging-related objects
self.write_logs = write_logs
self.log_metric_callback = log_metric_callback

if self.write_logs:
self.log_metric = write_metric
if self.log_metric_callback:
self.log_metric = log_metric_callback
self.logger.info("Using custom log metric: %s", self.log_metric)

self.best_model_score = None
self.metric_queue = queue.Queue()
Expand Down Expand Up @@ -653,6 +664,14 @@ def send_local_task_results(
}
self.metric_queue.put(metrics)
self.logger.metric("%s", str(metrics))
if self.write_logs:
self.log_metric(
collaborator_name,
task_name,
tensor_key.tensor_name,
float(value),
round_number,
)

task_results.append(tensor_key)

Expand Down Expand Up @@ -698,7 +717,9 @@ def _process_named_tensor(self, named_tensor, collaborator_name):
Returns:
tensor_key (TensorKey): Tensor key.
The tensorkey extracted from the protobuf.
nparray (np.array): Numpy array.
The numpy array associated with the returned tensorkey.
"""
raw_bytes = named_tensor.data_bytes
metadata = [
Expand Down Expand Up @@ -925,6 +946,14 @@ def _compute_validation_related_task_metrics(self, task_name):

self.metric_queue.put(metrics)
self.logger.metric("%s", metrics)
if self.write_logs:
self.log_metric(
"aggregator",
task_name,
tensor_key.tensor_name,
float(agg_results),
round_number,
)

# FIXME: Configurable logic for min/max criteria in saving best.
if "validate_agg" in tags:
Expand Down
15 changes: 15 additions & 0 deletions openfl/utilities/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
import psutil
from rich.console import Console
from rich.logging import RichHandler
from torch.utils.tensorboard import SummaryWriter


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback.
This function logs a metric to TensorBoard.
Args:
node_name (str): The name of the node.
task_name (str): The name of the task.
metric_name (str): The name of the metric.
metric (float): The value of the metric.
round_number (int): The current round number.
"""
writer = SummaryWriter("./logs/tensorboard", flush_secs=5)
writer.add_scalar(f"{node_name}/{task_name}/{metric_name}", metric, round_number)


def setup_loggers(log_level=logging.INFO):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def run(self):
'pandas',
'scikit-learn',
'flatten_json',
'tensorboardX',
'protobuf>=4.22,<6.0.0',
'grpcio>=1.56.2,<1.66.0',
],
Expand Down

0 comments on commit 8d20e42

Please sign in to comment.