Skip to content

Commit

Permalink
improve match-tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasMelton committed Sep 18, 2024
1 parent 69a2f54 commit c88bd92
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions artlib/common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m
return 0
else:

if match_reset_method == "MY~" and match_reset_func is not None:
if match_reset_method in ["MT~", "MT1"] and match_reset_func is not None:
T_values, T_cache = zip(*[
self.category_choice(x, w, params=self.params)
if match_reset_func(x, w, c_, params=self.params, cache=None)
Expand All @@ -355,17 +355,20 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m
w = self.W[c_]
cache = T_cache[c_]
m, cache = self.match_criterion_bin(x, w, params=self.params, cache=cache, op=mt_operator)
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
)
if match_reset_method in ["MT~", "MT1"] and match_reset_func is not None:
no_match_reset = True
else:
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
)
if m and no_match_reset:
self.set_weight(c_, self.update(x, w, self.params, cache=cache))
self._set_params(base_params)
return c_
else:
T[c_] = np.nan
if not no_match_reset:
if m and not no_match_reset:
keep_searching = self._match_tracking(cache, epsilon, self.params, match_reset_method)
if not keep_searching:
T[:] = np.nan
Expand Down

0 comments on commit c88bd92

Please sign in to comment.