Skip to content

Commit

Permalink
Chronos: reset batch_size according to forecaster.num_processes (#4480
Browse files Browse the repository at this point in the history
)

* change batch_size

* change batch_size

* add warning if batch_size % num_processes != 0

* fix code style problem

* show users the exact value of batch_size and self.num_process

* make the value shows only in the warning message

* fix pep8

* batch_size=max(1, xxx)
  • Loading branch information
Rebekk authored May 30, 2022
1 parent ebba044 commit 4efc267
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,14 @@ def fit(self, data, epochs=1, batch_size=32):

# data transformation
if isinstance(data, tuple):
if batch_size % self.num_processes != 0:
warnings.warn("'batch_size' cannot be divided with no remainder by "
"'self.num_processes'. We got 'batch_size' = {} and "
"'self.num_processes' = {}".
format(batch_size, self.num_processes))
data = DataLoader(TensorDataset(torch.from_numpy(data[0]),
torch.from_numpy(data[1])),
batch_size=batch_size,
batch_size=max(1, batch_size//self.num_processes),
shuffle=True)

# Trainer init and fitting
Expand Down

0 comments on commit 4efc267

Please sign in to comment.