diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index f684dec8a..1e31e727f 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -96,6 +96,7 @@ def extend_fl_setting_cfg(cfg): # Default values for 'dp': {'bucket_num':100, 'epsilon':None} # Default values for 'op_boost': {'algo':'global', 'lower_bound':1, # 'upper_bound':100, 'epsilon':2} + cfg.vertical.eval_protection = '' # ['', 'he'] cfg.vertical.data_size_for_debug = 0 # use a subset for debug in vfl, # 0 indicates using the entire dataset (disable debug mode) diff --git a/federatedscope/vertical_fl/loss/binary_cls.py b/federatedscope/vertical_fl/loss/binary_cls.py index aa95e8840..6223ef57a 100644 --- a/federatedscope/vertical_fl/loss/binary_cls.py +++ b/federatedscope/vertical_fl/loss/binary_cls.py @@ -1,4 +1,5 @@ import numpy as np +from sklearn import metrics class BinaryClsLoss(object): @@ -27,9 +28,10 @@ def _process_y_pred(self, y_pred): def get_metric(self, y, y_pred): y_pred = self._process_y_pred(y_pred) + auc = metrics.roc_auc_score(y, y_pred) y_pred = (y_pred >= 0.5).astype(np.float32) acc = np.sum(y_pred == y) / len(y) - return {'acc': acc} + return {'acc': acc, 'auc': auc} def get_loss(self, y, y_pred): y_pred = self._process_y_pred(y_pred) diff --git a/federatedscope/vertical_fl/utils.py b/federatedscope/vertical_fl/utils.py index d3e5d6947..959df62b1 100644 --- a/federatedscope/vertical_fl/utils.py +++ b/federatedscope/vertical_fl/utils.py @@ -1,6 +1,8 @@ from federatedscope.vertical_fl.xgb_base.worker import wrap_client_for_train, \ wrap_server_for_train, wrap_client_for_evaluation, \ wrap_server_for_evaluation +from federatedscope.vertical_fl.xgb_base.worker.he_evaluation_wrapper\ + import wrap_client_for_he_evaluation def wrap_vertical_server(server, config): @@ -13,7 +15,9 @@ def wrap_vertical_server(server, config): def wrap_vertical_client(client, config): if config.vertical.algo in ['xgb', 'gbdt', 'rf']: + if config.vertical.eval_protection == 'he': + client = wrap_client_for_he_evaluation(client) + else: + client = wrap_client_for_evaluation(client) client = wrap_client_for_train(client) - client = wrap_client_for_evaluation(client) - return client diff --git a/federatedscope/vertical_fl/xgb_base/baseline/gbdt_base_on_abalone.yaml b/federatedscope/vertical_fl/xgb_base/baseline/gbdt_base_on_abalone.yaml new file mode 100644 index 000000000..d6e3449ab --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/gbdt_base_on_abalone.yaml @@ -0,0 +1,36 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: gbdt_tree + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 4 +data: + root: data/ + type: abalone + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +criterion: + type: RegressionMSELoss +trainer: + type: verticaltrainer +train: + optimizer: + # learning rate for xgb model + eta: 0.5 +vertical: + use: True + dims: [4, 8] + algo: 'gbdt' + data_size_for_debug: 2000 + feature_subsample_ratio: 0.5 +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/rf_base_on_abalone.yaml b/federatedscope/vertical_fl/xgb_base/baseline/rf_base_on_abalone.yaml new file mode 100644 index 000000000..2bd00081d --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/rf_base_on_abalone.yaml @@ -0,0 +1,32 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: random_forest + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 4 +data: + root: data/ + type: abalone + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +criterion: + type: RegressionMSELoss +trainer: + type: verticaltrainer +vertical: + use: True + dims: [4, 8] + algo: 'rf' + data_size_for_debug: 1500 + feature_subsample_ratio: 0.5 +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml index ce18b6e51..4020cf05c 100644 --- a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml @@ -16,7 +16,7 @@ data: splits: [0.8, 0.2] dataloader: type: raw - batch_size: 4000 + batch_size: 2000 criterion: type: RegressionMSELoss trainer: diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult_by_he_eval.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult_by_he_eval.yaml new file mode 100644 index 000000000..2eeb454d5 --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult_by_he_eval.yaml @@ -0,0 +1,36 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: xgb_tree + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +data: + root: data/ + type: adult + splits: [1.0, 0.0] +dataloader: + type: raw + batch_size: 2000 +criterion: + type: CrossEntropyLoss +trainer: + type: verticaltrainer +train: + optimizer: + # learning rate for xgb model + eta: 0.5 +vertical: + use: True + dims: [7, 14] + algo: 'xgb' + eval_protection: 'he' + data_size_for_debug: 2000 +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_feature_order_dp_on_abalone.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_feature_order_dp_on_abalone.yaml new file mode 100644 index 000000000..b493f0a86 --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_feature_order_dp_on_abalone.yaml @@ -0,0 +1,38 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: xgb_tree + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +data: + root: data/ + type: abalone + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +criterion: + type: RegressionMSELoss +trainer: + type: verticaltrainer +train: + optimizer: + # learning rate for xgb model + eta: 0.5 +vertical: + use: True + dims: [4, 8] + algo: 'xgb' + protect_object: 'feature_order' + protect_method: 'dp' + protect_args: [ { 'bucket_num': 100, 'epsilon': 10 } ] + data_size_for_debug: 2000 +eval: + freq: 5 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index 1df94682b..9e1a80c3a 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -4,6 +4,8 @@ from federatedscope.core.workers import Client from federatedscope.core.message import Message +from federatedscope.vertical_fl.Paillier import \ + abstract_paillier logger = logging.getLogger(__name__) @@ -30,6 +32,11 @@ def __init__(self, self.msg_buffer = {'train': {}, 'eval': {}} self.client_num = self._cfg.federate.client_num + if self._cfg.vertical.eval_protection == 'he': + keys = abstract_paillier.generate_paillier_keypair( + n_length=self._cfg.vertical.key_size) + self.public_key, self.private_key = keys + self.feature_order = None self.merged_feature_order = None diff --git a/federatedscope/vertical_fl/xgb_base/worker/__init__.py b/federatedscope/vertical_fl/xgb_base/worker/__init__.py index 752d2db3c..f55fa814e 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/__init__.py +++ b/federatedscope/vertical_fl/xgb_base/worker/__init__.py @@ -4,8 +4,11 @@ wrap_server_for_train, wrap_client_for_train from federatedscope.vertical_fl.xgb_base.worker.evaluation_wrapper import \ wrap_server_for_evaluation, wrap_client_for_evaluation +from federatedscope.vertical_fl.xgb_base.worker.he_evaluation_wrapper\ + import wrap_client_for_he_evaluation __all__ = [ 'XGBServer', 'XGBClient', 'wrap_server_for_train', 'wrap_client_for_train', - 'wrap_server_for_evaluation', 'wrap_client_for_evaluation' + 'wrap_server_for_evaluation', 'wrap_client_for_evaluation', + 'wrap_client_for_he_evaluation' ] diff --git a/federatedscope/vertical_fl/xgb_base/worker/he_evaluation_wrapper.py b/federatedscope/vertical_fl/xgb_base/worker/he_evaluation_wrapper.py new file mode 100644 index 000000000..0835bccbd --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/worker/he_evaluation_wrapper.py @@ -0,0 +1,326 @@ +import types +import logging +import numpy as np + +from federatedscope.vertical_fl.loss.utils import get_vertical_loss +from federatedscope.core.message import Message + +logger = logging.getLogger(__name__) + + +def wrap_client_for_he_evaluation(client): + """ + Use PHE to perform secure evaluation. + For more details, please refer to the following papers: + An Efficient and Robust System for Vertically Federated Random Forest + (https://arxiv.org/pdf/2201.10761.pdf) + Privacy Preserving Vertical Federated Learning for Tree-based Models + (https://arxiv.org/pdf/2008.06170.pdf) + Fed-EINI: An Efficient and Interpretable Inference Framework for + Decision Tree Ensembles in Vertical Fed + (https://arxiv.org/pdf/2105.09540.pdf) + + """ + def eval(self, tree_num): + self.criterion = get_vertical_loss(loss_type=self._cfg.criterion.type, + model_type=self._cfg.model.type) + + off_node_list = list() + for node_num in range(2**self.model.max_depth - 1): + if self.model[tree_num][node_num].status == 'off': + off_node_list.append(node_num) + if off_node_list: + self.comm_manager.send( + Message( + msg_type='off_node_list', + sender=self.ID, + state=self.state, + receiver=[ + each + for each in list(self.comm_manager.neighbors.keys()) + if each != self.server_id + ], + content=(tree_num, off_node_list))) + + if self.test_x is None: + self.test_x, self.test_y = self._fetch_test_data() + self.merged_test_result = list() + self.test_result = np.zeros(self.test_x.shape[0]) + + self.one_tree_weight_vector = self.iterate_for_weight_vector( + tree_num, list(range(2**self.model.max_depth - 1))) + if self._cfg.model.type in ['xgb_tree', 'gbdt_tree']: + eta = self._cfg.train.optimizer.eta + else: + eta = 1.0 + self.one_tree_weight_vector = [ + eta * x for x in self.one_tree_weight_vector + ] + enc_one_tree_weight_vector = [ + self.public_key.encrypt(x) for x in self.one_tree_weight_vector + ] + indicator_array = self.get_test_result_for_one_tree(tree_num) + enc_indicator_weight_array = \ + indicator_array * enc_one_tree_weight_vector + self.comm_manager.send( + Message(msg_type='enc_pred_result', + sender=self.ID, + state=self.state, + receiver=self.ID - 1, + content=(tree_num, enc_indicator_weight_array))) + + def callback_func_for_off_node_list(self, message: Message): + tree_num, off_node_list = message.content + for node_num in off_node_list: + self.model[tree_num][node_num].status = 'off' + + def callback_func_for_enc_pred_result(self, message: Message): + if self.test_x is None: + self.test_x, self.test_y = self._fetch_test_data() + tree_num, enc_indicator_weight_array = message.content + indicator_array = self.get_test_result_for_one_tree(tree_num) + enc_indicator_weight_array = \ + indicator_array * enc_indicator_weight_array + if self.ID != 1: + self.comm_manager.send( + Message(msg_type='enc_pred_result', + sender=self.ID, + state=self.state, + receiver=self.ID - 1, + content=(tree_num, enc_indicator_weight_array))) + else: + enc_res = np.sum(enc_indicator_weight_array, axis=1) + self.comm_manager.send( + Message(msg_type='pred_result', + sender=self.ID, + state=self.state, + receiver=self.client_num, + content=(tree_num, enc_res))) + + def _fetch_test_data(self): + test_x = self.data['test']['x'] + test_y = self.data['test']['y'] if 'y' in self.data['test'] else None + + return test_x, test_y + + def callback_func_for_pred_result(self, message: Message): + tree_num, enc_res = message.content + self.test_result = np.asarray( + [self.private_key.decrypt(x) for x in enc_res]) + self.merged_test_result.append(self.test_result) + if ( + tree_num + 1 + ) % self._cfg.eval.freq == 0 or \ + tree_num + 1 == self._cfg.model.num_of_trees: + self._feedback_eval_metrics() + self.eval_finish_flag = True + self._check_eval_finish(tree_num) + + def _feedback_eval_metrics(self): + test_loss = self.criterion.get_loss(self.test_y, + self.merged_test_result) + metrics = self.criterion.get_metric(self.test_y, + self.merged_test_result) + modified_metrics = dict() + for key in metrics.keys(): + if 'test' not in key: + modified_metrics['test_' + key] = metrics[key] + else: + modified_metrics[key] = metrics[key] + modified_metrics.update({ + 'test_loss': test_loss, + 'test_total': len(self.test_y) + }) + + self.comm_manager.send( + Message(msg_type='eval_metric', + sender=self.ID, + state=self.state, + receiver=[self.server_id], + content=modified_metrics)) + self.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.ID, + state=self.state, + receiver=[self.server_id], + content=self.feature_importance)) + self.comm_manager.send( + Message(msg_type='ask_for_feature_importance', + sender=self.ID, + state=self.state, + receiver=[ + each + for each in list(self.comm_manager.neighbors.keys()) + if each != self.server_id + ], + content='None')) + + def callback_func_for_feature_importance(self, message: Message): + state = message.state + self.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.ID, + state=state, + receiver=[self.server_id], + content=self.feature_importance)) + + def iterate_for_leaf_vector(self, x, tree_num, tree_node_list, flag): + tree = self.model[tree_num] + node_num = tree_node_list[0] + feature_idx = tree[node_num].feature_idx + if tree[node_num].status == 'off': + return np.asarray([flag]) + else: + if flag == 0: + left_flag = right_flag = 0 + else: + if feature_idx is None or tree[node_num].feature_value is None: + left_flag = right_flag = 1 + elif x[feature_idx] < tree[node_num].feature_value: + left_flag, right_flag = 1, 0 + else: + left_flag, right_flag = 0, 1 + subtree_size = int(np.log2(len(tree_node_list))) + left_subtree_node_list = [] + right_subtree_node_list = [] + for i in range(1, subtree_size + 1): + subtree_node_list = tree_node_list[2**i - 1:2**(i + 1) - 1] + length = len(subtree_node_list) + left_subtree_node_list.extend(subtree_node_list[:length // 2]) + right_subtree_node_list.extend(subtree_node_list[length // 2:]) + left_vector = self.iterate_for_leaf_vector(x, tree_num, + left_subtree_node_list, + left_flag) + right_vector = self.iterate_for_leaf_vector( + x, tree_num, right_subtree_node_list, right_flag) + return np.concatenate((left_vector, right_vector)) + + def iterate_for_weight_vector(self, tree_num, tree_node_list): + tree = self.model[tree_num] + node_num = tree_node_list[0] + if tree[node_num].status == 'off': + return np.asarray([tree[node_num].weight]) + else: + subtree_size = int(np.log2(len(tree_node_list))) + left_subtree_node_list = [] + right_subtree_node_list = [] + for i in range(1, subtree_size + 1): + subtree_node_list = tree_node_list[2**i - 1:2**(i + 1) - 1] + length = len(subtree_node_list) + left_subtree_node_list.extend(subtree_node_list[:length // 2]) + right_subtree_node_list.extend(subtree_node_list[length // 2:]) + left_vector = self.iterate_for_weight_vector( + tree_num, left_subtree_node_list) + right_vector = self.iterate_for_weight_vector( + tree_num, right_subtree_node_list) + return np.concatenate((left_vector, right_vector)) + + def get_test_result_for_one_tree(self, tree_num): + res = [0] * self.test_x.shape[0] + for i in range(len(self.test_x)): + res[i] = self.iterate_for_leaf_vector( + self.test_x[i], + tree_num, + list(range(2**self.model.max_depth - 1)), + flag=1) + return np.asarray(res) + + # Bind method to instance + client.eval = types.MethodType(eval, client) + client._fetch_test_data = types.MethodType(_fetch_test_data, client) + client.iterate_for_leaf_vector = types.MethodType(iterate_for_leaf_vector, + client) + client._feedback_eval_metrics = types.MethodType(_feedback_eval_metrics, + client) + client.iterate_for_weight_vector = types.MethodType( + iterate_for_weight_vector, client) + client.get_test_result_for_one_tree = types.MethodType( + get_test_result_for_one_tree, client) + client.callback_func_for_off_node_list = types.MethodType( + callback_func_for_off_node_list, client) + client.callback_func_for_enc_pred_result = types.MethodType( + callback_func_for_enc_pred_result, client) + client.callback_func_for_pred_result = types.MethodType( + callback_func_for_pred_result, client) + client.callback_func_for_feature_importance = types.MethodType( + callback_func_for_feature_importance, client) + + # Register handler functions + client.register_handlers('off_node_list', + client.callback_func_for_off_node_list) + client.register_handlers('enc_pred_result', + client.callback_func_for_enc_pred_result) + client.register_handlers('pred_result', + client.callback_func_for_pred_result) + client.register_handlers('ask_for_feature_importance', + client.callback_func_for_feature_importance) + + return client + + +def wrap_server_for_evaluation(server): + def _check_and_save_result(self): + + state = max(self.msg_buffer['eval'].keys()) + buffer = self.msg_buffer['eval'][state] + if len(buffer['feature_importance'] + ) == self.client_num and buffer['metrics'] is not None: + self.state = state + self.feature_importance = dict( + sorted(buffer['feature_importance'].items(), + key=lambda x: x[0])) + self.metrics = buffer['metrics'] + self._monitor.update_best_result(self.best_results, + self.metrics, + results_type='server_global_eval') + self._monitor.add_items_to_best_result( + self.best_results, + self.feature_importance, + results_type='feature_importance') + formatted_logs = self._monitor.format_eval_res( + self.metrics, + rnd=self.state, + role='Server #', + forms=self._cfg.eval.report) + formatted_logs['feature_importance'] = self.feature_importance + logger.info(formatted_logs) + + if self.state + 1 == self._cfg.model.num_of_trees: + self.terminate() + + def callback_func_for_feature_importance(self, message: Message): + # Save the feature importance + feature_importance = message.content + sender = message.sender + state = message.state + if state not in self.msg_buffer['eval']: + self.msg_buffer['eval'][state] = {} + self.msg_buffer['eval'][state]['feature_importance'] = {} + self.msg_buffer['eval'][state]['metrics'] = None + self.msg_buffer['eval'][state]['feature_importance'].update( + {str(sender): feature_importance}) + self._check_and_save_result() + + def callback_funcs_for_metrics(self, message: Message): + state, metrics = message.state, message.content + if state not in self.msg_buffer['eval']: + self.msg_buffer['eval'][state] = {} + self.msg_buffer['eval'][state]['feature_importance'] = {} + self.msg_buffer['eval'][state]['metrics'] = None + self.msg_buffer['eval'][state]['metrics'] = metrics + self._check_and_save_result() + + # Bind method to instance + server._check_and_save_result = types.MethodType(_check_and_save_result, + server) + server.callback_func_for_feature_importance = types.MethodType( + callback_func_for_feature_importance, server) + server.callback_funcs_for_metrics = types.MethodType( + callback_funcs_for_metrics, server) + + # Register handler functions + server.register_handlers('feature_importance', + server.callback_func_for_feature_importance) + server.register_handlers('eval_metric', server.callback_funcs_for_metrics) + + return server diff --git a/tests/test_tree_based_model_for_vfl.py b/tests/test_tree_based_model_for_vfl.py index 65fc8dd8d..3e5d0ed8f 100644 --- a/tests/test_tree_based_model_for_vfl.py +++ b/tests/test_tree_based_model_for_vfl.py @@ -49,6 +49,43 @@ def set_config_for_xgb_base(self, cfg): return backup_cfg + def set_config_for_he_eval(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + + cfg.federate.mode = 'standalone' + cfg.federate.client_num = 2 + + cfg.model.type = 'xgb_tree' + cfg.model.lambda_ = 0.1 + cfg.model.gamma = 0 + cfg.model.num_of_trees = 10 + cfg.model.max_tree_depth = 3 + + cfg.train.optimizer.eta = 0.5 + + cfg.data.root = 'test_data/' + cfg.data.type = 'adult' + + cfg.dataloader.type = 'raw' + cfg.dataloader.batch_size = 2000 + + cfg.criterion.type = 'CrossEntropyLoss' + + cfg.vertical.use = True + cfg.vertical.dims = [7, 14] + cfg.vertical.algo = 'xgb' + cfg.vertical.data_size_for_debug = 2000 + cfg.vertical.eval = 'he' + + cfg.trainer.type = 'verticaltrainer' + cfg.eval.freq = 5 + cfg.eval.best_res_update_round_wise_key = "test_loss" + + return backup_cfg + def set_config_for_gbdt_base(self, cfg): backup_cfg = cfg.clone() @@ -336,6 +373,27 @@ def test_XGB_Base(self): self.assertGreater(test_results['server_global_eval']['test_acc'], 0.79) + def test_XGB_Base_for_he_eval(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_for_xgb_base(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_config = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_config) + self.assertIsNotNone(data) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_results = Fed_runner.run() + init_cfg.merge_from_other_cfg(backup_cfg) + print(test_results) + self.assertGreater(test_results['server_global_eval']['test_acc'], + 0.79) + def test_GBDT_Base(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_for_gbdt_base(init_cfg)