Skip to content

Commit

Permalink
Refactored code structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherbunn committed Aug 17, 2023
1 parent 5189419 commit eb066a7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
45 changes: 22 additions & 23 deletions evalml/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,6 @@ def unstack_multiseries(
# Perform the unstacking
X_unstacked_cols = []
y_unstacked_cols = []
new_time_index = None
for s_id in series_id_unique:
single_series = full_dataset[full_dataset[series_id] == s_id]

Expand Down Expand Up @@ -1478,36 +1477,22 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values
"""
original_columns = set()
series_ids = series_id_values or set()
for col in X.columns:
if col == time_index:
continue
separated_name = col.split("_")
original_columns.add("_".join(separated_name[:-1]))
if series_id_values is None:
if series_id_values is None:
for col in X.columns:
if col == time_index:
continue
separated_name = col.split("_")
original_columns.add("_".join(separated_name[:-1]))
series_ids.add(separated_name[-1])

restacked_X = []

if len(series_ids) == 0:
raise ValueError(
"Unable to stack X as X had no exogenous variables and `series_id_values` is None.",
"X has no exogenous variables and `series_id_values` is None.",
)

for i, original_col in enumerate(original_columns):
# Only include the series id once (for the first column)
include_series_id = i == 0
subset_X = [col for col in X.columns if original_col in col]
restacked_X.append(
stack_data(
X[subset_X],
include_series_id=include_series_id,
series_id_name=series_id_name,
starting_index=starting_index,
),
)
time_index_col = X[time_index].repeat(len(series_ids)).reset_index(drop=True)

if len(restacked_X) == 0:
if len(original_columns) == 0:
start_index = starting_index or X.index[0]
stacked_index = pd.RangeIndex(
start=start_index,
Expand All @@ -1522,6 +1507,20 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values
index=stacked_index,
)
else:
restacked_X = []
for i, original_col in enumerate(original_columns):
# Only include the series id once (for the first column)
include_series_id = i == 0
subset_X = [col for col in X.columns if original_col in col]
restacked_X.append(
stack_data(
X[subset_X],
include_series_id=include_series_id,
series_id_name=series_id_name,
starting_index=starting_index,
),
)

restacked_X = pd.concat(restacked_X, axis=1)
time_index_col.index = restacked_X.index
restacked_X[time_index] = time_index_col
Expand Down
2 changes: 1 addition & 1 deletion evalml/tests/pipeline_tests/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ def test_stack_X(

with pytest.raises(
ValueError,
match="Unable to stack X as X had no exogenous variables and `series_id_values` is None.",
match="X has no exogenous variables and `series_id_values` is None.",
):
stack_X(X, "series_id", "date", starting_index=starting_index)

Expand Down

0 comments on commit eb066a7

Please sign in to comment.