Skip to content

Commit

Permalink
test FusionART
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasMelton committed Mar 13, 2024
1 parent 37c1977 commit fb436c6
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 16 deletions.
13 changes: 10 additions & 3 deletions common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,16 @@ def update(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dic
def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
raise NotImplementedError

def add_weight(self, new_w: np.ndarray):
self.W.append(new_w)

def set_weight(self, idx: int, new_w: np.ndarray):
self.W[idx] = new_w

def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int:
if len(self.W) == 0:
self.W.append(self.new_weight(x, self.params))
w_new = self.new_weight(x, self.params)
self.add_weight(w_new)
return 0
else:
T_values, T_cache = zip(*[self.category_choice(x, w, params=self.params) for w in self.W])
Expand All @@ -66,14 +73,14 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
match_reset_func(x, w, c_, params=self.params, cache=cache)
)
if m and no_match_reset:
self.W[c_] = self.update(x, w, self.params, cache=cache)
self.set_weight(c_, self.update(x, w, self.params, cache=cache))
return c_
else:
T[c_] = -1

c_new = len(self.W)
w_new = self.new_weight(x, self.params)
self.W.append(w_new)
self.add_weight(w_new)
return c_new

def step_pred(self, x) -> int:
Expand Down
10 changes: 6 additions & 4 deletions elementary/DualVigilanceART.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def validate_params(params: dict):

def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -> int:
if len(self.base_module.W) == 0:
self.base_module.W.append(self.base_module.new_weight(x, self.base_module.params))
new_w = self.base_module.new_weight(x, self.base_module.params)
self.base_module.add_weight(new_w)
self.map[0] = 0
return 0
else:
Expand All @@ -87,22 +88,23 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -

if no_match_reset:
if m1:
self.base_module.W[c_] = self.base_module.update(x, w, self.params, cache=cache)
new_w = self.base_module.update(x, w, self.params, cache=cache)
self.base_module.set_weight(c_, new_w)
return self.map[c_]
else:
lb_params = dict(self.base_module.params, **{"rho": self.lower_bound})
m2, _ = self.base_module.match_criterion_bin(x, w, params=lb_params, cache=cache)
if m2:
c_new = len(self.base_module.W)
w_new = self.base_module.new_weight(x, self.base_module.params)
self.base_module.W.append(w_new)
self.base_module.add_weight(w_new)
self.map[c_new] = self.map[c_]
return self.map[c_new]
T[c_] = -1

c_new = len(self.base_module.W)
w_new = self.base_module.new_weight(x, self.base_module.params)
self.base_module.W.append(w_new)
self.base_module.add_weight(w_new)
self.map[c_new] = max(self.map.values()) + 1
return self.map[c_new]

Expand Down
48 changes: 48 additions & 0 deletions examples/test_fusion_art.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import path
import sys

# directory reach
directory = path.Path(__file__).abspath()

print(directory.parent)
# setting path
sys.path.append(directory.parent.parent)

from fusion.FusionART import FusionART
from elementary.FuzzyART import FuzzyART, prepare_data


def cluster_blobs():
data, target = make_blobs(n_samples=150, centers=3, cluster_std=0.50, random_state=0, shuffle=False)
print("Data has shape:", data.shape)

data_channel_a = data[:,0].reshape((-1,1))
data_channel_b = data[:,1].reshape((-1,1))

X_channel_a = prepare_data(data_channel_a)
X_channel_b = prepare_data(data_channel_b)

X = np.hstack([X_channel_a, X_channel_b])
print("Prepared data has shape:", X.shape)

params = {
"rho": 0.5,
"alpha": 0.0,
"beta": 1.0
}
art_a = FuzzyART(params)
art_b = FuzzyART(params)
cls = FusionART([art_a, art_b], gamma_values=[0.5, 0.5], channel_dims=[2,2])
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")

cls.visualize(data, y)
plt.show()


if __name__ == "__main__":
cluster_blobs()
31 changes: 25 additions & 6 deletions fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,19 @@ def __init__(
self._channel_indices = get_channel_position_tuples(self.channel_dims)
self.dim_ = sum(channel_dims)


@property
def W(self):
W = np.concatenate(
self.modules[k].W
for k in range(self.n)
)
W = [
np.concatenate(
[
self.modules[k].W[i]
for k in range(self.n)
]
)
for i
in range(self.modules[0].n_clusters)
]
return W

@W.setter
Expand All @@ -60,8 +67,10 @@ def validate_params(params: dict):
assert sum(params["gamma_values"]) == 1.0

def validate_data(self, X: np.ndarray):
assert np.all(X >= 0), "Data has not been normalized"
assert np.all(X <= 1.0), "Data has not been normalized"
self.check_dimensions(X)
for k in range(self.n):
X_k = X[:, self._channel_indices[k][0]:self._channel_indices[k][1]]
self.modules[k].validate_data(X_k)

def check_dimensions(self, X: np.ndarray):
assert X.shape[1] == self.dim_, "Invalid data shape"
Expand Down Expand Up @@ -136,3 +145,13 @@ def new_weight(self, i: np.ndarray, params: dict) -> np.ndarray:
for k in range(self.n)
]
return np.concatenate(W)

def add_weight(self, new_w: np.ndarray):
for k in range(self.n):
new_w_k = new_w[self._channel_indices[k][0]:self._channel_indices[k][1]]
self.modules[k].add_weight(new_w_k)

def set_weight(self, idx: int, new_w: np.ndarray):
for k in range(self.n):
new_w_k = new_w[self._channel_indices[k][0]:self._channel_indices[k][1]]
self.modules[k].set_weight(idx, new_w_k)
8 changes: 5 additions & 3 deletions topological/TopoART.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
resonant_c: int = -1

if len(self.W) == 0:
self.W.append(self.new_weight(x, self.params))
new_w = self.new_weight(x, self.params)
self.add_weight(new_w)
self.adjacency = np.zeros((1, 1), dtype=int)
self._counter = np.ones((1, ), dtype=int)
self._permanent_mask = np.zeros((1, ), dtype=bool)
Expand All @@ -118,12 +119,13 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
else:
params = dict(self.params, **{"beta": self.params["beta_lower"]})
#TODO: make compatible with DualVigilanceART
self.W[c_] = self.update(
new_w = self.update(
x,
w,
params=params,
cache=dict(cache, **{"resonant_c": resonant_c, "current_c": c_})
)
self.set_weight(c_, new_w)
if resonant_c < 0:
resonant_c = c_
else:
Expand All @@ -134,7 +136,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
if resonant_c < 0:
c_new = len(self.W)
w_new = self.new_weight(x, self.params)
self.W.append(w_new)
self.add_weight(w_new)
return c_new

return resonant_c
Expand Down

0 comments on commit fb436c6

Please sign in to comment.