Skip to content

Commit

Permalink
new params is_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
liangs6212 committed Aug 30, 2022
1 parent 91cb229 commit 339609b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
7 changes: 5 additions & 2 deletions python/chronos/src/bigdl/chronos/data/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, data, **schema):
self.scaler_index = [i for i in range(len(self.target_col))]
self.id_sensitive = None
self._has_generate_agg_feature = False
self.is_predict = False
self._check_basic_invariants()

self._id_list = list(np.unique(self.df[self.id_col]))
Expand Down Expand Up @@ -609,7 +610,8 @@ def roll(self,
# horizon_time is only for time_enc, the time_enc numpy ndarray won't have any
# shape change when the dataset is for prediction.
horizon_time = self.horizon
if is_predict:
self.is_predict = is_predict
if self.is_predict:
self.horizon = 0

if self.lookback == 'auto':
Expand Down Expand Up @@ -785,6 +787,7 @@ def to_torch_data_loader(self,
"of lookback and horizon, while get lookback+horizon="
f"{need_dflen} and the length of dataset is {len(self.df)}.")

self.is_predict = is_predict
torch_dataset = RollDataset(self.df,
dt_col=self.dt_col,
freq=self._freq,
Expand All @@ -795,7 +798,7 @@ def to_torch_data_loader(self,
id_col=self.id_col,
time_enc=time_enc,
label_len=label_len,
is_predict=is_predict)
is_predict=self.is_predict)
# TODO gen_rolling_feature and gen_global_feature will be support later
self.roll_target = target_col
self.roll_feature = feature_col
Expand Down
12 changes: 7 additions & 5 deletions python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def predict(self, data, batch_size=32, quantize=False):
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=True)
is_predict=data.is_predict)
# data transform
is_local_data = isinstance(data, (np.ndarray, DataLoader))
if is_local_data and self.distributed:
Expand Down Expand Up @@ -544,7 +544,7 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False):
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=True)
is_predict=data.is_predict)
if quantize:
return _pytorch_fashion_inference(model=self.onnxruntime_int8,
input_data=data,
Expand Down Expand Up @@ -654,7 +654,8 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False)
shuffle=False,
is_predict=False)
is_local_data = isinstance(data, (tuple, DataLoader))
if not is_local_data and not self.distributed:
data = xshard_to_np(data, mode="fit")
Expand Down Expand Up @@ -754,7 +755,8 @@ def evaluate_with_onnx(self, data,
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False)
shuffle=False,
is_predict=False)
if isinstance(data, DataLoader):
input_data = data
target = np.concatenate(tuple(val[1] for val in data), axis=0)
Expand Down Expand Up @@ -1098,7 +1100,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa
"""
Build a Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: int or "auto", Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs):
'''
Build a LSTM Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa
"""
Build a NBeats Forecaster Model.
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'.
Expand Down

0 comments on commit 339609b

Please sign in to comment.