Skip to content

Commit

Permalink
Chronos: Add support for id_sensitive=True to Forecaster.from_tsdat…
Browse files Browse the repository at this point in the history
…aset (intel-analytics#5551)

* supports id_sensitive=True

* new params is_predict

* rollback is_predict

* fix known issues

* fix code style
  • Loading branch information
liangs6212 authored and ForJadeForest committed Sep 20, 2022
1 parent 7fc1cb7 commit 9cd288d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa
"""
Build a Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: int or "auto", Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down Expand Up @@ -1151,6 +1151,11 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len):
fixMsg="Do not specify past_seq_len and future seq_len "
"or call tsdataset.roll method again and specify time step")

if tsdataset.id_sensitive:
_id_list_len = len(tsdataset.id_col)
input_feature_num *= _id_list_len
output_feature_num *= _id_list_len

return cls(past_seq_len=past_seq_len,
future_seq_len=future_seq_len,
input_feature_num=input_feature_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs):
'''
Build a LSTM Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down Expand Up @@ -198,6 +198,11 @@ def check_time_steps(tsdataset, past_seq_len):
fixMsg="Do not specify past_seq_len "
"or call tsdataset.roll method again and specify time step.")

if tsdataset.id_sensitive:
_id_list_len = len(tsdataset._id_list)
input_feature_num *= _id_list_len
output_feature_num *= _id_list_len

return cls(past_seq_len=past_seq_len,
input_feature_num=input_feature_num,
output_feature_num=output_feature_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa
"""
Build a NBeats Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down Expand Up @@ -225,6 +225,9 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len):
fixMsg="Do not specify past_seq_len and future seq_len "
"or call tsdataset.roll method again and specify time step")

invalidInputError(not all([tsdataset.id_sensitive, len(tsdataset._id_list) > 1]),
"NBeats only supports univariate forecasting.")

return cls(past_seq_len=past_seq_len,
future_seq_len=future_seq_len,
**kwargs)

0 comments on commit 9cd288d

Please sign in to comment.