Skip to content

Commit

Permalink
fix: solve problem with duplicated dates (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza authored Dec 9, 2023
2 parents 990090e + cc5c0e8 commit 5897e21
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
8 changes: 4 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ name: nixtlats
channels:
- conda-forge
dependencies:
- dask<2023.1.1
- dask
- jupyterlab
- pandas
- plotly
- prophet
- pyspark>=3.3
- pip:
- black
Expand All @@ -17,9 +18,8 @@ dependencies:
- statsforecast
- utilsforecast>=0.0.13
- requests
- duckdb<0.8
- fugue[ray]
- ray<2.4
- fugue[ray]>=0.8.7
- ray[serve-grpc]
- fire
- tabulate
- tenacity
24 changes: 22 additions & 2 deletions nbs/timegpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@
" if self.freq is None or self.freq in special_freqs:\n",
" unique_id = df.iloc[0]['unique_id']\n",
" df_id = df.query('unique_id == @unique_id')\n",
" inferred_freq = pd.infer_freq(df_id['ds'])\n",
" inferred_freq = pd.infer_freq(df_id['ds'].sort_values())\n",
" if inferred_freq is None:\n",
" raise Exception(\n",
" 'Could not infer frequency of ds column. This could be due to '\n",
Expand Down Expand Up @@ -345,7 +345,7 @@
" future_dates = X_df['ds'].unique().tolist()\n",
" else:\n",
" future_dates = []\n",
" dates = pd.DatetimeIndex(train_dates + future_dates)\n",
" dates = pd.DatetimeIndex(np.unique(train_dates + future_dates).tolist())\n",
" date_features_df = pd.DataFrame({'ds': dates})\n",
" for feature in self.date_features:\n",
" feat_df = self.compute_date_feature(dates, feature)\n",
Expand Down Expand Up @@ -1546,6 +1546,26 @@
"from utilsforecast.data import generate_series"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test date_features with multiple series\n",
"# and different ends\n",
"test_series = generate_series(n_series=2, min_length=5, max_length=20)\n",
"h = 12\n",
"fcst_test_series = timegpt.forecast(test_series, h=12, date_features=['dayofweek'])\n",
"uids = test_series['unique_id']\n",
"for uid in uids:\n",
" test_eq(\n",
" fcst_test_series.query('unique_id == @uid')['ds'].values,\n",
" pd.date_range(periods=h + 1, start=test_series.query('unique_id == @uid')['ds'].max())[1:].astype(str),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
4 changes: 2 additions & 2 deletions nixtlats/timegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def infer_freq(self, df: pd.DataFrame):
if self.freq is None or self.freq in special_freqs:
unique_id = df.iloc[0]["unique_id"]
df_id = df.query("unique_id == @unique_id")
inferred_freq = pd.infer_freq(df_id["ds"])
inferred_freq = pd.infer_freq(df_id["ds"].sort_values())
if inferred_freq is None:
raise Exception(
"Could not infer frequency of ds column. This could be due to "
Expand Down Expand Up @@ -295,7 +295,7 @@ def add_date_features(
future_dates = X_df["ds"].unique().tolist()
else:
future_dates = []
dates = pd.DatetimeIndex(train_dates + future_dates)
dates = pd.DatetimeIndex(np.unique(train_dates + future_dates).tolist())
date_features_df = pd.DataFrame({"ds": dates})
for feature in self.date_features:
feat_df = self.compute_date_feature(dates, feature)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()

dev = ["black", "nbdev", "plotly", "python-dotenv", "statsforecast"]
distributed = ["dask", "fugue[ray]", "pyspark"]
distributed = ["dask", "fugue[ray]>=0.8.7", "pyspark", "ray[serve-grpc]"]
plotting = ["utilsforecast[plotting]>=0.0.5"]
date_extras = ["holidays"]

Expand Down

0 comments on commit 5897e21

Please sign in to comment.