Skip to content

Commit

Permalink
- add mode option to aggregator cli, default to "train_and_validate"
Browse files Browse the repository at this point in the history
- enhance Aggregator to take mode attribute to enable fedeval or training switching at aggregator level
- rebase 10.Jan.2
- fixed formatting issues
- update test_aggregator_api with changes in start command
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 10, 2025
1 parent dfe9512 commit 7c7a82a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 13 deletions.
15 changes: 13 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import logging
Expand All @@ -20,6 +19,8 @@

logger = logging.getLogger(__name__)

VALID_MODES = {"train_and_validate", "evaluate"}


class Aggregator:
"""An Aggregator is the central node in federated learning.
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
log_memory_usage=False,
write_logs=False,
callbacks: Optional[List] = None,
mode: str = "train_and_validate",
):
"""Initializes the Aggregator.
Expand All @@ -108,7 +110,12 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
mode (str, optional): Operation mode. Can be 'train_and_validate',
'evaluate'. Defaults to 'train_and_validate'.
"""
if mode not in VALID_MODES:
raise ValueError(f"Mode must be one of {VALID_MODES}, got {mode}")
self.mode = mode
self.round_number = 0

if single_col_cert_common_name:
Expand Down Expand Up @@ -208,9 +215,13 @@ def _load_initial_tensors(self):
self.model, compression_pipeline=self.compression_pipeline
)

if round_number > self.round_number:
# Check mode before updating round number
if self.mode == "evaluate":
logger.info(f"Skipping round_number check for mode {self.mode}")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number

tensor_key_dict = {
TensorKey(k, self.uuid, self.round_number, False, ("model",)): v
for k, v in tensor_dict.items()
Expand Down
51 changes: 43 additions & 8 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import sys
from logging import getLogger
from pathlib import Path

from click import Path as ClickPath
from click import confirm, echo, group, option, pass_context, style
from click import (
Choice,
confirm,
echo,
group,
option,
pass_context,
style,
)
from click import (
Path as ClickPath,
)

from openfl.cryptography.ca import sign_certificate
from openfl.cryptography.io import get_csr_hash, read_crt, read_csr, read_key, write_crt, write_key
from openfl.cryptography.io import (
get_csr_hash,
read_crt,
read_csr,
read_key,
write_crt,
write_key,
)
from openfl.cryptography.participant import generate_csr
from openfl.federated import Plan
from openfl.interface.cli_helper import CERT_DIR
Expand Down Expand Up @@ -52,24 +68,43 @@ def aggregator(context):
default="plan/cols.yaml",
type=ClickPath(exists=True),
)
def start_(plan, authorized_cols):
"""Start the aggregator service."""
@option(
"-m",
"--mode",
type=Choice(["train_and_validate", "evaluate"]),
default="train_and_validate",
help="Operation mode - either train_and_validate or evaluate",
)
def start_(plan, authorized_cols, mode):
"""Start the aggregator service.
Args:
plan (str): Path to plan config file
authorized_cols (str): Path to authorized collaborators file
mode (str): Operation mode - either train_and_validate or evaluate
"""
if is_directory_traversal(plan):
echo("Federated learning plan path is out of the openfl workspace scope.")
sys.exit(1)
if is_directory_traversal(authorized_cols):
echo("Authorized collaborator list file path is out of the openfl workspace scope.")
sys.exit(1)

plan = Plan.parse(
# Parse plan and override mode if specified
parsed_plan = Plan.parse(
plan_config_path=Path(plan).absolute(),
cols_config_path=Path(authorized_cols).absolute(),
)

# Set mode in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["mode"] = mode
logger.info(f"Setting aggregator mode to: {mode}")

logger.info("🧿 Starting the Aggregator Service.")

plan.get_server().serve()
parsed_plan.get_server().serve()


@aggregator.command(name="generate-cert-request")
Expand Down
44 changes: 41 additions & 3 deletions tests/openfl/interface/test_aggregator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ def test_aggregator_start(mock_parse):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'mode': 'train_and_validate'}.get
mock_plan.get = {'mode': 'train_and_validate'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'mode': 'train_and_validate'
}
}
}
mock_parse.return_value = mock_plan

ret = start_(['-p', plan_config,
'-c', cols_config], standalone_mode=False)
Expand All @@ -32,7 +44,20 @@ def test_aggregator_start_illegal_plan(mock_parse, mock_is_directory_traversal):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'mode': 'train_and_validate'}.get
mock_plan.get = {'mode': 'train_and_validate'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'mode': 'train_and_validate'
}
}
}
mock_parse.return_value = mock_plan

mock_is_directory_traversal.side_effect = [True, False]

with TestCase.assertRaises(test_aggregator_start_illegal_plan, SystemExit):
Expand All @@ -48,7 +73,20 @@ def test_aggregator_start_illegal_cols(mock_parse, mock_is_directory_traversal):
plan_config = plan_path.joinpath('plan.yaml')
cols_config = plan_path.joinpath('cols.yaml')

mock_parse.return_value = mock.Mock()
# Create a mock plan with the required fields
mock_plan = mock.MagicMock()
mock_plan.__getitem__.side_effect = {'mode': 'train_and_validate'}.get
mock_plan.get = {'mode': 'train_and_validate'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'mode': 'train_and_validate'
}
}
}
mock_parse.return_value = mock_plan

mock_is_directory_traversal.side_effect = [False, True]

with TestCase.assertRaises(test_aggregator_start_illegal_cols, SystemExit):
Expand Down

0 comments on commit 7c7a82a

Please sign in to comment.