Skip to content

Commit

Permalink
fix: lru-cache usage (#171)
Browse files Browse the repository at this point in the history
* Fix lru-cache for binseg

* Fix lru-cache usage for Dynp and BottomUp
  • Loading branch information
julia-shenshina authored Jun 17, 2021
1 parent eeb31fb commit ae85f9d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
5 changes: 2 additions & 3 deletions src/ruptures/detection/binseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None
self.jump = jump
self.n_samples = None
self.signal = None
# cache for intermediate results
self.single_bkp = lru_cache(maxsize=None)(self._single_bkp)

def _seg(self, n_bkps=None, pen=None, epsilon=None):
"""Computes the binary segmentation.
Expand Down Expand Up @@ -83,7 +81,8 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None):
}
return partition

def _single_bkp(self, start, end):
@lru_cache(maxsize=None)
def single_bkp(self, start, end):
"""Return the optimal breakpoint of [start:end] (if it exists)."""
segment_cost = self.cost.error(start, end)
if np.isinf(segment_cost) and segment_cost < 0: # if constant on segment
Expand Down
4 changes: 2 additions & 2 deletions src/ruptures/detection/bottomup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None
self.n_samples = None
self.signal = None
self.leaves = None
self.merge = lru_cache(maxsize=None)(self._merge)

def _grow_tree(self):
"""Grow the entire binary tree."""
Expand Down Expand Up @@ -66,7 +65,8 @@ def _grow_tree(self):
leaves.append(leaf)
return leaves

def _merge(self, left, right):
@lru_cache(maxsize=None)
def merge(self, left, right):
"""Merge two contiguous segments."""
assert left.end == right.start, "Segments are not contiguous."
start, end = left.start, right.end
Expand Down
4 changes: 2 additions & 2 deletions src/ruptures/detection/dynp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None
jump (int, optional): subsample (one every *jump* points).
params (dict, optional): a dictionary of parameters for the cost instance.
"""
self.seg = lru_cache(maxsize=None)(self._seg) # dynamic programming
if custom_cost is not None and isinstance(custom_cost, BaseCost):
self.cost = custom_cost
else:
Expand All @@ -38,7 +37,8 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None
self.jump = jump
self.n_samples = None

def _seg(self, start, end, n_bkps):
@lru_cache(maxsize=None)
def seg(self, start, end, n_bkps):
"""Recurrence to find the optimal partition of signal[start:end].
This method is to be memoized and then used.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from itertools import product

import numpy as np
Expand Down Expand Up @@ -339,3 +340,9 @@ def test_model_small_signal_bis(signal_bkps_5D_n10, algo, model):
signal, _ = signal_bkps_5D_n10
with pytest.raises(BadSegmentationParameters):
algo(model=model, min_size=5, jump=2).fit_predict(signal, 2)


def test_binseg_deepcopy():
binseg = Binseg()
binseg_copy = deepcopy(binseg)
assert id(binseg.single_bkp) != id(binseg_copy.single_bkp)

0 comments on commit ae85f9d

Please sign in to comment.