diff --git a/artlib/common/BaseART.py b/artlib/common/BaseART.py index 6e7255a..3e11ce4 100644 --- a/artlib/common/BaseART.py +++ b/artlib/common/BaseART.py @@ -340,7 +340,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m return 0 else: - if match_reset_method == "MY~" and match_reset_func is not None: + if match_reset_method in ["MT~"] and match_reset_func is not None: T_values, T_cache = zip(*[ self.category_choice(x, w, params=self.params) if match_reset_func(x, w, c_, params=self.params, cache=None) @@ -355,17 +355,20 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None, m w = self.W[c_] cache = T_cache[c_] m, cache = self.match_criterion_bin(x, w, params=self.params, cache=cache, op=mt_operator) - no_match_reset = ( - match_reset_func is None or - match_reset_func(x, w, c_, params=self.params, cache=cache) - ) + if match_reset_method in ["MT~"] and match_reset_func is not None: + no_match_reset = True + else: + no_match_reset = ( + match_reset_func is None or + match_reset_func(x, w, c_, params=self.params, cache=cache) + ) if m and no_match_reset: self.set_weight(c_, self.update(x, w, self.params, cache=cache)) self._set_params(base_params) return c_ else: T[c_] = np.nan - if not no_match_reset: + if m and not no_match_reset: keep_searching = self._match_tracking(cache, epsilon, self.params, match_reset_method) if not keep_searching: T[:] = np.nan @@ -427,7 +430,7 @@ def post_fit(self, X: np.ndarray): pass - def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: Optional[Callable] = None, max_iter=1, match_reset_method:Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 0.0): + def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: Optional[Callable] = None, max_iter=1, match_reset_method:Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 0.0, verbose: bool = False): """ Fit the model to the data @@ -453,7 +456,12 @@ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, match_reset_func: O self.W: list[np.ndarray] = [] self.labels_ = np.zeros((X.shape[0], ), dtype=int) for _ in range(max_iter): - for i, x in enumerate(X): + if verbose: + from tqdm import tqdm + x_iter = tqdm(enumerate(X), total=int(X.shape[0])) + else: + x_iter = enumerate(X) + for i, x in x_iter: self.pre_step_fit(X) c = self.step_fit(x, match_reset_func=match_reset_func, match_reset_method=match_reset_method, epsilon=epsilon) self.labels_[i] = c diff --git a/artlib/elementary/GaussianART.py b/artlib/elementary/GaussianART.py index 4c48306..ccf2b9f 100644 --- a/artlib/elementary/GaussianART.py +++ b/artlib/elementary/GaussianART.py @@ -5,6 +5,7 @@ """ import numpy as np +from decimal import Decimal from typing import Optional, Iterable, List from matplotlib.axes import Axes from artlib.common.BaseART import BaseART @@ -13,20 +14,22 @@ class GaussianART(BaseART): # implementation of GaussianART - pi2 = np.pi*2 - def __init__(self, rho: float, sigma_init: np.ndarray): + def __init__(self, rho: float, sigma_init: np.ndarray, alpha: float = 1e-10): """ Parameters: - rho: vigilance parameter - sigma_init: initial estimate of the diagonal std + - alpha: used to prevent division by zero errors """ params = { "rho": rho, "sigma_init": sigma_init, + "alpha": alpha } super().__init__(params) + @staticmethod def validate_params(params: dict): """ @@ -38,7 +41,9 @@ def validate_params(params: dict): """ assert "rho" in params assert "sigma_init" in params - assert 1.0 >= params["rho"] > 0. + assert "alpha" in params + assert 1.0 >= params["rho"] >= 0. + assert params["alpha"] > 0. assert isinstance(params["rho"], float) assert isinstance(params["sigma_init"], np.ndarray) @@ -57,15 +62,19 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f """ mean = w[:self.dim_] - sigma = w[self.dim_:-1] + # sigma = w[self.dim_:2*self.dim] + inv_sig = w[2*self.dim_:3*self.dim_] + sqrt_det_sig = w[-2] n = w[-1] - sig = np.diag(np.multiply(sigma,sigma)) + dist = mean-i - exp_dist_sig_dist = np.exp(-0.5*np.matmul(dist.T, np.matmul(np.linalg.inv(sig), dist))) + exp_dist_sig_dist = np.exp(-0.5 * np.dot(dist, np.multiply(inv_sig, dist))) + cache = { "exp_dist_sig_dist": exp_dist_sig_dist } - p_i_cj = exp_dist_sig_dist/np.sqrt((self.pi2**self.dim_)*np.linalg.det(sig)) + # ignore the (2*pi)^d term as that is constant + p_i_cj = exp_dist_sig_dist/(params["alpha"]+sqrt_det_sig) p_cj = n/np.sum(w_[-1] for w_ in self.W) activation = p_i_cj*p_cj @@ -109,14 +118,18 @@ def update(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dic """ mean = w[:self.dim_] - sigma = w[self.dim_:-1] + sigma = w[self.dim_:2*self.dim_] n = w[-1] n_new = n+1 mean_new = (1-(1/n_new))*mean + (1/n_new)*i sigma_new = np.sqrt((1-(1/n_new))*np.multiply(sigma, sigma) + (1/n_new)*((mean_new - i)**2)) - return np.concatenate([mean_new, sigma_new, [n_new]]) + sigma2 = np.multiply(sigma_new, sigma_new) + inv_sig = 1 / sigma2 + det_sig = np.sqrt(np.prod(sigma2)) + + return np.concatenate([mean_new, sigma_new, inv_sig, [det_sig], [n_new]]) def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray: @@ -132,7 +145,10 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray: updated cluster weight """ - return np.concatenate([i, params["sigma_init"], [1.]]) + sigma2 = np.multiply(params["sigma_init"], params["sigma_init"]) + inv_sig_init = 1 / sigma2 + det_sig_init = np.sqrt(np.prod(sigma2)) + return np.concatenate([i, params["sigma_init"], inv_sig_init, [det_sig_init], [1.]]) def get_cluster_centers(self) -> List[np.ndarray]: """ diff --git a/artlib/supervised/SimpleARTMAP.py b/artlib/supervised/SimpleARTMAP.py index 1fbb2aa..b0bf4d9 100644 --- a/artlib/supervised/SimpleARTMAP.py +++ b/artlib/supervised/SimpleARTMAP.py @@ -129,7 +129,7 @@ def step_fit(self, x: np.ndarray, c_b: int, match_reset_method: Literal["MT+", " assert self.map[c_a] == c_b return c_a - def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 1e-10): + def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", epsilon: float = 1e-10, verbose: bool = False): """ Fit the model to the data @@ -155,7 +155,12 @@ def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1, match_reset_method: Lite self.module_a.labels_ = np.zeros((X.shape[0],), dtype=int) for _ in range(max_iter): - for i, (x, c_b) in enumerate(zip(X, y)): + if verbose: + from tqdm import tqdm + x_y_iter = tqdm(enumerate(zip(X, y)), total=int(X.shape[0])) + else: + x_y_iter = enumerate(zip(X, y)) + for i, (x, c_b) in x_y_iter: self.module_a.pre_step_fit(X) c_a = self.step_fit(x, c_b, match_reset_method=match_reset_method, epsilon=epsilon) self.module_a.labels_[i] = c_a