From bc812fa670eca722283c9558d514bb6f58d8d87e Mon Sep 17 00:00:00 2001 From: liangs6212 Date: Fri, 26 Aug 2022 14:11:49 +0800 Subject: [PATCH 1/5] supports id_sensitive=True --- .../src/bigdl/chronos/forecaster/base_forecaster.py | 11 +++++++++-- .../src/bigdl/chronos/forecaster/lstm_forecaster.py | 5 +++++ .../src/bigdl/chronos/forecaster/nbeats_forecaster.py | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index 1c6d622c8b9..d9f43fd826b 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -449,7 +449,8 @@ def predict(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False) + shuffle=False, + is_predict=True) # data transform is_local_data = isinstance(data, (np.ndarray, DataLoader)) if is_local_data and self.distributed: @@ -542,7 +543,8 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False) + shuffle=False, + is_predict=True) if quantize: return _pytorch_fashion_inference(model=self.onnxruntime_int8, input_data=data, @@ -1151,6 +1153,11 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset.id_col) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, input_feature_num=input_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py index 132de54ba4e..003150963f5 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py @@ -198,6 +198,11 @@ def check_time_steps(tsdataset, past_seq_len): fixMsg="Do not specify past_seq_len " "or call tsdataset.roll method again and specify time step.") + if tsdataset.id_sensitive: + _id_list_len = len(tsdataset._id_list) + input_feature_num *= _id_list_len + output_feature_num *= _id_list_len + return cls(past_seq_len=past_seq_len, input_feature_num=input_feature_num, output_feature_num=output_feature_num, diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index d674f3813c6..be667f6cffc 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -224,6 +224,9 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): f"but found {past_seq_len, future_seq_len}", fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") + + invalidInputError(tsdataset.id_sensitive and len(tsdataset._id_list) > 1, + "NBeats only supports univariate forecasting.") return cls(past_seq_len=past_seq_len, future_seq_len=future_seq_len, From d6a3c4d857215df2b940a1fda3be68d530e28368 Mon Sep 17 00:00:00 2001 From: liangs6212 Date: Tue, 30 Aug 2022 10:22:39 +0800 Subject: [PATCH 2/5] new params is_predict --- python/chronos/src/bigdl/chronos/data/tsdataset.py | 7 +++++-- .../src/bigdl/chronos/forecaster/base_forecaster.py | 12 +++++++----- .../src/bigdl/chronos/forecaster/lstm_forecaster.py | 2 +- .../bigdl/chronos/forecaster/nbeats_forecaster.py | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/chronos/src/bigdl/chronos/data/tsdataset.py b/python/chronos/src/bigdl/chronos/data/tsdataset.py index 4041cd999f1..15170b3288e 100644 --- a/python/chronos/src/bigdl/chronos/data/tsdataset.py +++ b/python/chronos/src/bigdl/chronos/data/tsdataset.py @@ -59,6 +59,7 @@ def __init__(self, data, **schema): self.scaler_index = [i for i in range(len(self.target_col))] self.id_sensitive = None self._has_generate_agg_feature = False + self.is_predict = False self._check_basic_invariants() self._id_list = list(np.unique(self.df[self.id_col])) @@ -609,7 +610,8 @@ def roll(self, # horizon_time is only for time_enc, the time_enc numpy ndarray won't have any # shape change when the dataset is for prediction. horizon_time = self.horizon - if is_predict: + self.is_predict = is_predict + if self.is_predict: self.horizon = 0 if self.lookback == 'auto': @@ -785,6 +787,7 @@ def to_torch_data_loader(self, "of lookback and horizon, while get lookback+horizon=" f"{need_dflen} and the length of dataset is {len(self.df)}.") + self.is_predict = is_predict torch_dataset = RollDataset(self.df, dt_col=self.dt_col, freq=self._freq, @@ -795,7 +798,7 @@ def to_torch_data_loader(self, id_col=self.id_col, time_enc=time_enc, label_len=label_len, - is_predict=is_predict) + is_predict=self.is_predict) # TODO gen_rolling_feature and gen_global_feature will be support later self.roll_target = target_col self.roll_feature = feature_col diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index d9f43fd826b..d74a5dbd48f 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -450,7 +450,7 @@ def predict(self, data, batch_size=32, quantize=False): feature_col=data.roll_feature, target_col=data.roll_target, shuffle=False, - is_predict=True) + is_predict=data.is_predict) # data transform is_local_data = isinstance(data, (np.ndarray, DataLoader)) if is_local_data and self.distributed: @@ -544,7 +544,7 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False): feature_col=data.roll_feature, target_col=data.roll_target, shuffle=False, - is_predict=True) + is_predict=data.is_predict) if quantize: return _pytorch_fashion_inference(model=self.onnxruntime_int8, input_data=data, @@ -654,7 +654,8 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False) + shuffle=False, + is_predict=False) is_local_data = isinstance(data, (tuple, DataLoader)) if not is_local_data and not self.distributed: data = xshard_to_np(data, mode="fit") @@ -754,7 +755,8 @@ def evaluate_with_onnx(self, data, horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False) + shuffle=False, + is_predict=False) if isinstance(data, DataLoader): input_data = data target = np.concatenate(tuple(val[1] for val in data), axis=0) @@ -1098,7 +1100,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: int or "auto", Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. diff --git a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py index 003150963f5..95b006fc09f 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/lstm_forecaster.py @@ -152,7 +152,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs): ''' Build a LSTM Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index be667f6cffc..701dc10eb6c 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -174,7 +174,7 @@ def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwa """ Build a NBeats Forecaster Model. - :param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance. + :param tsdataset: Train tsdataset, a bigdl.chronos.data.tsdataset.TSDataset instance. :param past_seq_len: Specify the history time steps (i.e. lookback). Do not specify the 'past_seq_len' if your tsdataset has called the 'TSDataset.roll' method or 'TSDataset.to_torch_data_loader'. From 04f689e33cf34e577122f9d4d43aa8689d27135d Mon Sep 17 00:00:00 2001 From: liangs6212 Date: Tue, 30 Aug 2022 16:50:48 +0800 Subject: [PATCH 3/5] rollback is_predict --- python/chronos/src/bigdl/chronos/data/tsdataset.py | 7 ++----- .../src/bigdl/chronos/forecaster/base_forecaster.py | 12 ++++-------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/chronos/src/bigdl/chronos/data/tsdataset.py b/python/chronos/src/bigdl/chronos/data/tsdataset.py index 15170b3288e..4041cd999f1 100644 --- a/python/chronos/src/bigdl/chronos/data/tsdataset.py +++ b/python/chronos/src/bigdl/chronos/data/tsdataset.py @@ -59,7 +59,6 @@ def __init__(self, data, **schema): self.scaler_index = [i for i in range(len(self.target_col))] self.id_sensitive = None self._has_generate_agg_feature = False - self.is_predict = False self._check_basic_invariants() self._id_list = list(np.unique(self.df[self.id_col])) @@ -610,8 +609,7 @@ def roll(self, # horizon_time is only for time_enc, the time_enc numpy ndarray won't have any # shape change when the dataset is for prediction. horizon_time = self.horizon - self.is_predict = is_predict - if self.is_predict: + if is_predict: self.horizon = 0 if self.lookback == 'auto': @@ -787,7 +785,6 @@ def to_torch_data_loader(self, "of lookback and horizon, while get lookback+horizon=" f"{need_dflen} and the length of dataset is {len(self.df)}.") - self.is_predict = is_predict torch_dataset = RollDataset(self.df, dt_col=self.dt_col, freq=self._freq, @@ -798,7 +795,7 @@ def to_torch_data_loader(self, id_col=self.id_col, time_enc=time_enc, label_len=label_len, - is_predict=self.is_predict) + is_predict=is_predict) # TODO gen_rolling_feature and gen_global_feature will be support later self.roll_target = target_col self.roll_feature = feature_col diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index d74a5dbd48f..cd0fb42effe 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -449,8 +449,7 @@ def predict(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=data.is_predict) + shuffle=False) # data transform is_local_data = isinstance(data, (np.ndarray, DataLoader)) if is_local_data and self.distributed: @@ -543,8 +542,7 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=data.is_predict) + shuffle=False) if quantize: return _pytorch_fashion_inference(model=self.onnxruntime_int8, input_data=data, @@ -654,8 +652,7 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=False) + shuffle=False) is_local_data = isinstance(data, (tuple, DataLoader)) if not is_local_data and not self.distributed: data = xshard_to_np(data, mode="fit") @@ -755,8 +752,7 @@ def evaluate_with_onnx(self, data, horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=False) + shuffle=False) if isinstance(data, DataLoader): input_data = data target = np.concatenate(tuple(val[1] for val in data), axis=0) From ee978e0d6b925de7aa6a206e82fa8f69f745190f Mon Sep 17 00:00:00 2001 From: liangs6212 Date: Sat, 3 Sep 2022 12:48:36 +0800 Subject: [PATCH 4/5] fix known issues --- .../chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index 701dc10eb6c..495b9d9dcac 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -225,7 +225,7 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") - invalidInputError(tsdataset.id_sensitive and len(tsdataset._id_list) > 1, + invalidInputError(not all([tsdataset.id_sensitive, len(tsdataset._id_list) > 1]), "NBeats only supports univariate forecasting.") return cls(past_seq_len=past_seq_len, From ad7021dc53dbb90f7c6d0c1e7349a8944533693b Mon Sep 17 00:00:00 2001 From: liangs6212 Date: Sat, 3 Sep 2022 12:50:04 +0800 Subject: [PATCH 5/5] fix code style --- .../chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py index 495b9d9dcac..0388ad057a0 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/nbeats_forecaster.py @@ -224,7 +224,7 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len): f"but found {past_seq_len, future_seq_len}", fixMsg="Do not specify past_seq_len and future seq_len " "or call tsdataset.roll method again and specify time step") - + invalidInputError(not all([tsdataset.id_sensitive, len(tsdataset._id_list) > 1]), "NBeats only supports univariate forecasting.")