Skip to content

Etna 797 #110

Merged
merged 15 commits into from
Oct 4, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Sequence anomalies ([#96](https://github.com/tinkoff-ai/etna-ts/pull/96))
- 'is_weekend' feature in DateFlagsTransform ([#101](https://github.com/tinkoff-ai/etna-ts/pull/101))
- Documentation example for models and note about inplace nature of forecast ([#112](https://github.com/tinkoff-ai/etna-ts/pull/112))
- Clustering (#[110](https://github.com/tinkoff-ai/etna-ts/pull/110))

### Changed
- SklearnTransform out column names ([#99](https://github.com/tinkoff-ai/etna-ts/pull/99))
Expand Down
8 changes: 8 additions & 0 deletions etna/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from etna.clustering.base import Clustering
from etna.clustering.distances.base import Distance
from etna.clustering.distances.distance_matrix import DistanceMatrix
from etna.clustering.distances.dtw_distance import DTWDistance
from etna.clustering.distances.euclidean_distance import EuclideanDistance
from etna.clustering.hierarchical.base import HierarchicalClustering
from etna.clustering.hierarchical.dtw_clustering import DTWClustering
from etna.clustering.hierarchical.euclidean_clustering import EuclideanClustering
33 changes: 33 additions & 0 deletions etna/clustering/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict

import pandas as pd

from etna.core import BaseMixin


class Clustering(ABC, BaseMixin):
"""Base class for ETNA clustering algorithms."""

@abstractmethod
def fit_predict(self) -> Dict[str, int]:
"""Fit clustering algo and predict clusters.

Returns
-------
Dict[str, int]:
dict in format {segment: cluster}
"""
pass

@abstractmethod
def get_centroids(self) -> pd.DataFrame:
"""Get centroids of clusters.

Returns
-------
pd.DataFrame:
dataframe with centroids
"""
pass
4 changes: 4 additions & 0 deletions etna/clustering/distances/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from etna.clustering.distances.base import Distance
from etna.clustering.distances.distance_matrix import DistanceMatrix
from etna.clustering.distances.dtw_distance import DTWDistance
from etna.clustering.distances.euclidean_distance import EuclideanDistance
106 changes: 106 additions & 0 deletions etna/clustering/distances/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import sys
import warnings
from abc import ABC
from abc import abstractmethod

import numpy as np
import pandas as pd

from etna.core import BaseMixin


class Distance(ABC, BaseMixin):
"""Base class for distances between series."""

def __init__(self, trim_series: bool = False, inf_value: float = sys.float_info.max // 10 ** 200):
"""Init Distance.

Parameters
----------
trim_series:
if True, get common (according to timestamp index) part of series and compute distance with it; if False,
compute distance with given series without any modifications.
inf_value:
if two empty series given or series' indices interception is empty,
return inf_value as a distance between the series
"""
self.trim_series = trim_series
self.inf_value = inf_value

@abstractmethod
def _compute_distance(self, x1: np.ndarray, x2: np.ndarray) -> float:
"""Compute distance between two given arrays."""
pass

def __call__(self, x1: pd.Series, x2: pd.Series) -> float:
"""Compute distance between x1 and x2.

Parameters
----------
x1:
timestamp-indexed series
x2:
timestamp-indexed series

Returns
-------
float:
distance between x1 and x2
"""
if self.trim_series:
common_indices = x1.index.intersection(x2.index)
_x1, _x2 = x1[common_indices], x2[common_indices]
else:
_x1, _x2 = x1, x2

# TODO: better to avoid such comments
# if x1 and x2 have no interception with timestamp return inf_value as a distance
if _x1.empty and _x2.empty:
return self.inf_value

distance = self._compute_distance(x1=_x1.values, x2=_x2.values)
# TODO: better to avoid such comments
# use it to avoid clustering confusing: if the last if passes we need to clip all the distances
# to inf_value
distance = min(self.inf_value, distance)
return distance

@staticmethod
def _validate_dataset(ts: "TSDataset"):
"""Check that dataset does not contain NaNs."""
for segment in ts.segments:
series = ts[:, segment, "target"]
first_valid_index = 0
last_valid_index = series.reset_index(drop=True).last_valid_index()
series_length = last_valid_index - first_valid_index + 1
if len(series.dropna()) != series_length:
warnings.warn(
f"Timeseries contains NaN values, which will be dropped. "
f"If it is not desirable behaviour, handle them manually."
)
break

@abstractmethod
def _get_average(self, ts: "TSDataset") -> pd.DataFrame:
"""Get series that minimizes squared distance to given ones according to the Distance."""
pass

def get_average(self, ts: "TSDataset") -> pd.DataFrame:
"""Get series that minimizes squared distance to given ones according to the Distance.

Parameters
----------
ts:
TSDataset with series to be averaged

Returns
-------
pd.DataFrame:
dataframe with columns "timestamp" and "target" that contains the series
"""
self._validate_dataset(ts)
centroid = self._get_average(ts)
return centroid


__all__ = ["Distance"]
126 changes: 126 additions & 0 deletions etna/clustering/distances/distance_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import warnings
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd

from etna.clustering.distances.base import Distance
from etna.core import BaseMixin
from etna.loggers import ConsoleLogger
from etna.loggers import tslogger


class DistanceMatrix(BaseMixin):
"""DistanceMatrix computes distance matrix from TSDataset."""

def __init__(self, distance: Distance):
"""Init DistanceMatrix.

Parameters
----------
distance:
class for distance measurement
"""
self.distance = distance
self.matrix: Optional[np.ndarray] = None
self.series: Optional[List[np.ndarray]] = None
self.segment2idx: Dict[str, int] = {}
self.idx2segment: Dict[int, str] = {}
self.series_number: Optional[int] = None

@staticmethod
def _validate_dataset(ts: "TSDataset"):
"""Check that dataset does not contain NaNs."""
for segment in ts.segments:
series = ts[:, segment, "target"]
first_valid_index = 0
last_valid_index = series.reset_index(drop=True).last_valid_index()
series_length = last_valid_index - first_valid_index + 1
if len(series.dropna()) != series_length:
warnings.warn(
f"Timeseries contains NaN values, which will be dropped. "
f"If it is not desirable behaviour, handle them manually."
)
break

def _get_series(self, ts: "TSDataset") -> List[pd.Series]:
"""Parse given TSDataset and get timestamp-indexed segment series.
Build mapping from segment to idx in matrix and vice versa.
"""
series_list = []
for i, segment in enumerate(ts.segments):
self.segment2idx[segment] = i
self.idx2segment[i] = segment
series = ts[:, segment, "target"].dropna()
series_list.append(series)

self.series_number = len(series_list)
return series_list

def _compute_dist(self, series: List[pd.Series], idx: int) -> np.ndarray:
"""Compute distance from idx-th series to other ones."""
distances = np.array([self.distance(series[idx], series[j]) for j in range(self.series_number)])
return distances

def _compute_dist_matrix(self, series: List[pd.Series]) -> np.ndarray:
"""Compute distance matrix for given series."""
distances = np.empty(shape=(self.series_number, self.series_number))
logging_freq = self.series_number // 10
logger_id = tslogger.add(ConsoleLogger())
Copy link
Contributor

@martins0n martins0n Oct 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't do this. Tslogger is as global object and all handlers've added already.
tslogger.start_experiment() and this too. Look at code examples https://github.com/tinkoff-ai/etna-ts/blob/51afbc2df6d046e3c2ab9b102f50fc7c83529f5a/etna/datasets/tsdataset.py#L122

tslogger.start_experiment()
for idx in range(self.series_number):
distances[idx] = self._compute_dist(series=series, idx=idx)
if (idx + 1) % logging_freq == 0:
tslogger.log(f"Done {idx + 1} out of {self.series_number} ")
tslogger.finish_experiment()
tslogger.remove(logger_id)
return distances

def fit(self, ts: "TSDataset") -> "DistanceMatrix":
"""Fit distance matrix: get timeseries from ts and compute pairwise distances.

Parameters
----------
ts:
TSDataset with timeseries

Returns
-------
self:
fitted DistanceMatrix object

"""
self._validate_dataset(ts)
self.series = self._get_series(ts)
self.matrix = self._compute_dist_matrix(self.series)
return self

def predict(self) -> np.ndarray:
"""Get distance matrix.

Returns
-------
np.ndarray:
2D array with distances between series
"""
return self.matrix

def fit_predict(self, ts: "TSDataset") -> np.ndarray:
"""Compute distance matrix and return it.

Parameters
----------
ts:
TSDataset with timeseries to compute matrix with

Returns
-------
np.ndarray:
2D array with distances between series
"""
return self.fit(ts).predict()


__all__ = ["DistanceMatrix"]
Loading