Skip to content

Commit

Permalink
Fix bug with indexing in _forecast_segment
Browse files Browse the repository at this point in the history
  • Loading branch information
d.a.bunin committed Jul 8, 2022
1 parent 116709e commit 1f0eee4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forecast(

@staticmethod
def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> pd.DataFrame:
segment_features = ts[:, segment, :]
segment_features = ts.df.loc[:, pd.IndexSlice[segment, :]]
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
dates = segment_features["timestamp"]
Expand Down Expand Up @@ -251,7 +251,7 @@ def get_model(self) -> Dict[str, Any]:
@staticmethod
def _forecast_segment(model: Any, segment: str, ts: TSDataset, *args, **kwargs) -> pd.DataFrame:
"""Make predictions for one segment."""
segment_features = ts[:, segment, :]
segment_features = ts.df.loc[:, pd.IndexSlice[segment, :]]
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
dates = segment_features["timestamp"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
@pytest.mark.parametrize(
"model, transforms, num_skip_timestamps",
[
(CatBoostModelPerSegment(), [LagTransform(in_column="target", lags=[2, 3])], 2),
(CatBoostModelPerSegment(), [LagTransform(in_column="target", lags=[2, 3])], 0),
(CatBoostModelMultiSegment(), [LagTransform(in_column="target", lags=[2, 3])], 0),
(LinearPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])], 3),
(LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])], 3),
Expand Down

0 comments on commit 1f0eee4

Please sign in to comment.