From 9e45908c1bb57451698df7848269fc4f8b227ecf Mon Sep 17 00:00:00 2001 From: liangs6212 <80952198+liangs6212@users.noreply.github.com> Date: Sat, 10 Sep 2022 19:14:17 +0800 Subject: [PATCH] Chronos: Add support for `id_sensitive`=True to Forecaster.from_tsdataset (#5551) * supports id_sensitive=True * new params is_predict * rollback is_predict * fix known issues * fix code style --- .../src/bigdl/chronos/forecaster/base_forecaster.py | 7 ++++++- .../src/bigdl/chronos/forecaster/lstm_forecaster.py | 7 ++++++- .../src/bigdl/chronos/forecaster/nbeats_forecaster.py | 5 ++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index a83eba92cca..71370f41bee 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -1096,7 +1096,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: int or "auto", Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -1151,6 +1151,11 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset.id_col) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, input_feature_num=input_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py index 132de54ba4e..95b006fc09f 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py @@ -152,7 +152,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs): ''' Build a LSTM Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -198,6 +198,11 @@ def check_time_steps(tsdataset, past_seq_len): fixMsg="Do not specify past_seq_len " "or call tsdataset.roll method again and specify time step.") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset._id_list) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, input_feature_num=input_feature_num, output_feature_num=output_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index d674f3813c6..0388ad057a0 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -174,7 +174,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a NBeats Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -225,6 +225,9 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + invalidInputError(not all([tsdataset.id_sensitive, len(tsdataset._id_list) > 1]), + "NBeats only supports univariate forecasting.") + return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, **kwargs)