Skip to content

Commit

Permalink
move "batch_size" from data to dataloader (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 authored Dec 20, 2022
1 parent 61d429e commit f837073
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ model:
data:
root: data/
type: abalone
batch_size: 4000
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 4000
criterion:
type: Regression
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ model:
data:
root: data/
type: blog
batch_size: 8000
splits: [1.0, 0.0]
dataloader:
type: raw
batch_size: 8000
criterion:
type: Regression
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ model:
data:
root: data/
type: credit
batch_size: 2000
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
criterion:
type: CrossEntropyLoss
trainer:
Expand Down
4 changes: 2 additions & 2 deletions federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
self.federate_mode = config.federate.mode

self.bin_num = config.train.optimizer.bin_num
self.batch_size = config.data.batch_size
self.batch_size = config.dataloader.batch_size

self.data = data
self.own_label = ('y' in self.data['train'])
Expand All @@ -74,7 +74,7 @@ def _init_data_related_var(self):
self.num_of_parties = self._cfg.federate.client_num

self.dataloader = batch_iter(self.data['train'],
self._cfg.data.batch_size,
self.batch_size,
shuffled=True)

self.feature_order = None
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/vertical_fl/xgb_base/worker/XGBServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
self.callback_func_for_feature_importance)

def _init_data_related_var(self):
self.batch_size = self._cfg.data.batch_size
self.batch_size = self._cfg.dataloader.batch_size
self.feature_list = [0] + self.vertical_dims
self.feature_partition = [
self.feature_list[i + 1] - self.feature_list[i]
Expand Down

0 comments on commit f837073

Please sign in to comment.