Skip to content

Commit

Permalink
Merge pull request #92 from NiklasMelton/match-tracking-update
Browse files Browse the repository at this point in the history
Match tracking update; improve Gaussian ART; add verbose fit
  • Loading branch information
NiklasMelton authored Sep 26, 2024
2 parents 69a2f54 + 97a5121 commit 59d674f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
24 changes: 16 additions & 8 deletions artlib/common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m
return 0
else:

if match_reset_method == "MY~" and match_reset_func is not None:
if match_reset_method in ["MT~"] and match_reset_func is not None:
T_values, T_cache = zip(*[
self.category_choice(x, w, params=self.params)
if match_reset_func(x, w, c_, params=self.params, cache=None)
Expand All @@ -355,17 +355,20 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m
w = self.W[c_]
cache = T_cache[c_]
m, cache = self.match_criterion_bin(x, w, params=self.params, cache=cache, op=mt_operator)
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
)
if match_reset_method in ["MT~"] and match_reset_func is not None:
no_match_reset = True
else:
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
)
if m and no_match_reset:
self.set_weight(c_, self.update(x, w, self.params, cache=cache))
self._set_params(base_params)
return c_
else:
T[c_] = np.nan
if not no_match_reset:
if m and not no_match_reset:
keep_searching = self._match_tracking(cache, epsilon, self.params, match_reset_method)
if not keep_searching:
T[:] = np.nan
Expand Down Expand Up @@ -427,7 +430,7 @@ def post_fit(self, X: np.ndarray):
pass


def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: Optional[Callable] = None, max_iter=1, match_reset_method:Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 0.0):
def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: Optional[Callable] = None, max_iter=1, match_reset_method:Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 0.0, verbose: bool = False):
"""
Fit the model to the data
Expand All @@ -453,7 +456,12 @@ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: O
self.W: list[np.ndarray] = []
self.labels_ = np.zeros((X.shape[0], ), dtype=int)
for _ in range(max_iter):
for i, x in enumerate(X):
if verbose:
from tqdm import tqdm
x_iter = tqdm(enumerate(X), total=int(X.shape[0]))
else:
x_iter = enumerate(X)
for i, x in x_iter:
self.pre_step_fit(X)
c = self.step_fit(x, match_reset_func=match_reset_func, match_reset_method=match_reset_method, epsilon=epsilon)
self.labels_[i] = c
Expand Down
36 changes: 26 additions & 10 deletions artlib/elementary/GaussianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import numpy as np
from decimal import Decimal
from typing import Optional, Iterable, List
from matplotlib.axes import Axes
from artlib.common.BaseART import BaseART
Expand All @@ -13,20 +14,22 @@

class GaussianART(BaseART):
# implementation of GaussianART
pi2 = np.pi*2
def __init__(self, rho: float, sigma_init: np.ndarray):
def __init__(self, rho: float, sigma_init: np.ndarray, alpha: float = 1e-10):
"""
Parameters:
- rho: vigilance parameter
- sigma_init: initial estimate of the diagonal std
- alpha: used to prevent division by zero errors
"""
params = {
"rho": rho,
"sigma_init": sigma_init,
"alpha": alpha
}
super().__init__(params)


@staticmethod
def validate_params(params: dict):
"""
Expand All @@ -38,7 +41,9 @@ def validate_params(params: dict):
"""
assert "rho" in params
assert "sigma_init" in params
assert 1.0 >= params["rho"] > 0.
assert "alpha" in params
assert 1.0 >= params["rho"] >= 0.
assert params["alpha"] > 0.
assert isinstance(params["rho"], float)
assert isinstance(params["sigma_init"], np.ndarray)

Expand All @@ -57,15 +62,19 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
"""
mean = w[:self.dim_]
sigma = w[self.dim_:-1]
# sigma = w[self.dim_:2*self.dim]
inv_sig = w[2*self.dim_:3*self.dim_]
sqrt_det_sig = w[-2]
n = w[-1]
sig = np.diag(np.multiply(sigma,sigma))

dist = mean-i
exp_dist_sig_dist = np.exp(-0.5*np.matmul(dist.T, np.matmul(np.linalg.inv(sig), dist)))
exp_dist_sig_dist = np.exp(-0.5 * np.dot(dist, np.multiply(inv_sig, dist)))

cache = {
"exp_dist_sig_dist": exp_dist_sig_dist
}
p_i_cj = exp_dist_sig_dist/np.sqrt((self.pi2**self.dim_)*np.linalg.det(sig))
# ignore the (2*pi)^d term as that is constant
p_i_cj = exp_dist_sig_dist/(params["alpha"]+sqrt_det_sig)
p_cj = n/np.sum(w_[-1] for w_ in self.W)

activation = p_i_cj*p_cj
Expand Down Expand Up @@ -109,14 +118,18 @@ def update(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dic
"""
mean = w[:self.dim_]
sigma = w[self.dim_:-1]
sigma = w[self.dim_:2*self.dim_]
n = w[-1]

n_new = n+1
mean_new = (1-(1/n_new))*mean + (1/n_new)*i
sigma_new = np.sqrt((1-(1/n_new))*np.multiply(sigma, sigma) + (1/n_new)*((mean_new - i)**2))

return np.concatenate([mean_new, sigma_new, [n_new]])
sigma2 = np.multiply(sigma_new, sigma_new)
inv_sig = 1 / sigma2
det_sig = np.sqrt(np.prod(sigma2))

return np.concatenate([mean_new, sigma_new, inv_sig, [det_sig], [n_new]])


def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
Expand All @@ -132,7 +145,10 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
updated cluster weight
"""
return np.concatenate([i, params["sigma_init"], [1.]])
sigma2 = np.multiply(params["sigma_init"], params["sigma_init"])
inv_sig_init = 1 / sigma2
det_sig_init = np.sqrt(np.prod(sigma2))
return np.concatenate([i, params["sigma_init"], inv_sig_init, [det_sig_init], [1.]])

def get_cluster_centers(self) -> List[np.ndarray]:
"""
Expand Down
9 changes: 7 additions & 2 deletions artlib/supervised/SimpleARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def step_fit(self, x: np.ndarray, c_b: int, match_reset_method: Literal["MT+", "
assert self.map[c_a] == c_b
return c_a

def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 1e-10):
def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 1e-10, verbose: bool = False):
"""
Fit the model to the data
Expand All @@ -155,7 +155,12 @@ def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Lite
self.module_a.labels_ = np.zeros((X.shape[0],), dtype=int)

for _ in range(max_iter):
for i, (x, c_b) in enumerate(zip(X, y)):
if verbose:
from tqdm import tqdm
x_y_iter = tqdm(enumerate(zip(X, y)), total=int(X.shape[0]))
else:
x_y_iter = enumerate(zip(X, y))
for i, (x, c_b) in x_y_iter:
self.module_a.pre_step_fit(X)
c_a = self.step_fit(x, c_b, match_reset_method=match_reset_method, epsilon=epsilon)
self.module_a.labels_[i] = c_a
Expand Down

0 comments on commit 59d674f

Please sign in to comment.