diff --git a/examples/time_series_modeling/time_series_modeling/neural_prophet.py b/examples/time_series_modeling/time_series_modeling/neural_prophet.py index be6bbeddf..7f1426bbd 100644 --- a/examples/time_series_modeling/time_series_modeling/neural_prophet.py +++ b/examples/time_series_modeling/time_series_modeling/neural_prophet.py @@ -10,7 +10,7 @@ # First, we import necessary libraries to run the training workflow. import pandas as pd -from flytekit import current_context, task, workflow, Deck, ImageSpec +from flytekit import Deck, ImageSpec, current_context, task, workflow from flytekit.types.file import FlyteFile # %% [markdown] @@ -30,7 +30,7 @@ ], # This registry is for a local flyte demo cluster. Replace this with your # own registry, e.g. `docker.io//` - registry="localhost:30000" + registry="localhost:30000", ) # %% [markdown] @@ -41,16 +41,19 @@ URL = "https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv" + @task(container_image=image) def load_data() -> pd.DataFrame: return pd.read_csv(URL) + # %% [markdown] # ## Model Training Task # # This task trains the Neural Prophet model on the loaded data. # We train the model in the hourly frequency for ten epochs. + @task(container_image=image) def train_model(df: pd.DataFrame) -> FlyteFile: from neuralprophet import NeuralProphet, save @@ -62,12 +65,14 @@ def train_model(df: pd.DataFrame) -> FlyteFile: save(model, model_fp) return FlyteFile(model_fp) + # %% [markdown] # ## Forecasting Task # # This task loads the trained model, makes predictions, and visualizes the # results using a Flyte Deck. + @task( container_image=image, enable_deck=True, @@ -77,7 +82,7 @@ def make_forecast(df: pd.DataFrame, model_file: FlyteFile) -> pd.DataFrame: model_file.download() model = load(model_file.path) - + # Create a new dataframe reaching 365 into the future # for our forecast, n_historic_predictions also shows historic data df_future = model.make_future_dataframe( @@ -95,12 +100,14 @@ def make_forecast(df: pd.DataFrame, model_file: FlyteFile) -> pd.DataFrame: return forecast + # %% [markdown] # ## Main Workflow # # Finally, this workflow orchestrates the entire process: loading data, # training the model, and making forecasts. + @workflow def main() -> pd.DataFrame: df = load_data()