-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add xgb homoeval #549
Add xgb homoeval #549
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,9 +27,11 @@ def _process_y_pred(self, y_pred): | |
|
||
def get_metric(self, y, y_pred): | ||
y_pred = self._process_y_pred(y_pred) | ||
from sklearn import metrics | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move to top of this file |
||
auc = metrics.roc_auc_score(y, y_pred) | ||
y_pred = (y_pred >= 0.5).astype(np.float32) | ||
acc = np.sum(y_pred == y) / len(y) | ||
return {'acc': acc} | ||
return {'acc': acc, 'auc': auc} | ||
|
||
def get_loss(self, y, y_pred): | ||
y_pred = self._process_y_pred(y_pred) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from federatedscope.vertical_fl.xgb_base.worker import wrap_client_for_train, \ | ||
wrap_server_for_train, wrap_client_for_evaluation, \ | ||
wrap_server_for_evaluation | ||
from federatedscope.vertical_fl.xgb_base.worker.homo_evaluation_wrapper\ | ||
import wrap_client_for_homo_evaluation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
def wrap_vertical_server(server, config): | ||
|
@@ -13,7 +15,9 @@ def wrap_vertical_server(server, config): | |
|
||
def wrap_vertical_client(client, config): | ||
if config.vertical.algo in ['xgb', 'gbdt', 'rf']: | ||
if config.vertical.eval == 'homo': | ||
client = wrap_client_for_homo_evaluation(client) | ||
else: | ||
client = wrap_client_for_evaluation(client) | ||
client = wrap_client_for_train(client) | ||
client = wrap_client_for_evaluation(client) | ||
|
||
return client |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use_gpu: False | ||
device: 0 | ||
backend: torch | ||
federate: | ||
mode: standalone | ||
client_num: 2 | ||
model: | ||
type: gbdt_tree | ||
lambda_: 0.1 | ||
gamma: 0 | ||
num_of_trees: 10 | ||
max_tree_depth: 4 | ||
data: | ||
root: data/ | ||
type: abalone | ||
splits: [0.8, 0.2] | ||
dataloader: | ||
type: raw | ||
batch_size: 2000 | ||
criterion: | ||
type: RegressionMSELoss | ||
trainer: | ||
type: verticaltrainer | ||
train: | ||
optimizer: | ||
# learning rate for xgb model | ||
eta: 0.5 | ||
vertical: | ||
use: True | ||
dims: [4, 8] | ||
algo: 'gbdt' | ||
data_size_for_debug: 2000 | ||
feature_subsample_ratio: 0.5 | ||
eval: | ||
freq: 3 | ||
best_res_update_round_wise_key: test_loss |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use_gpu: False | ||
device: 0 | ||
backend: torch | ||
federate: | ||
mode: standalone | ||
client_num: 2 | ||
model: | ||
type: random_forest | ||
lambda_: 0.1 | ||
gamma: 0 | ||
num_of_trees: 10 | ||
max_tree_depth: 4 | ||
data: | ||
root: data/ | ||
type: abalone | ||
splits: [0.8, 0.2] | ||
dataloader: | ||
type: raw | ||
batch_size: 2000 | ||
criterion: | ||
type: RegressionMSELoss | ||
trainer: | ||
type: verticaltrainer | ||
vertical: | ||
use: True | ||
dims: [4, 8] | ||
algo: 'rf' | ||
data_size_for_debug: 1500 | ||
feature_subsample_ratio: 0.5 | ||
eval: | ||
freq: 3 | ||
best_res_update_round_wise_key: test_loss |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please change the name of this file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use_gpu: False | ||
device: 0 | ||
backend: torch | ||
federate: | ||
mode: standalone | ||
client_num: 2 | ||
model: | ||
type: xgb_tree | ||
lambda_: 0.1 | ||
gamma: 0 | ||
num_of_trees: 10 | ||
max_tree_depth: 3 | ||
data: | ||
root: data/ | ||
type: adult | ||
splits: [1.0, 0.0] | ||
dataloader: | ||
type: raw | ||
batch_size: 2000 | ||
criterion: | ||
type: CrossEntropyLoss | ||
trainer: | ||
type: verticaltrainer | ||
train: | ||
optimizer: | ||
# learning rate for xgb model | ||
eta: 0.5 | ||
vertical: | ||
use: True | ||
dims: [7, 14] | ||
algo: 'xgb' | ||
eval: 'homo' | ||
data_size_for_debug: 2000 | ||
eval: | ||
freq: 3 | ||
best_res_update_round_wise_key: test_loss |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
use_gpu: False | ||
device: 0 | ||
backend: torch | ||
federate: | ||
mode: standalone | ||
client_num: 2 | ||
model: | ||
type: xgb_tree | ||
lambda_: 0.1 | ||
gamma: 0 | ||
num_of_trees: 10 | ||
max_tree_depth: 3 | ||
data: | ||
root: data/ | ||
type: abalone | ||
splits: [0.8, 0.2] | ||
dataloader: | ||
type: raw | ||
batch_size: 2000 | ||
criterion: | ||
type: RegressionMSELoss | ||
trainer: | ||
type: verticaltrainer | ||
train: | ||
optimizer: | ||
# learning rate for xgb model | ||
eta: 0.5 | ||
vertical: | ||
use: True | ||
dims: [4, 8] | ||
algo: 'xgb' | ||
protect_object: 'feature_order' | ||
protect_method: 'dp' | ||
protect_args: [ { 'bucket_num': 100, 'epsilon': 10 } ] | ||
data_size_for_debug: 2000 | ||
eval: | ||
freq: 5 | ||
best_res_update_round_wise_key: test_loss |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,14 @@ def __init__(self, | |
self.msg_buffer = {'train': {}, 'eval': {}} | ||
self.client_num = self._cfg.federate.client_num | ||
|
||
# TODO: the following is used for homo_evaluation, | ||
# which may be put in an if condition | ||
from federatedscope.vertical_fl.Paillier import \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please move this to the top of this file |
||
abstract_paillier | ||
keys = abstract_paillier.generate_paillier_keypair( | ||
n_length=self._cfg.vertical.key_size) | ||
self.public_key, self.private_key = keys | ||
|
||
self.feature_order = None | ||
self.merged_feature_order = None | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,11 @@ | |
wrap_server_for_train, wrap_client_for_train | ||
from federatedscope.vertical_fl.xgb_base.worker.evaluation_wrapper import \ | ||
wrap_server_for_evaluation, wrap_client_for_evaluation | ||
from federatedscope.vertical_fl.xgb_base.worker.homo_evaluation_wrapper\ | ||
import wrap_client_for_homo_evaluation | ||
|
||
__all__ = [ | ||
'XGBServer', 'XGBClient', 'wrap_server_for_train', 'wrap_client_for_train', | ||
'wrap_server_for_evaluation', 'wrap_client_for_evaluation' | ||
'wrap_server_for_evaluation', 'wrap_client_for_evaluation', | ||
'wrap_client_for_homo_evaluation' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments