Skip to content

Commit

Permalink
add feature importance to xgb algo (#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 authored Nov 21, 2022
1 parent 42707fe commit cfab8c7
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 3 deletions.
7 changes: 7 additions & 0 deletions federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,10 @@ def update_best_result(self, best_results, new_results, results_type):
logger.error(
"cfg.wandb.use=True but not install the wandb package")
exit()

def add_items_to_best_result(self, best_results, new_results,
results_type):
"""
Add a new key: value item (results-type: new_results) to best_result
"""
best_results[results_type] = new_results
2 changes: 1 addition & 1 deletion federatedscope/vertical_fl/worker/vertical_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def callback_funcs_for_encryped_gradient(self, message: Message):
formatted_logs = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Global-Eval-Server #',
role='Server #',
forms=self._cfg.eval.report)
logger.info(formatted_logs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ train:
max_tree_depth: 3
xgb_base:
use: True
use_bin: False
use_bin: True
dims: [7, 14]
eval:
freq: 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def callback_func_for_split(self, message: Message):
tree_num, node_num, split_ref = message.content
feature_idx = split_ref['feature_idx'] - self.client.feature_list[
self.client.ID - 1]
self.client.feature_importance[feature_idx] += 1
value_idx = split_ref['value_idx']
# feature_value = sorted(self.client.x[:, feature_idx])[value_idx]
feature_value = self.client.x[:, feature_idx][
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def callback_func_for_split(self, message: Message):
feature_idx = split_ref['feature_idx'] - self.client.feature_list[
self.client.ID - 1]
bin_idx = split_ref['bin_idx']
self.client.feature_importance[feature_idx] += 1
# feature_value = sorted(self.x[:, feature_idx])[value_idx]
if bin_idx == 0:
feature_value = self._min(
Expand Down
31 changes: 30 additions & 1 deletion federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self,

self.feature_order = [0] * self.my_num_of_feature

self.feature_importance = [0] * self.my_num_of_feature
# self.ss = AdditiveSecretSharing(shared_party_num=self.num_of_parties)
# self.ns = Node_split()
# self.fs = Feature_sort()
Expand All @@ -91,9 +92,10 @@ def __init__(self,
self.register_handlers('model_para', self.callback_func_for_model_para)
self.register_handlers('data_sample',
self.callback_func_for_data_sample)

self.register_handlers('compute_next_node',
self.callback_func_for_compute_next_node)
self.register_handlers('send_feature_importance',
self.callback_func_for_send_feature_importance)

# save the order of values in each feature
def order_feature(self, data):
Expand Down Expand Up @@ -125,6 +127,8 @@ def callback_func_for_model_para(self, message: Message):
# init y_hat
self.y_hat = np.random.uniform(low=0.0, high=1.0, size=len(self.y))
# self.y_hat = np.zeros(len(self.y))
logger.info(f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
self.comm_manager.send(
Message(
msg_type='data_sample',
Expand Down Expand Up @@ -200,6 +204,23 @@ def compute_weight(self, tree_num, node_num):
state=self.state,
receiver=self.server_id,
content=None))
self.comm_manager.send(
Message(msg_type='send_feature_importance',
sender=self.ID,
state=self.state,
receiver=[
each for each in list(
self.comm_manager.neighbors.keys())
if each != self.server_id
],
content=None))
self.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.ID,
state=self.state,
receiver=self.server_id,
content=self.feature_importance))

else:
self.state += 1
logger.info(
Expand All @@ -214,3 +235,11 @@ def compute_weight(self, tree_num, node_num):
node_num].weight * self.tree_list[tree_num][
node_num].indicator
self.compute_weight(tree_num, node_num + 1)

def callback_func_for_send_feature_importance(self, message: Message):
self.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.ID,
state=self.state,
receiver=self.server_id,
content=self.feature_importance))
15 changes: 15 additions & 0 deletions federatedscope/vertical_fl/xgb_base/worker/XGBServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ def __init__(self,
self.tree_list = [
Tree(self.max_tree_depth).tree for _ in range(self.num_of_trees)
]
self.feature_importance_dict = dict()

self.register_handlers('test', self.callback_func_for_test)
self.register_handlers('test_result',
self.callback_func_for_test_result)
self.register_handlers('feature_importance',
self.callback_func_for_feature_importance)

def trigger_for_start(self):
if self.check_client_join_in():
Expand All @@ -64,6 +67,14 @@ def broadcast_model_para(self):
content=(self.lambda_, self.gamma, self.num_of_trees,
self.max_tree_depth)))

def callback_func_for_feature_importance(self, message: Message):
feature_importance = message.content
self.feature_importance_dict[message.sender] = feature_importance
if len(self.feature_importance_dict) == self.num_of_parties:
self.feature_importance_dict = dict(
sorted(self.feature_importance_dict.items(),
key=lambda x: x[0]))

def callback_func_for_test(self, message: Message):
test_x = self.data['test']['x']
test_y = self.data['test']['y']
Expand All @@ -88,6 +99,10 @@ def callback_func_for_test_result(self, message: Message):
self._monitor.update_best_result(self.best_results,
metrics,
results_type='server_global_eval')
self._monitor.add_items_to_best_result(
self.best_results,
self.feature_importance_dict,
results_type='feature_importance')
formatted_logs = self._monitor.format_eval_res(
metrics,
rnd=self.state,
Expand Down

0 comments on commit cfab8c7

Please sign in to comment.