Skip to content

Commit

Permalink
Export and clean notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopeix committed Apr 3, 2024
1 parent 23513fc commit d6ded7c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nixtlats/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,12 +749,15 @@ def partition_by_uid(func):
def wrapper(self, num_partitions, **kwargs):
if num_partitions is None or num_partitions == 1:
return func(self, **kwargs, num_partitions=1)
main_logger.info(f"Number of partitions: {num_partitions}")
df = kwargs.pop("df")
X_df = kwargs.pop("X_df", None)
id_col = kwargs["id_col"]
uids = df["unique_id"].unique()
results_df = []
split_index = 1
for uids_split in np.array_split(uids, num_partitions):
main_logger.info(f"Partition {split_index} of {num_partitions}")
df_uids = df.query("unique_id in @uids_split")
if X_df is not None:
X_df_uids = X_df.query("unique_id in @uids_split")
Expand All @@ -767,6 +770,7 @@ def wrapper(self, num_partitions, **kwargs):
kwargs_uids["X_df"] = X_df_uids
results_uids = func(self, **kwargs_uids, num_partitions=1)
results_df.append(results_uids)
split_index += 1
results_df = pd.concat(results_df).reset_index(drop=True)
return results_df

Expand Down

0 comments on commit d6ded7c

Please sign in to comment.