Skip to content

Commit

Permalink
fix: recover finetune and model changes (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza authored Apr 1, 2024
1 parent 6dfba26 commit dc8b72c
Show file tree
Hide file tree
Showing 19 changed files with 495 additions and 555 deletions.
26 changes: 13 additions & 13 deletions nbs/distributed.timegpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@
" X_df: Optional[fugue.AnyDataFrame] = None,\n",
" level: Optional[List[Union[int, float]]] = None,\n",
" quantiles: Optional[List[float]] = None,\n",
" fewshot_steps: int = 0,\n",
" fewshot_loss: str = 'default',\n",
" finetune_steps: int = 0,\n",
" finetune_loss: str = 'default',\n",
" clean_ex_first: bool = True,\n",
" validate_token: bool = False,\n",
" add_history: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" num_partitions: Optional[int] = None,\n",
" ) -> fugue.AnyDataFrame:\n",
" kwargs = dict(\n",
Expand All @@ -179,8 +179,8 @@
" target_col=target_col,\n",
" level=level,\n",
" quantiles=quantiles,\n",
" fewshot_steps=fewshot_steps,\n",
" fewshot_loss=fewshot_loss,\n",
" finetune_steps=finetune_steps,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_token=validate_token,\n",
" add_history=add_history,\n",
Expand Down Expand Up @@ -217,7 +217,7 @@
" validate_token: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" num_partitions: Optional[int] = None,\n",
" ) -> fugue.AnyDataFrame:\n",
" kwargs = dict(\n",
Expand Down Expand Up @@ -254,13 +254,13 @@
" target_col: str = 'y',\n",
" level: Optional[List[Union[int, float]]] = None,\n",
" quantiles: Optional[List[float]] = None,\n",
" fewshot_steps: int = 0,\n",
" fewshot_loss: str = 'default',\n",
" finetune_steps: int = 0,\n",
" finetune_loss: str = 'default',\n",
" clean_ex_first: bool = True,\n",
" validate_token: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" n_windows: int = 1,\n",
" step_size: Optional[int] = None,\n",
" num_partitions: Optional[int] = None,\n",
Expand All @@ -273,8 +273,8 @@
" target_col=target_col,\n",
" level=level,\n",
" quantiles=quantiles,\n",
" fewshot_steps=fewshot_steps,\n",
" fewshot_loss=fewshot_loss,\n",
" finetune_steps=finetune_steps,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_token=validate_token,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -448,7 +448,7 @@
" num_partitions=1,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" model='short-horizon',\n",
" model='timegpt-1',\n",
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
Expand Down Expand Up @@ -771,7 +771,7 @@
" num_partitions=1,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" model='short-horizon',\n",
" model='timegpt-1',\n",
" **anomalies_kwargs\n",
" )\n",
" anomalies_df = fa.as_pandas(anomalies_df)\n",
Expand Down
55 changes: 23 additions & 32 deletions nbs/docs/getting-started/1_getting_started_short.ipynb

Large diffs are not rendered by default.

69 changes: 33 additions & 36 deletions nbs/docs/how-to-guides/0_distributed_fcst_spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,7 @@
"execution_count": null,
"id": "fcf6004b-ebd0-4a3c-8c02-d5463c62f79e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from tqdm.autonotebook import tqdm\n"
]
}
],
"outputs": [],
"source": [
"from nixtlats import TimeGPT"
]
Expand Down Expand Up @@ -176,8 +167,7 @@
"text": [
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/11/09 17:49:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"23/11/09 17:49:02 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
"24/04/01 03:34:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
Expand Down Expand Up @@ -242,10 +232,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nixtlats.timegpt:Validating inputs... (4 + 16) / 20]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n"
]
},
{
Expand Down Expand Up @@ -302,42 +292,42 @@
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint... (36 + 60) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs... (54 + 42) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...========> (71 + 25) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs... \n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...===> (76 + 20) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...==============> (93 + 3) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
" \r"
]
Expand All @@ -347,7 +337,7 @@
"#| hide\n",
"# test different results for different models\n",
"fcst_df_1 = fcst_df.toPandas()\n",
"fcst_df_2 = timegpt.forecast(spark_df, h=12, model='long-horizon')\n",
"fcst_df_2 = timegpt.forecast(spark_df, h=12, model='timegpt-1-long-horizon')\n",
"fcst_df_2 = fcst_df_2.toPandas()\n",
"test_fail(\n",
" lambda: pd.testing.assert_frame_equal(fcst_df_1[['TimeGPT']], fcst_df_2[['TimeGPT']]),\n",
Expand Down Expand Up @@ -464,24 +454,31 @@
"output_type": "stream",
"text": [
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"[Stage 33:=====================================================> (19 + 1) / 20]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"| FR|2016-12-31 00:00:00| 64.97691027939692|60.056473801735784|61.71575274765864|68.23806781113521| 69.89734675705805|\n",
"| FR|2016-12-31 01:00:00| 60.14365519077404| 56.12626745731457|56.73784790927991|63.54946247226818| 64.16104292423351|\n",
"| FR|2016-12-31 02:00:00| 59.42375860682185| 54.84932824030574|56.52975776758845|62.31775944605525| 63.99818897333796|\n",
"| FR|2016-12-31 03:00:00| 55.11264928302748| 47.59671153125746|51.95117842731459|58.27412013874037| 62.6285870347975|\n",
"| FR|2016-12-31 04:00:00|54.400922806813526|44.925772896840385|49.65213255412798|59.14971305949907|63.876072716786666|\n",
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"| FR|2016-12-31 00:00:00| 59.39155162090687| 54.47111514324573| 56.13039408916859|62.65270915264515| 64.311988098568|\n",
"| FR|2016-12-31 01:00:00| 60.1843929541434|56.167005220683926|56.778585672649264|63.59020023563754|64.20178068760288|\n",
"| FR|2016-12-31 02:00:00| 58.12912691907976| 53.55469655256365| 55.23512607984636|61.02312775831316|62.70355728559587|\n",
"| FR|2016-12-31 03:00:00|53.825965179940155| 46.31002742817014| 50.66449432422726|56.98743603565305|61.34190293171017|\n",
"| FR|2016-12-31 04:00:00| 47.6941769331486| 38.21902702317546| 42.94538668046305|52.44296718583414|57.16932684312174|\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"only showing top 5 rows\n",
"\n"
]
Expand Down
Loading

0 comments on commit dc8b72c

Please sign in to comment.