Skip to content

Commit

Permalink
Merge pull request #57 from NiklasMelton/test-topo-art
Browse files Browse the repository at this point in the history
test topo ART
  • Loading branch information
NiklasMelton authored Mar 14, 2024
2 parents 4d8aea0 + 4642a39 commit d1ff8a7
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
2 changes: 1 addition & 1 deletion common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def visualize(
try:
self.plot_cluster_bounds(ax, colors, linewidth)
except NotImplementedError:
warn(f"{self.__class__.__name__} does not support plotting cluster bounds." )
warn(f"{self.__class__.__name__} does not support plotting cluster bounds.")



Expand Down
6 changes: 5 additions & 1 deletion elementary/DualVigilanceART.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,8 @@ def plot_cluster_bounds(self, ax: Axes, colors: Iterable, linewidth: int = 1):
colors_base = []
for k_a in range(self.base_module.n_clusters):
colors_base.append(colors[self.map[k_a]])
self.base_module.plot_cluster_bounds(ax, colors_base, linewidth)

try:
self.base_module.plot_cluster_bounds(ax=ax, colors=colors_base, linewidth=linewidth)
except NotImplementedError:
warn(f"{self.base_module.__class__.__name__} does not support plotting cluster bounds.")
45 changes: 45 additions & 0 deletions examples/test_topo_art.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import path
import sys

# directory reach
directory = path.Path(__file__).abspath()

print(directory.parent)
# setting path
sys.path.append(directory.parent.parent)

from topological.TopoART import TopoART
from elementary.FuzzyART import FuzzyART, prepare_data


def cluster_blobs():
data, target = make_blobs(n_samples=150, centers=3, cluster_std=0.50, random_state=0, shuffle=False)
print("Data has shape:", data.shape)

X = prepare_data(data)
print("Prepared data has shape:", X.shape)

params = {
"rho": 0.6,
"alpha": 0.8,
"beta": 1.0
}
base_art = FuzzyART(**params)
cls = TopoART(base_art, betta_lower=0.3, tau=150, phi=35)
cls = cls.fit(X, max_iter=5)
y = cls.labels_

print(f"{cls.n_clusters} clusters found")
print(f"{cls.base_module.n_clusters} internal clusters found")

print("Adjacency Matrix:")
print(cls.adjacency)

cls.visualize(X, y)
plt.show()


if __name__ == "__main__":
cluster_blobs()
34 changes: 28 additions & 6 deletions topological/TopoART.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"""

import numpy as np
from typing import Optional, Callable
from typing import Optional, Callable, Iterable
from matplotlib.axes import Axes
from warnings import warn
from common.BaseART import BaseART

Expand Down Expand Up @@ -146,7 +147,7 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
"""

return self.new_weight(i, params)
return self.base_module.new_weight(i, params)


def add_weight(self, new_w: np.ndarray):
Expand All @@ -157,7 +158,10 @@ def add_weight(self, new_w: np.ndarray):
- new_w: new cluster weight to add
"""
self.adjacency = np.pad(self.adjacency, ((0, 1), (0, 1)), "constant")
if len(self.W) == 0:
self.adjacency = np.zeros((1, 1))
else:
self.adjacency = np.pad(self.adjacency, ((0, 1), (0, 1)), "constant")
self._permanent_mask = np.pad(self._permanent_mask, (0, 1), "constant")
self.weight_sample_counter_.append(1)
self.W.append(new_w)
Expand All @@ -166,9 +170,10 @@ def add_weight(self, new_w: np.ndarray):
def prune(self, X: np.ndarray):
self._permanent_mask += (np.array(self.weight_sample_counter_) >= self.phi)
perm_labels = np.where(self._permanent_mask)[0]

self.W = [w for w, pm in zip(self.W, self._permanent_mask) if pm]
self.weight_sample_counter_ = [self.weight_sample_counter_[i] for i in perm_labels]
self.adjacency = self.adjacency[perm_labels, perm_labels]
self.adjacency = self.adjacency[perm_labels][:, perm_labels]
self._permanent_mask = self._permanent_mask[perm_labels]

label_map = {
Expand All @@ -180,9 +185,11 @@ def prune(self, X: np.ndarray):
for i, x in enumerate(X):
if self.labels_[i] in label_map:
self.labels_[i] = label_map[self.labels_[i]]
else:
elif len(self.W) > 0:
# this is a more flexible approach than that described in the paper
self.labels_[i] = self.step_pred(x)
else:
self.labels_[i] = -1

def post_step_fit(self, X: np.ndarray):
"""
Expand Down Expand Up @@ -240,11 +247,12 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
x,
w,
params=params,
cache=dict(cache, **{"resonant_c": resonant_c, "current_c": c_})
cache=dict((cache if cache else {}), **{"resonant_c": resonant_c, "current_c": c_})
)
self.set_weight(c_, new_w)
if resonant_c < 0:
resonant_c = c_
T[c_] = -1
else:
return resonant_c
else:
Expand All @@ -258,3 +266,17 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -

return resonant_c

def plot_cluster_bounds(self, ax: Axes, colors: Iterable, linewidth: int = 1):
"""
undefined function for visualizing the bounds of each cluster
Parameters:
- ax: figure axes
- colors: colors to use for each cluster
- linewidth: width of boundary line
"""
try:
self.base_module.plot_cluster_bounds(ax=ax, colors=colors, linewidth=linewidth)
except NotImplementedError:
warn(f"{self.base_module.__class__.__name__} does not support plotting cluster bounds.")

0 comments on commit d1ff8a7

Please sign in to comment.