Skip to content

Commit

Permalink
Add xgb HE eval (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 authored Mar 30, 2023
1 parent 5e3abad commit 7c4a3f2
Show file tree
Hide file tree
Showing 12 changed files with 548 additions and 5 deletions.
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def extend_fl_setting_cfg(cfg):
# Default values for 'dp': {'bucket_num':100, 'epsilon':None}
# Default values for 'op_boost': {'algo':'global', 'lower_bound':1,
# 'upper_bound':100, 'epsilon':2}
cfg.vertical.eval_protection = '' # ['', 'he']
cfg.vertical.data_size_for_debug = 0 # use a subset for debug in vfl,
# 0 indicates using the entire dataset (disable debug mode)

Expand Down
4 changes: 3 additions & 1 deletion federatedscope/vertical_fl/loss/binary_cls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from sklearn import metrics


class BinaryClsLoss(object):
Expand Down Expand Up @@ -27,9 +28,10 @@ def _process_y_pred(self, y_pred):

def get_metric(self, y, y_pred):
y_pred = self._process_y_pred(y_pred)
auc = metrics.roc_auc_score(y, y_pred)
y_pred = (y_pred >= 0.5).astype(np.float32)
acc = np.sum(y_pred == y) / len(y)
return {'acc': acc}
return {'acc': acc, 'auc': auc}

def get_loss(self, y, y_pred):
y_pred = self._process_y_pred(y_pred)
Expand Down
8 changes: 6 additions & 2 deletions federatedscope/vertical_fl/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from federatedscope.vertical_fl.xgb_base.worker import wrap_client_for_train, \
wrap_server_for_train, wrap_client_for_evaluation, \
wrap_server_for_evaluation
from federatedscope.vertical_fl.xgb_base.worker.he_evaluation_wrapper\
import wrap_client_for_he_evaluation


def wrap_vertical_server(server, config):
Expand All @@ -13,7 +15,9 @@ def wrap_vertical_server(server, config):

def wrap_vertical_client(client, config):
if config.vertical.algo in ['xgb', 'gbdt', 'rf']:
if config.vertical.eval_protection == 'he':
client = wrap_client_for_he_evaluation(client)
else:
client = wrap_client_for_evaluation(client)
client = wrap_client_for_train(client)
client = wrap_client_for_evaluation(client)

return client
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use_gpu: False
device: 0
backend: torch
federate:
mode: standalone
client_num: 2
model:
type: gbdt_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 4
data:
root: data/
type: abalone
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
criterion:
type: RegressionMSELoss
trainer:
type: verticaltrainer
train:
optimizer:
# learning rate for xgb model
eta: 0.5
vertical:
use: True
dims: [4, 8]
algo: 'gbdt'
data_size_for_debug: 2000
feature_subsample_ratio: 0.5
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use_gpu: False
device: 0
backend: torch
federate:
mode: standalone
client_num: 2
model:
type: random_forest
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 4
data:
root: data/
type: abalone
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
criterion:
type: RegressionMSELoss
trainer:
type: verticaltrainer
vertical:
use: True
dims: [4, 8]
algo: 'rf'
data_size_for_debug: 1500
feature_subsample_ratio: 0.5
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data:
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 4000
batch_size: 2000
criterion:
type: RegressionMSELoss
trainer:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use_gpu: False
device: 0
backend: torch
federate:
mode: standalone
client_num: 2
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: adult
splits: [1.0, 0.0]
dataloader:
type: raw
batch_size: 2000
criterion:
type: CrossEntropyLoss
trainer:
type: verticaltrainer
train:
optimizer:
# learning rate for xgb model
eta: 0.5
vertical:
use: True
dims: [7, 14]
algo: 'xgb'
eval_protection: 'he'
data_size_for_debug: 2000
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use_gpu: False
device: 0
backend: torch
federate:
mode: standalone
client_num: 2
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: abalone
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
criterion:
type: RegressionMSELoss
trainer:
type: verticaltrainer
train:
optimizer:
# learning rate for xgb model
eta: 0.5
vertical:
use: True
dims: [4, 8]
algo: 'xgb'
protect_object: 'feature_order'
protect_method: 'dp'
protect_args: [ { 'bucket_num': 100, 'epsilon': 10 } ]
data_size_for_debug: 2000
eval:
freq: 5
best_res_update_round_wise_key: test_loss
7 changes: 7 additions & 0 deletions federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from federatedscope.core.workers import Client
from federatedscope.core.message import Message
from federatedscope.vertical_fl.Paillier import \
abstract_paillier

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +32,11 @@ def __init__(self,
self.msg_buffer = {'train': {}, 'eval': {}}
self.client_num = self._cfg.federate.client_num

if self._cfg.vertical.eval_protection == 'he':
keys = abstract_paillier.generate_paillier_keypair(
n_length=self._cfg.vertical.key_size)
self.public_key, self.private_key = keys

self.feature_order = None
self.merged_feature_order = None

Expand Down
5 changes: 4 additions & 1 deletion federatedscope/vertical_fl/xgb_base/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
wrap_server_for_train, wrap_client_for_train
from federatedscope.vertical_fl.xgb_base.worker.evaluation_wrapper import \
wrap_server_for_evaluation, wrap_client_for_evaluation
from federatedscope.vertical_fl.xgb_base.worker.he_evaluation_wrapper\
import wrap_client_for_he_evaluation

__all__ = [
'XGBServer', 'XGBClient', 'wrap_server_for_train', 'wrap_client_for_train',
'wrap_server_for_evaluation', 'wrap_client_for_evaluation'
'wrap_server_for_evaluation', 'wrap_client_for_evaluation',
'wrap_client_for_he_evaluation'
]
Loading

0 comments on commit 7c4a3f2

Please sign in to comment.