diff --git a/python/chronos/src/bigdl/chronos/data/tsdataset.py b/python/chronos/src/bigdl/chronos/data/tsdataset.py index bcb274cd58b2..3b8c6a8201d7 100644 --- a/python/chronos/src/bigdl/chronos/data/tsdataset.py +++ b/python/chronos/src/bigdl/chronos/data/tsdataset.py @@ -59,7 +59,6 @@ def __init__(self, data, **schema): self.scaler_index = [i for i in range(len(self.target_col))] self.id_sensitive = None self._has_generate_agg_feature = False - self.is_predict = False self._check_basic_invariants() self._id_list = list(np.unique(self.df[self.id_col])) @@ -610,8 +609,7 @@ def roll(self, # horizon_time is only for time_enc, the time_enc numpy ndarray won't have any # shape change when the dataset is for prediction. horizon_time = self.horizon - self.is_predict = is_predict - if self.is_predict: + if is_predict: self.horizon = 0 if self.lookback == 'auto': @@ -787,7 +785,6 @@ def to_torch_data_loader(self, "of lookback and horizon, while get lookback+horizon=" f"{need_dflen} and the length of dataset is {len(self.df)}.") - self.is_predict = is_predict torch_dataset = RollDataset(self.df, dt_col=self.dt_col, freq=self._freq, @@ -798,7 +795,7 @@ def to_torch_data_loader(self, id_col=self.id_col, time_enc=time_enc, label_len=label_len, - is_predict=self.is_predict) + is_predict=is_predict) # TODO gen_rolling_feature and gen_global_feature will be support later self.roll_target = target_col self.roll_feature = feature_col diff --git a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py index 0101ff2fa5eb..003fa37c4d93 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py @@ -449,8 +449,7 @@ def predict(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=data.is_predict) + shuffle=False) # data transform is_local_data = isinstance(data, (np.ndarray, DataLoader)) if is_local_data and self.distributed: @@ -543,8 +542,7 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False): horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=data.is_predict) + shuffle=False) if quantize: return _pytorch_fashion_inference(model=self.onnxruntime_int8, input_data=data, @@ -654,8 +652,7 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=False) + shuffle=False) is_local_data = isinstance(data, (tuple, DataLoader)) if not is_local_data and not self.distributed: data = xshard_to_np(data, mode="fit") @@ -755,8 +752,7 @@ def evaluate_with_onnx(self, data, horizon=self.data_config['future_seq_len'], feature_col=data.roll_feature, target_col=data.roll_target, - shuffle=False, - is_predict=False) + shuffle=False) if isinstance(data, DataLoader): input_data = data target = np.concatenate(tuple(val[1] for val in data), axis=0)