Skip to content

Commit

Permalink
add several Byzantine robust algorithms (#552)
Browse files Browse the repository at this point in the history
  • Loading branch information
private-mechanism authored Mar 27, 2023
1 parent d2e7d08 commit 2f31956
Show file tree
Hide file tree
Showing 12 changed files with 584 additions and 23 deletions.
19 changes: 19 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,22 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

---------------------------------------------------------------------------------
The implementations of median aggregator in federatedscope/core/aggregators/median_aggregator.py
and trimmedmean aggregator in federatedscope/core/aggregators/trimmedmean_aggregator.py
are adapted from https://github.com/bladesteam/blades (Apache License)

Copyright (c) 2022 lishenghui

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
11 changes: 11 additions & 0 deletions federatedscope/core/aggregators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import ServerClientsInterpolateAggregator
from federatedscope.core.aggregators.fedopt_aggregator import FedOptAggregator
from federatedscope.core.aggregators.krum_aggregator import KrumAggregator
from federatedscope.core.aggregators.median_aggregator import MedianAggregator
from federatedscope.core.aggregators.trimmedmean_aggregator import \
TrimmedmeanAggregator
from federatedscope.core.aggregators.bulyan_aggregator import \
BulyanAggregator
from federatedscope.core.aggregators.normbounding_aggregator import \
NormboundingAggregator

__all__ = [
'Aggregator',
Expand All @@ -18,4 +25,8 @@
'ServerClientsInterpolateAggregator',
'FedOptAggregator',
'KrumAggregator',
'MedianAggregator',
'TrimmedmeanAggregator',
'BulyanAggregator',
'NormboundingAggregator',
]
106 changes: 106 additions & 0 deletions federatedscope/core/aggregators/bulyan_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import copy
import torch
from federatedscope.core.aggregators import ClientsAvgAggregator


class BulyanAggregator(ClientsAvgAggregator):
"""
Implementation of Bulyan refers to `The Hidden Vulnerability
of Distributed Learning in Byzantium`
[Mhamdi et al., 2018]
(http://proceedings.mlr.press/v80/mhamdi18a/mhamdi18a.pdf)
It combines the MultiKrum aggregator and the treamedmean aggregator
"""
def __init__(self, model=None, device='cpu', config=None):
super(BulyanAggregator, self).__init__(model, device, config)
self.byzantine_node_num = config.aggregator.byzantine_node_num
self.sample_client_rate = config.federate.sample_client_rate
assert 4 * self.byzantine_node_num + 3 <= config.federate.client_num

def aggregate(self, agg_info):
"""
To preform aggregation with Median aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_bulyan(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model

def _calculate_distance(self, model_a, model_b):
"""
Calculate the Euclidean distance between two given model para delta
"""
distance = 0.0

for key in model_a:
if isinstance(model_a[key], torch.Tensor):
model_a[key] = model_a[key].float()
model_b[key] = model_b[key].float()
else:
model_a[key] = torch.FloatTensor(model_a[key])
model_b[key] = torch.FloatTensor(model_b[key])

distance += torch.dist(model_a[key], model_b[key], p=2)
return distance

def _calculate_score(self, models):
"""
Calculate Krum scores
"""
model_num = len(models)
closest_num = model_num - self.byzantine_node_num - 2

distance_matrix = torch.zeros(model_num, model_num)
for index_a in range(model_num):
for index_b in range(index_a, model_num):
if index_a == index_b:
distance_matrix[index_a, index_b] = float('inf')
else:
distance_matrix[index_a, index_b] = distance_matrix[
index_b, index_a] = self._calculate_distance(
models[index_a], models[index_b])

sorted_distance = torch.sort(distance_matrix)[0]
krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1)
return krum_scores

def _aggre_with_bulyan(self, models):
'''
Apply MultiKrum to select \theta (\theta <= client_num-
2*self.byzantine_node_num) local models
'''
init_model = self.model.state_dict()
global_update = copy.deepcopy(init_model)
models_para = [each_model[1] for each_model in models]
krum_scores = self._calculate_score(models_para)
index_order = torch.sort(krum_scores)[1].numpy()
reliable_models = list()
for number, index in enumerate(index_order):
if number < len(models) - int(
2 * self.sample_client_rate * self.byzantine_node_num):
reliable_models.append(models[index])
'''
Sort parameter for each coordinate of the rest \theta reliable
local models, and find \gamma (gamma<\theta-2*self.byzantine_num)
parameters closest to the median to perform averaging
'''
exluded_num = int(self.sample_client_rate * self.byzantine_node_num)
gamma = len(reliable_models) - 2 * exluded_num
for key in init_model:
temp = torch.stack(
[each_model[1][key] for each_model in reliable_models], 0)
pos_largest, _ = torch.topk(temp, exluded_num, 0)
neg_smallest, _ = torch.topk(-temp, exluded_num, 0)
new_stacked = torch.cat([temp, -pos_largest,
neg_smallest]).sum(0).float()
new_stacked /= gamma
global_update[key] = new_stacked
return global_update
2 changes: 1 addition & 1 deletion federatedscope/core/aggregators/krum_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class KrumAggregator(ClientsAvgAggregator):
def __init__(self, model=None, device='cpu', config=None):
super(KrumAggregator, self).__init__(model, device, config)
self.byzantine_node_num = config.aggregator.byzantine_node_num
self.krum_agg_num = config.aggregator.krum.agg_num
self.krum_agg_num = config.aggregator.BFT_args.krum_agg_num
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \
"it should be satisfied that 2*byzantine_node_num + 2 < client_num"

Expand Down
52 changes: 52 additions & 0 deletions federatedscope/core/aggregators/median_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator
import logging

logger = logging.getLogger(__name__)


class MedianAggregator(ClientsAvgAggregator):
"""
Implementation of median refers to `Byzantine-robust distributed
learning: Towards optimal statistical rates`
[Yin et al., 2018]
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf)
It computes the coordinate-wise median of recieved updates from clients
The code is adapted from https://github.com/bladesteam/blades
"""
def __init__(self, model=None, device='cpu', config=None):
super(MedianAggregator, self).__init__(model, device, config)
self.byzantine_node_num = config.aggregator.byzantine_node_num
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \
"it should be satisfied that 2*byzantine_node_num + 2 < client_num"

def aggregate(self, agg_info):
"""
To preform aggregation with Median aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_median(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model

def _aggre_with_median(self, models):
init_model = self.model.state_dict()
global_update = copy.deepcopy(init_model)
for key in init_model:
temp = torch.stack([each_model[1][key] for each_model in models],
0)
temp_pos, _ = torch.median(temp, dim=0)
temp_neg, _ = torch.median(-temp, dim=0)
global_update[key] = (temp_pos - temp_neg) / 2
return global_update
64 changes: 64 additions & 0 deletions federatedscope/core/aggregators/normbounding_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator

logger = logging.getLogger(__name__)


class NormboundingAggregator(ClientsAvgAggregator):
"""
The server clips each update to reduce the negative impact \
of malicious updates.
"""
def __init__(self, model=None, device='cpu', config=None):
super(NormboundingAggregator, self).__init__(model, device, config)
self.norm_bound = config.aggregator.BFT_args.normbounding_norm_bound

def aggregate(self, agg_info):
"""
To preform aggregation with normbounding aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_normbounding(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model

def _aggre_with_normbounding(self, models):
models_temp = []
for each_model in models:
param = self._flatten_updates(each_model[1])
if torch.norm(param, p=2) > self.norm_bound:
scaling_rate = self.norm_bound / torch.norm(param, p=2)
scaled_param = scaling_rate * param
models_temp.append(
(each_model[0], self._reconstruct_updates(scaled_param)))
else:
models_temp.append(each_model)
return self._para_weighted_avg(models_temp)

def _flatten_updates(self, model):
model_update = []
init_model = self.model.state_dict()
for key in init_model:
model_update.append(model[key].view(-1))
return torch.cat(model_update, dim=0)

def _reconstruct_updates(self, flatten_updates):
start_idx = 0
init_model = self.model.state_dict()
reconstructed_model = copy.deepcopy(init_model)
for key in init_model:
reconstructed_model[key] = flatten_updates[
start_idx:start_idx + len(init_model[key].view(-1))].reshape(
init_model[key].shape)
start_idx = start_idx + len(init_model[key].view(-1))
return reconstructed_model
57 changes: 57 additions & 0 deletions federatedscope/core/aggregators/trimmedmean_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator
import logging

logger = logging.getLogger(__name__)


class TrimmedmeanAggregator(ClientsAvgAggregator):
"""
Implementation of median refer to `Byzantine-robust distributed
learning: Towards optimal statistical rates`
[Yin et al., 2018]
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf)
The code is adapted from https://github.com/bladesteam/blades
"""
def __init__(self, model=None, device='cpu', config=None):
super(TrimmedmeanAggregator, self).__init__(model, device, config)
self.excluded_ratio = \
config.aggregator.BFT_args.trimmedmean_excluded_ratio
self.byzantine_node_num = config.aggregator.byzantine_node_num
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \
"it should be satisfied that 2*byzantine_node_num + 2 < client_num"
assert self.excluded_ratio < 0.5

def aggregate(self, agg_info):
"""
To preform aggregation with trimmedmean aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_trimmedmean(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model

def _aggre_with_trimmedmean(self, models):
init_model = self.model.state_dict()
global_update = copy.deepcopy(init_model)
excluded_num = int(len(models) * self.excluded_ratio)
for key in init_model:
temp = torch.stack([each_model[1][key] for each_model in models],
0)
pos_largest, _ = torch.topk(temp, excluded_num, 0)
neg_smallest, _ = torch.topk(-temp, excluded_num, 0)
new_stacked = torch.cat([temp, -pos_largest,
neg_smallest]).sum(0).float()
new_stacked /= len(temp) - 2 * excluded_num
global_update[key] = new_stacked
return global_update
28 changes: 22 additions & 6 deletions federatedscope/core/auxiliaries/aggregator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,18 @@ def get_aggregator(method, model=None, device=None, online=False, config=None):
from federatedscope.core.aggregators import ClientsAvgAggregator, \
OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \
FedOptAggregator, NoCommunicationAggregator, \
AsynClientsAvgAggregator, KrumAggregator
AsynClientsAvgAggregator, KrumAggregator, \
MedianAggregator, TrimmedmeanAggregator, \
BulyanAggregator, NormboundingAggregator

STR2AGG = {
'fedavg': ClientsAvgAggregator,
'krum': KrumAggregator,
'median': MedianAggregator,
'bulyan': BulyanAggregator,
'trimmedmean': TrimmedmeanAggregator,
'normbounding': NormboundingAggregator
}

if method.lower() in constants.AGGREGATOR_TYPE:
aggregator_type = constants.AGGREGATOR_TYPE[method.lower()]
Expand Down Expand Up @@ -87,12 +98,17 @@ def get_aggregator(method, model=None, device=None, online=False, config=None):
return AsynClientsAvgAggregator(model=model,
device=device,
config=config)
elif config.aggregator.krum.use:
return KrumAggregator(model=model, device=device, config=config)
else:
return ClientsAvgAggregator(model=model,
device=device,
config=config)
if config.aggregator.robust_rule not in STR2AGG:
logger.warning(
f'The specified {config.aggregator.robust_rule} aggregtion\
rule has not been supported, the vanilla fedavg algorithm \
will be used instead.')
return STR2AGG.get(config.aggregator.robust_rule,
ClientsAvgAggregator)(model=model,
device=device,
config=config)

elif aggregator_type == 'server_clients_interpolation':
return ServerClientsInterpolateAggregator(
model=model,
Expand Down
Loading

0 comments on commit 2f31956

Please sign in to comment.