diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index a83eba92cca..71370f41bee 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -1096,7 +1096,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: int or "auto", Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -1151,6 +1151,11 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset.id_col) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, input_feature_num=input_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py index 132de54ba4e..95b006fc09f 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py @@ -152,7 +152,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs): ''' Build a LSTM Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -198,6 +198,11 @@ def check_time_steps(tsdataset, past_seq_len): fixMsg="Do not specify past_seq_len " "or call tsdataset.roll method again and specify time step.") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset._id_list) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, input_feature_num=input_feature_num, output_feature_num=output_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index d674f3813c6..0388ad057a0 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -174,7 +174,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a NBeats Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. @@ -225,6 +225,9 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + invalidInputError(not all([tsdataset.id_sensitive, len(tsdataset._id_list) > 1]), + "NBeats only supports univariate forecasting.") + return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, **kwargs)