Skip to content

Commit

Permalink
Merge pull request #55 from NiklasMelton/sample-counts
Browse files Browse the repository at this point in the history
cleanup and sample counting
  • Loading branch information
NiklasMelton authored Mar 14, 2024
2 parents c808080 + fb9b4eb commit c181f18
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 59 deletions.
7 changes: 7 additions & 0 deletions common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, params: dict):
"""
self.validate_params(params)
self.params = params
self.sample_counter_ = 0
self.weight_sample_counter_: list[int] = []

def __getattr__(self, key):
if key in self.params:
Expand Down Expand Up @@ -237,6 +239,7 @@ def add_weight(self, new_w: np.ndarray):
- new_w: new cluster weight to add
"""
self.weight_sample_counter_.append(1)
self.W.append(new_w)

def set_weight(self, idx: int, new_w: np.ndarray):
Expand All @@ -248,6 +251,7 @@ def set_weight(self, idx: int, new_w: np.ndarray):
- new_w: new cluster weight
"""
self.weight_sample_counter_[idx] += 1
self.W[idx] = new_w

def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int:
Expand All @@ -264,6 +268,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
cluster label of the input sample
"""
self.sample_counter_ += 1
if len(self.W) == 0:
w_new = self.new_weight(x, self.params)
self.add_weight(w_new)
Expand Down Expand Up @@ -345,6 +350,7 @@ def fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None, max_it
"""
self.validate_data(X)
self.check_dimensions(X)
self.is_fitted_ = True

self.W: list[np.ndarray] = []
self.labels_ = np.zeros((X.shape[0], ), dtype=int)
Expand All @@ -371,6 +377,7 @@ def partial_fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None

self.validate_data(X)
self.check_dimensions(X)
self.is_fitted_ = True

if not hasattr(self, 'W'):
self.W: list[np.ndarray] = []
Expand Down
4 changes: 0 additions & 4 deletions elementary/ART2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def check_dimensions(self, X: np.ndarray):
- X: data set
"""
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
else:
assert X.shape[1] == self.dim_
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
assert self.params["alpha"] <= 1 / np.sqrt(self.dim_)
Expand Down
4 changes: 0 additions & 4 deletions elementary/BayesianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def check_dimensions(self, X: np.ndarray):
- X: data set
"""
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
else:
assert X.shape[1] == self.dim_
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
assert self.params["cov_init"].shape[0] == self.dim_
Expand Down
14 changes: 10 additions & 4 deletions elementary/DualVigilanceART.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import numpy as np
from typing import Optional, Callable, Iterable
from warnings import warn
from matplotlib.axes import Axes
from common.BaseART import BaseART

Expand All @@ -13,6 +14,12 @@ class DualVigilanceART(BaseART):
# implementation of Dual Vigilance ART

def __init__(self, base_module: BaseART, lower_bound: float):
assert isinstance(base_module, BaseART)
if hasattr(base_module, "base_module"):
warn(
f"{base_module.__class__.__name__} is an abstraction of the BaseART class. "
f"This module will only make use of the base_module {base_module.base_module.__class__.__name__}"
)
params = dict(base_module.params, **{"rho_lower_bound": lower_bound})
super().__init__(params)
self.base_module = base_module
Expand Down Expand Up @@ -78,10 +85,6 @@ def check_dimensions(self, X: np.ndarray):
- X: data set
"""
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
else:
assert X.shape[1] == self.dim_
self.base_module.check_dimensions(X)

def validate_data(self, X: np.ndarray):
Expand Down Expand Up @@ -109,6 +112,8 @@ def validate_params(params: dict):
assert "rho_lower_bound" in params, \
"Dual Vigilance ART requires a lower bound 'rho' value"
assert params["rho"] > params["rho_lower_bound"] >= 0
assert isinstance(params["rho"], float)
assert isinstance(params["rho_lower_bound"], float)

def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int:
"""
Expand All @@ -124,6 +129,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
cluster label of the input sample
"""
self.sample_counter_ += 1
if len(self.base_module.W) == 0:
new_w = self.base_module.new_weight(x, self.base_module.params)
self.base_module.add_weight(new_w)
Expand Down
4 changes: 0 additions & 4 deletions elementary/FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def check_dimensions(self, X: np.ndarray):
- X: data set
"""
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
else:
assert X.shape[1] == self.dim_
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
self.dim_original = int(self.dim_//2)
Expand Down
4 changes: 0 additions & 4 deletions fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ def check_dimensions(self, X: np.ndarray):
- X: data set
"""
if not hasattr(self, "dim_"):
self.dim_ = X.shape[1]
else:
assert X.shape[1] == self.dim_
assert X.shape[1] == self.dim_, "Invalid data shape"

def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
Expand Down
81 changes: 42 additions & 39 deletions topological/TopoART.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@

import numpy as np
from typing import Optional, Callable
from warnings import warn
from common.BaseART import BaseART


class TopoART(BaseART):

def __init__(self, base_module: BaseART, betta_lower: float, tau: int, phi: int):
assert isinstance(base_module, BaseART)
if hasattr(base_module, "base_module"):
warn(
f"{base_module.__class__.__name__} is an abstraction of the BaseART class. "
f"This module will only make use of the base_module {base_module.base_module.__class__.__name__}"
)
params = dict(base_module.params, **{"beta_lower": betta_lower, "tau": tau, "phi": phi})
super().__init__(params)
self.base_module = base_module
self.adjacency = np.zeros([], dtype=int)
self._counter = np.zeros([], dtype=int)
self._permanent_mask = np.zeros([], dtype=bool)

@staticmethod
Expand All @@ -38,6 +44,10 @@ def validate_params(params: dict):
assert "phi" in params
assert params["beta"] >= params["beta_lower"]
assert params["phi"] <= params["tau"]
assert isinstance(params["beta"], float)
assert isinstance(params["beta_lower"], float)
assert isinstance(params["tau"], int)
assert isinstance(params["phi"], int)

@property
def W(self):
Expand Down Expand Up @@ -134,18 +144,32 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
Returns:
updated cluster weight
"""

return self.new_weight(i, params)


def add_weight(self, new_w: np.ndarray):
"""
add a new cluster weight
Parameters:
- new_w: new cluster weight to add
"""
self.adjacency = np.pad(self.adjacency, ((0, 1), (0, 1)), "constant")
self._counter = np.pad(self._counter, (0, 1), "constant", constant_values=(1,))
self._permanent_mask = np.pad(self._permanent_mask, (0, 1), "constant")
return self.new_weight(i, params)
self.weight_sample_counter_.append(1)
self.W.append(new_w)


def prune(self, X: np.ndarray):
self._permanent_mask += (self._counter >= self.params["phi"])
self._permanent_mask += (np.array(self.weight_sample_counter_) >= self.phi)
perm_labels = np.where(self._permanent_mask)[0]
self.W = [w for w, pm in zip(self.W, self._permanent_mask) if pm]
self._counter = self._counter[perm_labels]
self.weight_sample_counter_ = [self.weight_sample_counter_[i] for i in perm_labels]
self.adjacency = self.adjacency[perm_labels, perm_labels]
self._permanent_mask = self._permanent_mask[perm_labels]

label_map = {
label: np.where(perm_labels == label)[0][0]
Expand All @@ -157,15 +181,18 @@ def prune(self, X: np.ndarray):
if self.labels_[i] in label_map:
self.labels_[i] = label_map[self.labels_[i]]
else:
T_values, T_cache = zip(*[self.category_choice(x, w, params=self.params) for w in self.W])
T = np.array(T_values)
new_label = np.argmax(T)
self.labels_[i] = new_label
self._counter[new_label] += 1

def step_prune(self, X: np.ndarray):
sum_counter = sum(self._counter)
if sum_counter > 0 and sum_counter % self.params["tau"] == 0:
# this is a more flexible approach than that described in the paper
self.labels_[i] = self.step_pred(x)

def post_step_fit(self, X: np.ndarray):
"""
Function called after each sample fit. Used for cluster pruning
Parameters:
- X: data set
"""
if self.sample_counter_ > 0 and self.sample_counter_ % self.tau == 0:
self.prune(X)

def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int:
Expand All @@ -182,13 +209,13 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
cluster label of the input sample
"""
self.sample_counter_ += 1
resonant_c: int = -1

if len(self.W) == 0:
new_w = self.new_weight(x, self.params)
self.add_weight(new_w)
self.adjacency = np.zeros((1, 1), dtype=int)
self._counter = np.ones((1, ), dtype=int)
self._permanent_mask = np.zeros((1, ), dtype=bool)
return 0
else:
Expand Down Expand Up @@ -231,27 +258,3 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -

return resonant_c


def fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None, max_iter=1):
"""
Fit the model to the data
Parameters:
- X: data set
- match_reset_func: a callable accepting the data sample, a cluster weight, the params dict, and the cache dict
Permits external factors to influence cluster creation.
Returns True if the cluster is valid for the sample, False otherwise
- max_iter: number of iterations to fit the model on the same data set
"""
self.validate_data(X)
self.check_dimensions(X)

self.W: list[np.ndarray] = []
self.labels_ = np.zeros((X.shape[0], ), dtype=int)
for _ in range(max_iter):
for i, x in enumerate(X):
self.step_prune(X)
c = self.step_fit(x, match_reset_func=match_reset_func)
self.labels_[i] = c
return self

0 comments on commit c181f18

Please sign in to comment.