Skip to content

Commit

Permalink
Add examples to CatBoost models (#181)
Browse files Browse the repository at this point in the history
* Add examples to CatBoost models

* Update changelog

* Fix errors in doctest

* Align example sections correctly
  • Loading branch information
Mr-Geekman authored Oct 12, 2021
1 parent 5f6bd14 commit 04ad785
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Example for ProphetModel ([#178](https://github.com/tinkoff-ai/etna-ts/pull/178))
- Instruction notebook for custom model and transform creation ([#180](https://github.com/tinkoff-ai/etna-ts/pull/180))
- Add inverse_transform in *OutliersTransform ([#160](https://github.com/tinkoff-ai/etna-ts/pull/160))
- Examples for CatBoostModelMultiSegment and CatBoostModelPerSegment ([#181](https://github.com/tinkoff-ai/etna-ts/pull/181))

### Changed
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
Expand Down
84 changes: 82 additions & 2 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,47 @@ def predict(self, df: pd.DataFrame) -> list:


class CatBoostModelPerSegment(PerSegmentModel):
"""Class for holding per segment Catboost model."""
"""Class for holding per segment Catboost model.
Examples
--------
>>> from etna.datasets import generate_periodic_df
>>> from etna.datasets import TSDataset
>>> from etna.models import CatBoostModelPerSegment
>>> from etna.transforms import LagTransform
>>> classic_df = generate_periodic_df(
... periods=100,
... start_time="2020-01-01",
... n_segments=4,
... period=7,
... sigma=3
... )
>>> df = TSDataset.to_dataset(df=classic_df)
>>> ts = TSDataset(df, freq="D")
>>> horizon = 7
>>> transforms = [
... LagTransform(in_column="target", lags=[horizon, horizon+1, horizon+2])
... ]
>>> ts.fit_transform(transforms=transforms)
>>> future = ts.make_future(horizon)
>>> model = CatBoostModelPerSegment()
>>> model.fit(ts=ts)
CatBoostModelPerSegment(iterations = None, depth = None, learning_rate = None,
logging_level = 'Silent', l2_leaf_reg = None, thread_count = None, )
>>> forecast = model.forecast(future)
>>> pd.options.display.float_format = '{:,.2f}'.format
>>> forecast[:, :, "target"]
segment segment_0 segment_1 segment_2 segment_3
feature target target target target
timestamp
2020-04-10 9.00 9.00 4.00 6.00
2020-04-11 5.00 2.00 7.00 9.00
2020-04-12 0.00 4.00 7.00 9.00
2020-04-13 0.00 5.00 9.00 7.00
2020-04-14 1.00 2.00 1.00 6.00
2020-04-15 5.00 7.00 4.00 7.00
2020-04-16 8.00 6.00 2.00 0.00
"""

def __init__(
self,
Expand Down Expand Up @@ -122,7 +162,47 @@ def __init__(


class CatBoostModelMultiSegment(Model):
"""Class for holding Catboost model for all segments."""
"""Class for holding Catboost model for all segments.
Examples
--------
>>> from etna.datasets import generate_periodic_df
>>> from etna.datasets import TSDataset
>>> from etna.models import CatBoostModelMultiSegment
>>> from etna.transforms import LagTransform
>>> classic_df = generate_periodic_df(
... periods=100,
... start_time="2020-01-01",
... n_segments=4,
... period=7,
... sigma=3
... )
>>> df = TSDataset.to_dataset(df=classic_df)
>>> ts = TSDataset(df, freq="D")
>>> horizon = 7
>>> transforms = [
... LagTransform(in_column="target", lags=[horizon, horizon+1, horizon+2])
... ]
>>> ts.fit_transform(transforms=transforms)
>>> future = ts.make_future(horizon)
>>> model = CatBoostModelMultiSegment()
>>> model.fit(ts=ts)
CatBoostModelMultiSegment(iterations = None, depth = None, learning_rate = None,
logging_level = 'Silent', l2_leaf_reg = None, thread_count = None, )
>>> forecast = model.forecast(future)
>>> pd.options.display.float_format = '{:,.2f}'.format
>>> forecast[:, :, "target"].round()
segment segment_0 segment_1 segment_2 segment_3
feature target target target target
timestamp
2020-04-10 9.00 9.00 4.00 6.00
2020-04-11 5.00 2.00 7.00 9.00
2020-04-12 -0.00 4.00 7.00 9.00
2020-04-13 0.00 5.00 9.00 7.00
2020-04-14 1.00 2.00 1.00 6.00
2020-04-15 5.00 7.00 4.00 7.00
2020-04-16 8.00 6.00 2.00 0.00
"""

def __init__(
self,
Expand Down

0 comments on commit 04ad785

Please sign in to comment.