Skip to content

Commit

Permalink
supports id_sensitive=True
Browse files Browse the repository at this point in the history
  • Loading branch information
liangs6212 committed Aug 26, 2022
1 parent fdf18bc commit 91cb229
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
11 changes: 9 additions & 2 deletions python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ def predict(self, data, batch_size=32, quantize=False):
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False)
shuffle=False,
is_predict=True)
# data transform
is_local_data = isinstance(data, (np.ndarray, DataLoader))
if is_local_data and self.distributed:
Expand Down Expand Up @@ -542,7 +543,8 @@ def predict_with_onnx(self, data, batch_size=32, quantize=False):
horizon=self.data_config['future_seq_len'],
feature_col=data.roll_feature,
target_col=data.roll_target,
shuffle=False)
shuffle=False,
is_predict=True)
if quantize:
return _pytorch_fashion_inference(model=self.onnxruntime_int8,
input_data=data,
Expand Down Expand Up @@ -1151,6 +1153,11 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len):
fixMsg="Do not specify past_seq_len and future seq_len "
"or call tsdataset.roll method again and specify time step")

if tsdataset.id_sensitive:
_id_list_len = len(tsdataset.id_col)
input_feature_num *= _id_list_len
output_feature_num *= _id_list_len

return cls(past_seq_len=past_seq_len,
future_seq_len=future_seq_len,
input_feature_num=input_feature_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def check_time_steps(tsdataset, past_seq_len):
fixMsg="Do not specify past_seq_len "
"or call tsdataset.roll method again and specify time step.")

if tsdataset.id_sensitive:
_id_list_len = len(tsdataset._id_list)
input_feature_num *= _id_list_len
output_feature_num *= _id_list_len

return cls(past_seq_len=past_seq_len,
input_feature_num=input_feature_num,
output_feature_num=output_feature_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len):
f"but found {past_seq_len, future_seq_len}",
fixMsg="Do not specify past_seq_len and future seq_len "
"or call tsdataset.roll method again and specify time step")

invalidInputError(tsdataset.id_sensitive and len(tsdataset._id_list) > 1,
"NBeats only supports univariate forecasting.")

return cls(past_seq_len=past_seq_len,
future_seq_len=future_seq_len,
Expand Down

0 comments on commit 91cb229

Please sign in to comment.