Skip to content

Commit

Permalink
make linter happy
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Oct 2, 2024
1 parent b5005ac commit 087b8cb
Showing 1 changed file with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# First, we import necessary libraries to run the training workflow.

import pandas as pd
from flytekit import current_context, task, workflow, Deck, ImageSpec
from flytekit import Deck, ImageSpec, current_context, task, workflow
from flytekit.types.file import FlyteFile

# %% [markdown]
Expand All @@ -30,7 +30,7 @@
],
# This registry is for a local flyte demo cluster. Replace this with your
# own registry, e.g. `docker.io/<username>/<imagename>`
registry="localhost:30000"
registry="localhost:30000",
)

# %% [markdown]
Expand All @@ -41,16 +41,19 @@

URL = "https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv"


@task(container_image=image)
def load_data() -> pd.DataFrame:
return pd.read_csv(URL)


# %% [markdown]
# ## Model Training Task
#
# This task trains the Neural Prophet model on the loaded data.
# We train the model in the hourly frequency for ten epochs.


@task(container_image=image)
def train_model(df: pd.DataFrame) -> FlyteFile:
from neuralprophet import NeuralProphet, save
Expand All @@ -62,12 +65,14 @@ def train_model(df: pd.DataFrame) -> FlyteFile:
save(model, model_fp)
return FlyteFile(model_fp)


# %% [markdown]
# ## Forecasting Task
#
# This task loads the trained model, makes predictions, and visualizes the
# results using a Flyte Deck.


@task(
container_image=image,
enable_deck=True,
Expand All @@ -77,7 +82,7 @@ def make_forecast(df: pd.DataFrame, model_file: FlyteFile) -> pd.DataFrame:

model_file.download()
model = load(model_file.path)

# Create a new dataframe reaching 365 into the future
# for our forecast, n_historic_predictions also shows historic data
df_future = model.make_future_dataframe(
Expand All @@ -95,12 +100,14 @@ def make_forecast(df: pd.DataFrame, model_file: FlyteFile) -> pd.DataFrame:

return forecast


# %% [markdown]
# ## Main Workflow
#
# Finally, this workflow orchestrates the entire process: loading data,
# training the model, and making forecasts.


@workflow
def main() -> pd.DataFrame:
df = load_data()
Expand Down

0 comments on commit 087b8cb

Please sign in to comment.