-
Notifications
You must be signed in to change notification settings - Fork 80
Merged
Etna 797 #110
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
7d060e1
Original version
d547a51
Original version tests
564eb40
Remake Distance classes + tests
4687104
Fixes + NaN handling
bc52c3a
Remake HierarchicalClustering + update CHANGELOG
1c8cbf3
Merge branch 'master' into ETNA-797
alex-hse-repository fceba7e
Fix typing + examples
4656d8a
Update poetry.lock
4e18952
Merge branch 'master' into ETNA-797
alex-hse-repository 224f88f
Remove tqdm + Add logging
0da3b11
Merge branch 'master' into ETNA-797
alex-hse-repository a9269c9
Fix logging
a432ec6
Fix logging
382b380
Fix examples in docstring
3ac73b1
Merge branch 'master' into ETNA-797
alex-hse-repository File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from etna.clustering.base import Clustering | ||
from etna.clustering.distances.base import Distance | ||
from etna.clustering.distances.distance_matrix import DistanceMatrix | ||
from etna.clustering.distances.dtw_distance import DTWDistance | ||
from etna.clustering.distances.euclidean_distance import EuclideanDistance | ||
from etna.clustering.hierarchical.base import HierarchicalClustering | ||
from etna.clustering.hierarchical.dtw_clustering import DTWClustering | ||
from etna.clustering.hierarchical.euclidean_clustering import EuclideanClustering |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import Dict | ||
|
||
import pandas as pd | ||
|
||
from etna.core import BaseMixin | ||
|
||
|
||
class Clustering(ABC, BaseMixin): | ||
"""Base class for ETNA clustering algorithms.""" | ||
|
||
@abstractmethod | ||
def fit_predict(self) -> Dict[str, int]: | ||
"""Fit clustering algo and predict clusters. | ||
|
||
Returns | ||
------- | ||
Dict[str, int]: | ||
dict in format {segment: cluster} | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_centroids(self) -> pd.DataFrame: | ||
"""Get centroids of clusters. | ||
|
||
Returns | ||
------- | ||
pd.DataFrame: | ||
dataframe with centroids | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from etna.clustering.distances.base import Distance | ||
from etna.clustering.distances.distance_matrix import DistanceMatrix | ||
from etna.clustering.distances.dtw_distance import DTWDistance | ||
from etna.clustering.distances.euclidean_distance import EuclideanDistance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import sys | ||
import warnings | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.core import BaseMixin | ||
|
||
|
||
class Distance(ABC, BaseMixin): | ||
"""Base class for distances between series.""" | ||
|
||
def __init__(self, trim_series: bool = False, inf_value: float = sys.float_info.max // 10 ** 200): | ||
"""Init Distance. | ||
|
||
Parameters | ||
---------- | ||
trim_series: | ||
if True, get common (according to timestamp index) part of series and compute distance with it; if False, | ||
compute distance with given series without any modifications. | ||
inf_value: | ||
if two empty series given or series' indices interception is empty, | ||
return inf_value as a distance between the series | ||
""" | ||
self.trim_series = trim_series | ||
self.inf_value = inf_value | ||
|
||
@abstractmethod | ||
def _compute_distance(self, x1: np.ndarray, x2: np.ndarray) -> float: | ||
"""Compute distance between two given arrays.""" | ||
pass | ||
|
||
def __call__(self, x1: pd.Series, x2: pd.Series) -> float: | ||
"""Compute distance between x1 and x2. | ||
|
||
Parameters | ||
---------- | ||
x1: | ||
timestamp-indexed series | ||
x2: | ||
timestamp-indexed series | ||
|
||
Returns | ||
------- | ||
float: | ||
distance between x1 and x2 | ||
""" | ||
if self.trim_series: | ||
common_indices = x1.index.intersection(x2.index) | ||
_x1, _x2 = x1[common_indices], x2[common_indices] | ||
else: | ||
_x1, _x2 = x1, x2 | ||
|
||
# TODO: better to avoid such comments | ||
# if x1 and x2 have no interception with timestamp return inf_value as a distance | ||
if _x1.empty and _x2.empty: | ||
return self.inf_value | ||
|
||
distance = self._compute_distance(x1=_x1.values, x2=_x2.values) | ||
# TODO: better to avoid such comments | ||
# use it to avoid clustering confusing: if the last if passes we need to clip all the distances | ||
# to inf_value | ||
distance = min(self.inf_value, distance) | ||
return distance | ||
|
||
@staticmethod | ||
def _validate_dataset(ts: "TSDataset"): | ||
"""Check that dataset does not contain NaNs.""" | ||
for segment in ts.segments: | ||
series = ts[:, segment, "target"] | ||
first_valid_index = 0 | ||
last_valid_index = series.reset_index(drop=True).last_valid_index() | ||
series_length = last_valid_index - first_valid_index + 1 | ||
if len(series.dropna()) != series_length: | ||
warnings.warn( | ||
f"Timeseries contains NaN values, which will be dropped. " | ||
f"If it is not desirable behaviour, handle them manually." | ||
) | ||
break | ||
|
||
@abstractmethod | ||
def _get_average(self, ts: "TSDataset") -> pd.DataFrame: | ||
"""Get series that minimizes squared distance to given ones according to the Distance.""" | ||
pass | ||
|
||
def get_average(self, ts: "TSDataset") -> pd.DataFrame: | ||
"""Get series that minimizes squared distance to given ones according to the Distance. | ||
|
||
Parameters | ||
---------- | ||
ts: | ||
TSDataset with series to be averaged | ||
|
||
Returns | ||
------- | ||
pd.DataFrame: | ||
dataframe with columns "timestamp" and "target" that contains the series | ||
""" | ||
self._validate_dataset(ts) | ||
centroid = self._get_average(ts) | ||
return centroid | ||
|
||
|
||
__all__ = ["Distance"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import warnings | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.clustering.distances.base import Distance | ||
from etna.core import BaseMixin | ||
from etna.loggers import ConsoleLogger | ||
from etna.loggers import tslogger | ||
|
||
|
||
class DistanceMatrix(BaseMixin): | ||
"""DistanceMatrix computes distance matrix from TSDataset.""" | ||
|
||
def __init__(self, distance: Distance): | ||
"""Init DistanceMatrix. | ||
|
||
Parameters | ||
---------- | ||
distance: | ||
class for distance measurement | ||
""" | ||
self.distance = distance | ||
self.matrix: Optional[np.ndarray] = None | ||
self.series: Optional[List[np.ndarray]] = None | ||
self.segment2idx: Dict[str, int] = {} | ||
self.idx2segment: Dict[int, str] = {} | ||
self.series_number: Optional[int] = None | ||
|
||
@staticmethod | ||
def _validate_dataset(ts: "TSDataset"): | ||
"""Check that dataset does not contain NaNs.""" | ||
for segment in ts.segments: | ||
series = ts[:, segment, "target"] | ||
first_valid_index = 0 | ||
last_valid_index = series.reset_index(drop=True).last_valid_index() | ||
series_length = last_valid_index - first_valid_index + 1 | ||
if len(series.dropna()) != series_length: | ||
warnings.warn( | ||
f"Timeseries contains NaN values, which will be dropped. " | ||
f"If it is not desirable behaviour, handle them manually." | ||
) | ||
break | ||
|
||
def _get_series(self, ts: "TSDataset") -> List[pd.Series]: | ||
"""Parse given TSDataset and get timestamp-indexed segment series. | ||
Build mapping from segment to idx in matrix and vice versa. | ||
""" | ||
series_list = [] | ||
for i, segment in enumerate(ts.segments): | ||
self.segment2idx[segment] = i | ||
self.idx2segment[i] = segment | ||
series = ts[:, segment, "target"].dropna() | ||
series_list.append(series) | ||
|
||
self.series_number = len(series_list) | ||
return series_list | ||
|
||
def _compute_dist(self, series: List[pd.Series], idx: int) -> np.ndarray: | ||
"""Compute distance from idx-th series to other ones.""" | ||
distances = np.array([self.distance(series[idx], series[j]) for j in range(self.series_number)]) | ||
return distances | ||
|
||
def _compute_dist_matrix(self, series: List[pd.Series]) -> np.ndarray: | ||
"""Compute distance matrix for given series.""" | ||
distances = np.empty(shape=(self.series_number, self.series_number)) | ||
logging_freq = self.series_number // 10 | ||
logger_id = tslogger.add(ConsoleLogger()) | ||
tslogger.start_experiment() | ||
for idx in range(self.series_number): | ||
distances[idx] = self._compute_dist(series=series, idx=idx) | ||
if (idx + 1) % logging_freq == 0: | ||
tslogger.log(f"Done {idx + 1} out of {self.series_number} ") | ||
tslogger.finish_experiment() | ||
tslogger.remove(logger_id) | ||
return distances | ||
|
||
def fit(self, ts: "TSDataset") -> "DistanceMatrix": | ||
"""Fit distance matrix: get timeseries from ts and compute pairwise distances. | ||
|
||
Parameters | ||
---------- | ||
ts: | ||
TSDataset with timeseries | ||
|
||
Returns | ||
------- | ||
self: | ||
fitted DistanceMatrix object | ||
|
||
""" | ||
self._validate_dataset(ts) | ||
self.series = self._get_series(ts) | ||
self.matrix = self._compute_dist_matrix(self.series) | ||
return self | ||
|
||
def predict(self) -> np.ndarray: | ||
"""Get distance matrix. | ||
|
||
Returns | ||
------- | ||
np.ndarray: | ||
2D array with distances between series | ||
""" | ||
return self.matrix | ||
|
||
def fit_predict(self, ts: "TSDataset") -> np.ndarray: | ||
"""Compute distance matrix and return it. | ||
|
||
Parameters | ||
---------- | ||
ts: | ||
TSDataset with timeseries to compute matrix with | ||
|
||
Returns | ||
------- | ||
np.ndarray: | ||
2D array with distances between series | ||
""" | ||
return self.fit(ts).predict() | ||
|
||
|
||
__all__ = ["DistanceMatrix"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't do this. Tslogger is as global object and all handlers've added already.
tslogger.start_experiment()
and this too. Look at code examples https://github.com/tinkoff-ai/etna-ts/blob/51afbc2df6d046e3c2ab9b102f50fc7c83529f5a/etna/datasets/tsdataset.py#L122