diff --git a/delta/task/learning.py b/delta/task/learning.py index 33943e9..683a9dd 100644 --- a/delta/task/learning.py +++ b/delta/task/learning.py @@ -2,35 +2,33 @@ import abc import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +import delta.dataset import numpy as np import torch from torch.utils.data import DataLoader, Dataset -import delta.dataset -from ..dataset import split_dataset from ..core.strategy import ( - LearningStrategy, - RandomSelectStrategy, + CURVE_TYPE, EpochMergeStrategy, IterMergeStrategy, - WeightResultStrategy, - CURVE_TYPE, + LearningStrategy, + RandomSelectStrategy, ResultStrategy, SelectStrategy, + WeightResultStrategy, ) from ..core.task import ( - DataFormat, DataLocation, DataNode, GraphNode, InputGraphNode, MapOperator, MapReduceOperator, - Operator, ReduceOperator, ) +from ..dataset import split_dataset from .task import HorizontalTask _logger = logging.getLogger(__name__) @@ -60,14 +58,14 @@ def __init__( merge_strategy = IterMergeStrategy(merge_iteration) super().__init__( "FedAvg", - select_strategy(min_clients, max_clients), - merge_strategy, - result_strategy(), - wait_timeout, - connection_timeout, - False, - precision, - curve, + select_strategy=select_strategy(min_clients, max_clients), + merge_strategy=merge_strategy, + result_strategy=result_strategy(), + wait_timeout=wait_timeout, + connection_timeout=connection_timeout, + fault_tolerant=False, + precision=precision, + curve=curve, ) @@ -100,14 +98,14 @@ def __init__( super().__init__( "FaultTolerantFedAvg", - select_strategy(min_clients, max_clients), - merge_strategy, - result_strategy(), - wait_timeout, - connection_timeout, - True, - precision, - curve, + select_strategy=select_strategy(min_clients, max_clients), + merge_strategy=merge_strategy, + result_strategy=result_strategy(), + wait_timeout=wait_timeout, + connection_timeout=connection_timeout, + fault_tolerant=True, + precision=precision, + curve=curve, ) @@ -477,7 +475,9 @@ def reduce( res[key] = tmp.item() except ValueError: res[key] = tmp - _logger.info(f"Round {self.round} validating result {key}: {res[key]}") + _logger.info( + f"Round {self.round} validating result {key}: {res[key]}" + ) return res val_op = _ValidateOp( diff --git a/learning_example.py b/learning_example.py index 89eadde..0dc2f37 100644 --- a/learning_example.py +++ b/learning_example.py @@ -63,7 +63,7 @@ def __init__(self) -> None: max_clients=3, # Maximum nodes allowed in each round, must be greater equal than min_clients. merge_epoch=1, # The number of epochs to run before aggregation is performed. merge_iteration=0, # The number of iterations to run before aggregation is performed. One of this and the above number must be 0. - wait_timeout=45, # Timeout for calculation. + wait_timeout=90, # Timeout for calculation. connection_timeout=10, # Wait timeout for each step. ), ) @@ -159,6 +159,6 @@ def state_dict(self) -> Dict[str, torch.Tensor]: task_id = delta_node.create_task(task) if delta_node.trace(task_id): res = delta_node.get_result(task_id) - print(type(res)) + print(res) else: print("Task error") diff --git a/logit_example.py b/logit_example.py index f952432..31d4deb 100644 --- a/logit_example.py +++ b/logit_example.py @@ -17,7 +17,7 @@ def __init__( wait_timeout=5, # Timeout for calculation. connection_timeout=5, # Wait timeout for each step. verify_timeout=360, # Timeout for the final zero knownledge verification step - enable_verify=True # whether to enable final zero knownledge verification step + enable_verify=False # whether to enable final zero knownledge verification step ) def dataset(self): diff --git a/setup.py b/setup.py index 612fc84..47e82ed 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def run_tests(self): setup( name="delta-task", - version="0.8.0", + version="0.8.1", license_files=("LICENSE"), packages=find_packages(), include_package_data=True,