Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add feature importance to xgb algo #438

Merged
merged 6 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,10 @@ def update_best_result(self, best_results, new_results, results_type):
logger.error(
"cfg.wandb.use=True but not install the wandb package")
exit()

def add_items_to_best_result(self, best_results, new_results,
results_type):
"""
Add a new key: value item (results-type: new_results) to best_result
"""
best_results[results_type] = new_results
2 changes: 1 addition & 1 deletion federatedscope/vertical_fl/worker/vertical_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def callback_funcs_for_encryped_gradient(self, message: Message):
formatted_logs = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Global-Eval-Server #',
role='Server #',
forms=self._cfg.eval.report)
logger.info(formatted_logs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ train:
max_tree_depth: 3
xgb_base:
use: True
use_bin: False
use_bin: True
dims: [7, 14]
eval:
freq: 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def callback_func_for_split(self, message: Message):
tree_num, node_num, split_ref = message.content
feature_idx = split_ref['feature_idx'] - self.client.feature_list[
self.client.ID - 1]
self.client.feature_importance[feature_idx] += 1
value_idx = split_ref['value_idx']
# feature_value = sorted(self.client.x[:, feature_idx])[value_idx]
feature_value = self.client.x[:, feature_idx][
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def callback_func_for_split(self, message: Message):
feature_idx = split_ref['feature_idx'] - self.client.feature_list[
self.client.ID - 1]
bin_idx = split_ref['bin_idx']
self.client.feature_importance[feature_idx] += 1
# feature_value = sorted(self.x[:, feature_idx])[value_idx]
if bin_idx == 0:
feature_value = self._min(
Expand Down
31 changes: 30 additions & 1 deletion federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self,

self.feature_order = [0] * self.my_num_of_feature

self.feature_importance = [0] * self.my_num_of_feature
# self.ss = AdditiveSecretSharing(shared_party_num=self.num_of_parties)
# self.ns = Node_split()
# self.fs = Feature_sort()
Expand All @@ -91,9 +92,10 @@ def __init__(self,
self.register_handlers('model_para', self.callback_func_for_model_para)
self.register_handlers('data_sample',
self.callback_func_for_data_sample)

self.register_handlers('compute_next_node',
self.callback_func_for_compute_next_node)
self.register_handlers('send_feature_importance',
self.callback_func_for_send_feature_importance)

# save the order of values in each feature
def order_feature(self, data):
Expand Down Expand Up @@ -125,6 +127,8 @@ def callback_func_for_model_para(self, message: Message):
# init y_hat
self.y_hat = np.random.uniform(low=0.0, high=1.0, size=len(self.y))
# self.y_hat = np.zeros(len(self.y))
logger.info(f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
self.comm_manager.send(
Message(
msg_type='data_sample',
Expand Down Expand Up @@ -200,6 +204,23 @@ def compute_weight(self, tree_num, node_num):
state=self.state,
receiver=self.server_id,
content=None))
self.comm_manager.send(
Message(msg_type='send_feature_importance',
sender=self.ID,
state=self.state,
receiver=[
each for each in list(
self.comm_manager.neighbors.keys())
if each != self.server_id
],
content=None))
self.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.ID,
state=self.state,
receiver=self.server_id,
content=self.feature_importance))

else:
self.state += 1
logger.info(
Expand All @@ -214,3 +235,11 @@ def compute_weight(self, tree_num, node_num):
node_num].weight * self.tree_list[tree_num][
node_num].indicator
self.compute_weight(tree_num, node_num + 1)

def callback_func_for_send_feature_importance(self, message: Message):
self.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.ID,
state=self.state,
receiver=self.server_id,
content=self.feature_importance))
15 changes: 15 additions & 0 deletions federatedscope/vertical_fl/xgb_base/worker/XGBServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ def __init__(self,
self.tree_list = [
Tree(self.max_tree_depth).tree for _ in range(self.num_of_trees)
]
self.feature_importance_dict = dict()

self.register_handlers('test', self.callback_func_for_test)
self.register_handlers('test_result',
self.callback_func_for_test_result)
self.register_handlers('feature_importance',
self.callback_func_for_feature_importance)

def trigger_for_start(self):
if self.check_client_join_in():
Expand All @@ -64,6 +67,14 @@ def broadcast_model_para(self):
content=(self.lambda_, self.gamma, self.num_of_trees,
self.max_tree_depth)))

def callback_func_for_feature_importance(self, message: Message):
feature_importance = message.content
self.feature_importance_dict[message.sender] = feature_importance
if len(self.feature_importance_dict) == self.num_of_parties:
self.feature_importance_dict = dict(
sorted(self.feature_importance_dict.items(),
key=lambda x: x[0]))

def callback_func_for_test(self, message: Message):
test_x = self.data['test']['x']
test_y = self.data['test']['y']
Expand All @@ -88,6 +99,10 @@ def callback_func_for_test_result(self, message: Message):
self._monitor.update_best_result(self.best_results,
metrics,
results_type='server_global_eval')
self._monitor.add_items_to_best_result(
self.best_results,
self.feature_importance_dict,
results_type='feature_importance')
formatted_logs = self._monitor.format_eval_res(
metrics,
rnd=self.state,
Expand Down