From fb9b4eb2fe62937267bbd8f19eccac106d81f655 Mon Sep 17 00:00:00 2001 From: niklas melton Date: Wed, 13 Mar 2024 21:38:28 -0500 Subject: [PATCH] cleanup and sample counting --- common/BaseART.py | 7 +++ elementary/ART2.py | 4 -- elementary/BayesianART.py | 4 -- elementary/DualVigilanceART.py | 14 ++++-- elementary/FuzzyART.py | 4 -- fusion/FusionART.py | 4 -- topological/TopoART.py | 81 ++++++++++++++++++---------------- 7 files changed, 59 insertions(+), 59 deletions(-) diff --git a/common/BaseART.py b/common/BaseART.py index 03b41d1..37bc57d 100644 --- a/common/BaseART.py +++ b/common/BaseART.py @@ -19,6 +19,8 @@ def __init__(self, params: dict): """ self.validate_params(params) self.params = params + self.sample_counter_ = 0 + self.weight_sample_counter_: list[int] = [] def __getattr__(self, key): if key in self.params: @@ -237,6 +239,7 @@ def add_weight(self, new_w: np.ndarray): - new_w: new cluster weight to add """ + self.weight_sample_counter_.append(1) self.W.append(new_w) def set_weight(self, idx: int, new_w: np.ndarray): @@ -248,6 +251,7 @@ def set_weight(self, idx: int, new_w: np.ndarray): - new_w: new cluster weight """ + self.weight_sample_counter_[idx] += 1 self.W[idx] = new_w def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int: @@ -264,6 +268,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) - cluster label of the input sample """ + self.sample_counter_ += 1 if len(self.W) == 0: w_new = self.new_weight(x, self.params) self.add_weight(w_new) @@ -345,6 +350,7 @@ def fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None, max_it """ self.validate_data(X) self.check_dimensions(X) + self.is_fitted_ = True self.W: list[np.ndarray] = [] self.labels_ = np.zeros((X.shape[0], ), dtype=int) @@ -371,6 +377,7 @@ def partial_fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None self.validate_data(X) self.check_dimensions(X) + self.is_fitted_ = True if not hasattr(self, 'W'): self.W: list[np.ndarray] = [] diff --git a/elementary/ART2.py b/elementary/ART2.py index 7245403..c939e1a 100644 --- a/elementary/ART2.py +++ b/elementary/ART2.py @@ -66,10 +66,6 @@ def check_dimensions(self, X: np.ndarray): - X: data set """ - if not hasattr(self, "dim_"): - self.dim_ = X.shape[1] - else: - assert X.shape[1] == self.dim_ if not hasattr(self, "dim_"): self.dim_ = X.shape[1] assert self.params["alpha"] <= 1 / np.sqrt(self.dim_) diff --git a/elementary/BayesianART.py b/elementary/BayesianART.py index d8e2280..a1d2ef8 100644 --- a/elementary/BayesianART.py +++ b/elementary/BayesianART.py @@ -52,10 +52,6 @@ def check_dimensions(self, X: np.ndarray): - X: data set """ - if not hasattr(self, "dim_"): - self.dim_ = X.shape[1] - else: - assert X.shape[1] == self.dim_ if not hasattr(self, "dim_"): self.dim_ = X.shape[1] assert self.params["cov_init"].shape[0] == self.dim_ diff --git a/elementary/DualVigilanceART.py b/elementary/DualVigilanceART.py index 4cb8244..20a15eb 100644 --- a/elementary/DualVigilanceART.py +++ b/elementary/DualVigilanceART.py @@ -5,6 +5,7 @@ """ import numpy as np from typing import Optional, Callable, Iterable +from warnings import warn from matplotlib.axes import Axes from common.BaseART import BaseART @@ -13,6 +14,12 @@ class DualVigilanceART(BaseART): # implementation of Dual Vigilance ART def __init__(self, base_module: BaseART, lower_bound: float): + assert isinstance(base_module, BaseART) + if hasattr(base_module, "base_module"): + warn( + f"{base_module.__class__.__name__} is an abstraction of the BaseART class. " + f"This module will only make use of the base_module {base_module.base_module.__class__.__name__}" + ) params = dict(base_module.params, **{"rho_lower_bound": lower_bound}) super().__init__(params) self.base_module = base_module @@ -78,10 +85,6 @@ def check_dimensions(self, X: np.ndarray): - X: data set """ - if not hasattr(self, "dim_"): - self.dim_ = X.shape[1] - else: - assert X.shape[1] == self.dim_ self.base_module.check_dimensions(X) def validate_data(self, X: np.ndarray): @@ -109,6 +112,8 @@ def validate_params(params: dict): assert "rho_lower_bound" in params, \ "Dual Vigilance ART requires a lower bound 'rho' value" assert params["rho"] > params["rho_lower_bound"] >= 0 + assert isinstance(params["rho"], float) + assert isinstance(params["rho_lower_bound"], float) def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int: """ @@ -124,6 +129,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) - cluster label of the input sample """ + self.sample_counter_ += 1 if len(self.base_module.W) == 0: new_w = self.base_module.new_weight(x, self.base_module.params) self.base_module.add_weight(new_w) diff --git a/elementary/FuzzyART.py b/elementary/FuzzyART.py index 61938ee..ca4f257 100644 --- a/elementary/FuzzyART.py +++ b/elementary/FuzzyART.py @@ -107,10 +107,6 @@ def check_dimensions(self, X: np.ndarray): - X: data set """ - if not hasattr(self, "dim_"): - self.dim_ = X.shape[1] - else: - assert X.shape[1] == self.dim_ if not hasattr(self, "dim_"): self.dim_ = X.shape[1] self.dim_original = int(self.dim_//2) diff --git a/fusion/FusionART.py b/fusion/FusionART.py index 1e70717..a7e2a6a 100644 --- a/fusion/FusionART.py +++ b/fusion/FusionART.py @@ -112,10 +112,6 @@ def check_dimensions(self, X: np.ndarray): - X: data set """ - if not hasattr(self, "dim_"): - self.dim_ = X.shape[1] - else: - assert X.shape[1] == self.dim_ assert X.shape[1] == self.dim_, "Invalid data shape" def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]: diff --git a/topological/TopoART.py b/topological/TopoART.py index 624bad7..11452cc 100644 --- a/topological/TopoART.py +++ b/topological/TopoART.py @@ -10,17 +10,23 @@ import numpy as np from typing import Optional, Callable +from warnings import warn from common.BaseART import BaseART class TopoART(BaseART): def __init__(self, base_module: BaseART, betta_lower: float, tau: int, phi: int): + assert isinstance(base_module, BaseART) + if hasattr(base_module, "base_module"): + warn( + f"{base_module.__class__.__name__} is an abstraction of the BaseART class. " + f"This module will only make use of the base_module {base_module.base_module.__class__.__name__}" + ) params = dict(base_module.params, **{"beta_lower": betta_lower, "tau": tau, "phi": phi}) super().__init__(params) self.base_module = base_module self.adjacency = np.zeros([], dtype=int) - self._counter = np.zeros([], dtype=int) self._permanent_mask = np.zeros([], dtype=bool) @staticmethod @@ -38,6 +44,10 @@ def validate_params(params: dict): assert "phi" in params assert params["beta"] >= params["beta_lower"] assert params["phi"] <= params["tau"] + assert isinstance(params["beta"], float) + assert isinstance(params["beta_lower"], float) + assert isinstance(params["tau"], int) + assert isinstance(params["phi"], int) @property def W(self): @@ -134,18 +144,32 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray: Returns: updated cluster weight + """ + + return self.new_weight(i, params) + + + def add_weight(self, new_w: np.ndarray): + """ + add a new cluster weight + + Parameters: + - new_w: new cluster weight to add + """ self.adjacency = np.pad(self.adjacency, ((0, 1), (0, 1)), "constant") - self._counter = np.pad(self._counter, (0, 1), "constant", constant_values=(1,)) self._permanent_mask = np.pad(self._permanent_mask, (0, 1), "constant") - return self.new_weight(i, params) + self.weight_sample_counter_.append(1) + self.W.append(new_w) def prune(self, X: np.ndarray): - self._permanent_mask += (self._counter >= self.params["phi"]) + self._permanent_mask += (np.array(self.weight_sample_counter_) >= self.phi) perm_labels = np.where(self._permanent_mask)[0] self.W = [w for w, pm in zip(self.W, self._permanent_mask) if pm] - self._counter = self._counter[perm_labels] + self.weight_sample_counter_ = [self.weight_sample_counter_[i] for i in perm_labels] + self.adjacency = self.adjacency[perm_labels, perm_labels] + self._permanent_mask = self._permanent_mask[perm_labels] label_map = { label: np.where(perm_labels == label)[0][0] @@ -157,15 +181,18 @@ def prune(self, X: np.ndarray): if self.labels_[i] in label_map: self.labels_[i] = label_map[self.labels_[i]] else: - T_values, T_cache = zip(*[self.category_choice(x, w, params=self.params) for w in self.W]) - T = np.array(T_values) - new_label = np.argmax(T) - self.labels_[i] = new_label - self._counter[new_label] += 1 - - def step_prune(self, X: np.ndarray): - sum_counter = sum(self._counter) - if sum_counter > 0 and sum_counter % self.params["tau"] == 0: + # this is a more flexible approach than that described in the paper + self.labels_[i] = self.step_pred(x) + + def post_step_fit(self, X: np.ndarray): + """ + Function called after each sample fit. Used for cluster pruning + + Parameters: + - X: data set + + """ + if self.sample_counter_ > 0 and self.sample_counter_ % self.tau == 0: self.prune(X) def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int: @@ -182,13 +209,13 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) - cluster label of the input sample """ + self.sample_counter_ += 1 resonant_c: int = -1 if len(self.W) == 0: new_w = self.new_weight(x, self.params) self.add_weight(new_w) self.adjacency = np.zeros((1, 1), dtype=int) - self._counter = np.ones((1, ), dtype=int) self._permanent_mask = np.zeros((1, ), dtype=bool) return 0 else: @@ -231,27 +258,3 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) - return resonant_c - - def fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None, max_iter=1): - """ - Fit the model to the data - - Parameters: - - X: data set - - match_reset_func: a callable accepting the data sample, a cluster weight, the params dict, and the cache dict - Permits external factors to influence cluster creation. - Returns True if the cluster is valid for the sample, False otherwise - - max_iter: number of iterations to fit the model on the same data set - - """ - self.validate_data(X) - self.check_dimensions(X) - - self.W: list[np.ndarray] = [] - self.labels_ = np.zeros((X.shape[0], ), dtype=int) - for _ in range(max_iter): - for i, x in enumerate(X): - self.step_prune(X) - c = self.step_fit(x, match_reset_func=match_reset_func) - self.labels_[i] = c - return self