Skip to content

Commit

Permalink
Merge pull request #139 from ibm-granite/pipeline_preprocess
Browse files Browse the repository at this point in the history
Pipeline preprocessor behavior
  • Loading branch information
wgifford authored Sep 19, 2024
2 parents c889210 + 7076577 commit f124336
Show file tree
Hide file tree
Showing 7 changed files with 906 additions and 848 deletions.

Large diffs are not rendered by default.

1,676 changes: 866 additions & 810 deletions services/inference/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion services/inference/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ __version_tuple__ = (0, 0, 0)
# including 3.9 causes poetry lock to run forever
python = ">=3.10,<3.13"
numpy = { version = "<2" }
tsfm_public = { git = "https://github.com/IBM/tsfm.git", tag = "v0.2.6" }
tsfm_public = { git = "https://github.com/IBM-granite/granite-tsfm.git", tag = "v0.2.9" }

# trying to pick up cpu version for tsfminference
# to make image smaller
Expand Down
19 changes: 11 additions & 8 deletions services/inference/tsfminference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictO
# train to estimate freq if not available
preprocessor.train(data)

LOGGER.info(f"Data frequency determined: {preprocessor.freq}")

# warn if future data is not provided, but is needed by the model
if preprocessor.exogenous_channel_indices and future_data is None:
raise ValueError(
Expand All @@ -126,6 +128,7 @@ def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictO
explode_forecasts=True,
feature_extractor=preprocessor,
add_known_ground_truth=False,
freq=preprocessor.freq,
)

# truncate data length when exploding
Expand All @@ -135,16 +138,16 @@ def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictO
# data, id_columns=input.id_columns, start_index=-context_length
# )

test_data = preprocessor.preprocess(data)
# test_data = preprocessor.preprocess(data)

if future_data is not None:
# future data needs some values for targets, but they are unused
# Eventually this will be part of the forecast pipeline.
future_data[input_payload.target_columns] = 0
future_data = preprocessor.preprocess(future_data)
future_data.drop(columns=input_payload.target_columns)
# if future_data is not None:
# # future data needs some values for targets, but they are unused
# # Eventually this will be part of the forecast pipeline.
# future_data[input_payload.target_columns] = 0
# future_data = preprocessor.preprocess(future_data)
# future_data.drop(columns=input_payload.target_columns)

forecasts = forecast_pipeline(test_data, future_time_series=future_data, inverse_scale_outputs=True)
forecasts = forecast_pipeline(data, future_time_series=future_data, inverse_scale_outputs=True)

return PredictOutput(
model_id=input_payload.model_id,
Expand Down
2 changes: 1 addition & 1 deletion tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(patchtst_model):
inverse_scale_outputs=True,
)

forecasts = forecast_pipeline(tsp.preprocess(test_data))
forecasts = forecast_pipeline(test_data)

assert forecasts.shape == (
test_end_index - test_start_index - context_length + 1,
Expand Down
18 changes: 11 additions & 7 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Hugging Face Pipeline for Time Series Tasks"""

import inspect
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -116,13 +116,11 @@ def __init__(
self,
model: Union["PreTrainedModel"],
*args,
freq: Optional[str] = None,
explode_forecasts: bool = False,
inverse_scale_outputs: bool = True,
add_known_ground_truth: bool = True,
**kwargs,
):
kwargs["freq"] = freq
kwargs["explode_forecasts"] = explode_forecasts
kwargs["inverse_scale_outputs"] = inverse_scale_outputs
kwargs["add_known_ground_truth"] = add_known_ground_truth
Expand All @@ -142,10 +140,6 @@ def __init__(
if p not in kwargs:
kwargs[p] = getattr(kwargs["feature_extractor"], p)

# get freq from kwargs or the preprocessor
if "freq" not in kwargs:
kwargs["freq"] = kwargs["feature_extractor"].freq

if "context_length" not in kwargs:
kwargs["context_length"] = model.config.context_length

Expand Down Expand Up @@ -331,6 +325,7 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li
prediction_length = kwargs.get("prediction_length")
timestamp_column = kwargs.get("timestamp_column")
id_columns = kwargs.get("id_columns")
target_columns = kwargs.get("target_columns")
# context_length = kwargs.get("context_length")

# use the feature extractor here
Expand All @@ -343,6 +338,9 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li

future_time_series = kwargs.pop("future_time_series", None)

if self.feature_extractor:
time_series = self.feature_extractor.preprocess(time_series)

if future_time_series is not None:
if isinstance(future_time_series, str):
future_time_series = pd.read_csv(
Expand Down Expand Up @@ -370,6 +368,12 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li
f"If provided, `future_time_series` data should cover the prediction length for each of the time series in the test dataset. Received data of length {future_time_series.shape[0]} but expected {prediction_length * id_count}"
)

if self.feature_extractor:
# future data needs some values for targets, but they are unused
future_time_series[target_columns] = 0
future_time_series = self.feature_extractor.preprocess(future_time_series)
future_time_series.drop(columns=target_columns)

time_series = pd.concat((time_series, future_time_series), axis=0)
else:
# no additional exogenous data provided, extend with empty periods
Expand Down
2 changes: 2 additions & 0 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def plot_predictions(

if indices is None:
l = len(predictions_df)
num_plots = min(num_plots, l)
indices = np.random.choice(l, size=num_plots, replace=False)
predictions_subset = [predictions_df.iloc[i] for i in indices]

Expand All @@ -297,6 +298,7 @@ def plot_predictions(

with torch.no_grad():
if indices is None:
num_plots = min(num_plots, len(dset))
indices = np.random.choice(len(dset), size=num_plots, replace=False)

signature = inspect.signature(model.forward)
Expand Down

0 comments on commit f124336

Please sign in to comment.