Skip to content

Commit

Permalink
rollback is_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
liangs6212 committed Aug 30, 2022
1 parent 339609b commit f5b3392
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
7 changes: 2 additions & 5 deletions python/chronos/src/bigdl/chronos/data/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, data, **schema):
self.scaler_index = [i for i in range(len(self.target_col))]
self.id_sensitive = None
self._has_generate_agg_feature = False
self.is_predict = False
self._check_basic_invariants()

self._id_list = list(np.unique(self.df[self.id_col]))
Expand Down Expand Up @@ -610,8 +609,7 @@ def roll(self,
# horizon_time is only for time_enc, the time_enc numpy ndarray won't have any
# shape change when the dataset is for prediction.
horizon_time = self.horizon
self.is_predict = is_predict
if self.is_predict:
if is_predict:
self.horizon = 0

if self.lookback == 'auto':
Expand Down Expand Up @@ -787,7 +785,6 @@ def to_torch_data_loader(self,
"of lookback and horizon, while get lookback+horizon="
f"{need_dflen} and the length of dataset is {len(self.df)}.")

self.is_predict = is_predict
torch_dataset = RollDataset(self.df,
dt_col=self.dt_col,
freq=self._freq,
Expand All @@ -798,7 +795,7 @@ def to_torch_data_loader(self,
id_col=self.id_col,
time_enc=time_enc,
label_len=label_len,
is_predict=self.is_predict)
is_predict=is_predict)
# TODO gen_rolling_feature and gen_global_feature will be support later
self.roll_target = target_col
self.roll_feature = feature_col
Expand Down
12 changes: 4 additions & 8 deletions python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,7 @@ def predict(self, data, batch_size=32, quantize=False):
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=data.is_predict)
shuffle=False)
# data transform
is_local_data = isinstance(data, (np.ndarray, DataLoader))
if is_local_data and self.distributed:
Expand Down Expand Up @@ -543,8 +542,7 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False):
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=data.is_predict)
shuffle=False)
if quantize:
return _pytorch_fashion_inference(model=self.onnxruntime_int8,
input_data=data,
Expand Down Expand Up @@ -654,8 +652,7 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=False)
shuffle=False)
is_local_data = isinstance(data, (tuple, DataLoader))
if not is_local_data and not self.distributed:
data = xshard_to_np(data, mode="fit")
Expand Down Expand Up @@ -755,8 +752,7 @@ def evaluate_with_onnx(self, data,
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False,
is_predict=False)
shuffle=False)
if isinstance(data, DataLoader):
input_data = data
target = np.concatenate(tuple(val[1] for val in data), axis=0)
Expand Down

0 comments on commit f5b3392

Please sign in to comment.