Skip to content

Commit

Permalink
Merge pull request #45 from NiklasMelton/update-base-art
Browse files Browse the repository at this point in the history
Pass cache through MC func
  • Loading branch information
NiklasMelton authored Mar 12, 2024
2 parents e76326c + ee4bbfe commit c90d489
Show file tree
Hide file tree
Showing 14 changed files with 76 additions and 60 deletions.
6 changes: 3 additions & 3 deletions common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def validate_data(self, X: np.ndarray):
def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
raise NotImplementedError

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
raise NotImplementedError

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
raise NotImplementedError

def update(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> np.ndarray:
Expand All @@ -60,7 +60,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
c_ = int(np.argmax(T))
w = self.W[c_]
cache = T_cache[c_]
m = self.match_criterion_bin(x, w, params=self.params, cache=cache)
m, cache = self.match_criterion_bin(x, w, params=self.params, cache=cache)
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
Expand Down
9 changes: 5 additions & 4 deletions elementary/ART1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
w_bu = w[:self.dim_]
return float(np.dot(i, w_bu)), None

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
w_td = w[self.dim_:]
return l1norm(np.logical_and(i, w_td)) / l1norm(i)
return l1norm(np.logical_and(i, w_td)) / l1norm(i), cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params, cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params, cache)
return M >= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
w_td = w[self.dim_:]
Expand Down
11 changes: 6 additions & 5 deletions elementary/ART2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,21 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
cache = {"activation": activation}
return activation, cache

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
if cache is None:
raise ValueError("No cache provided")
# TODO: make this more efficient
M = cache["activation"]
M_u = params["alpha"]*np.sum(i)
# suppress if uncommitted activation is higher
if M < M_u:
return -1.
return -1., cache
else:
return M
return M, cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params, cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params, cache)
return M >= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
return params["beta"]*i + (1-params["beta"])*w
Expand Down
13 changes: 9 additions & 4 deletions elementary/BayesianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,28 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f

return activation, cache

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
# the original paper uses the det(cov_old) for match criterion
# however, it makes logical sense to use the new_cov and results are improved when doing so
new_w = self.update(i, w, params, cache)
new_cov = new_w[self.dim_:-1].reshape((self.dim_, self.dim_))
cache["new_w"] = new_w
# if cache is None:
# raise ValueError("No cache provided")
# return cache["det_cov"]
return np.linalg.det(new_cov)
return np.linalg.det(new_cov), cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params=params, cache=cache) <= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params=params, cache=cache)
return M <= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
if cache is None:
raise ValueError("No cache provided")

if "new_w" in cache:
return cache["new_w"]

mean = w[:self.dim_]
cov = w[self.dim_:-1].reshape((self.dim_, self.dim_))
n = w[-1]
Expand Down
4 changes: 2 additions & 2 deletions elementary/DualVigilanceART.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
c_ = int(np.argmax(T))
w = self.base_module.W[c_]
cache = T_cache[c_]
m1 = self.base_module.match_criterion_bin(x, w, params=self.base_module.params, cache=cache)
m1, cache = self.base_module.match_criterion_bin(x, w, params=self.base_module.params, cache=cache)
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, self.map[c_], params=self.base_module.params, cache=cache)
Expand All @@ -85,7 +85,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
return self.map[c_]
else:
lb_params = dict(self.base_module.params, **{"rho": self.lower_bound})
m2 = self.base_module.match_criterion_bin(x, w, params=lb_params, cache=cache)
m2, _ = self.base_module.match_criterion_bin(x, w, params=lb_params, cache=cache)
if m2:
c_new = len(self.base_module.W)
w_new = self.base_module.new_weight(x, self.base_module.params)
Expand Down
11 changes: 6 additions & 5 deletions elementary/EllipsoidART.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,17 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
return (params["r_hat"] - radius - max(radius, dist)) / (params["r_hat"] - 2*radius + params["alpha"]), cache


def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
radius = w[-1]
if cache is None:
raise ValueError("No cache provided")
dist = cache["dist"]

return 1 - (radius + max(radius, dist))/params["r_hat"]
return 1 - (radius + max(radius, dist))/params["r_hat"], cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params=params, cache=cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params=params, cache=cache)
return M >= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
centroid = w[:self.dim_]
Expand Down Expand Up @@ -111,7 +112,7 @@ def plot_cluster_bounds(self, ax: Axes, colors: Iterable, linewidth: int = 1):
centroid,
width,
height,
angle,
angle=angle,
linewidth=linewidth,
edgecolor=col,
facecolor='none'
Expand Down
9 changes: 5 additions & 4 deletions elementary/FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def validate_data(self, X: np.ndarray):
def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
return l1norm(fuzzy_and(i, w)) / (params["alpha"] + l1norm(w)), None

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
return l1norm(fuzzy_and(i, w)) / self.dim_original
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
return l1norm(fuzzy_and(i, w)) / self.dim_original, cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params)
return M >= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
b = params["beta"]
Expand Down
9 changes: 5 additions & 4 deletions elementary/GaussianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
return activation, cache


def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
if cache is None:
raise ValueError("No cache provided")
exp_dist_sig_dist = cache["exp_dist_sig_dist"]
return exp_dist_sig_dist
return exp_dist_sig_dist, cache


def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params=params, cache=cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params=params, cache=cache)
return M >= params["rho"], cache


def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
Expand Down
9 changes: 5 additions & 4 deletions elementary/HypersphereART.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
return (params["r_hat"] - max_radius)/(params["r_hat"] - radius + params["alpha"]), cache


def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
radius = w[-1]
if cache is None:
raise ValueError("No cache provided")
max_radius = cache["max_radius"]

return 1 - (max(radius, max_radius)/params["r_hat"])
return 1 - (max(radius, max_radius)/params["r_hat"]), cache


def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params=params, cache=cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params=params, cache=cache)
return M >= params["rho"], cache


def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
Expand Down
9 changes: 5 additions & 4 deletions elementary/QuadraticNeuronART.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
}
return activation, cache

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
if cache is None:
raise ValueError("No cache provided")
return cache["activation"]
return cache["activation"], cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
return self.match_criterion(i, w, params, cache) >= params["rho"]
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params, cache)
return M >= params["rho"], cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
s = cache["s"]
Expand Down
2 changes: 1 addition & 1 deletion examples/test_ellipsoid_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def cluster_blobs():
print("Prepared data has shape:", X.shape)

params = {
"rho": 0.3,
"rho": 0.2,
"alpha": 0.0,
"beta": 1.0,
"r_hat": 0.6,
Expand Down
34 changes: 19 additions & 15 deletions fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,31 @@ def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[f
for k in range(self.n)
]
)
cache = {k: cache for k, cache in enumerate(caches)}
cache = {k: cache_k for k, cache_k in enumerate(caches)}
activation = sum([a*self.params["gamma_values"][k] for k, a in enumerate(activations)])
return activation, cache

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> list[float]:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[list[float], dict]:
if cache is None:
raise ValueError("No cache provided")
M = [
self.modules[k].match_criterion(
i[self._channel_indices[k][0]:self._channel_indices[k][1]],
w[self._channel_indices[k][0]:self._channel_indices[k][1]],
self.modules[k].params,
cache[k]
)
for k in range(self.n)
]
return M
M, caches = zip(
*[
self.modules[k].match_criterion(
i[self._channel_indices[k][0]:self._channel_indices[k][1]],
w[self._channel_indices[k][0]:self._channel_indices[k][1]],
self.modules[k].params,
cache[k]
)
for k in range(self.n)
]
)
cache = {k: cache_k for k, cache_k in enumerate(caches)}
return M, cache

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
M = self.match_criterion(i, w, params, cache)
return all(M[k] >= self.modules[k].params["rho"] for k in range(self.n))
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
M, cache = self.match_criterion(i, w, params, cache)
#TODO make work for Bayesian ART
return all(M[k] >= self.modules[k].params["rho"] for k in range(self.n)), cache

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
W = [
Expand Down
4 changes: 2 additions & 2 deletions templates/ART_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def validate_data(self, X: np.ndarray):
def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
raise NotImplementedError

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
raise NotImplementedError

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
raise NotImplementedError

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
Expand Down
6 changes: 3 additions & 3 deletions topological/TopoART.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def validate_data(self, X: np.ndarray):
def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
return self.base_module.category_choice(i, w, params)

def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> float:
def match_criterion(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[float, dict]:
return self.base_module.match_criterion(i, w, params, cache)

def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> bool:
def match_criterion_bin(self, i: np.ndarray, w: np.ndarray, params: dict, cache: Optional[dict] = None) -> tuple[bool, dict]:
return self.base_module.match_criterion_bin(i, w, params, cache)

def update(self, i: np.ndarray, w: np.ndarray, params, cache: Optional[dict] = None) -> np.ndarray:
Expand Down Expand Up @@ -107,7 +107,7 @@ def step_fit(self, x: np.ndarray, match_reset_func: Optional[Callable] = None) -
c_ = int(np.argmax(T))
w = self.W[c_]
cache = T_cache[c_]
m = self.match_criterion_bin(x, w, params=self.params, cache=cache)
m, cache = self.match_criterion_bin(x, w, params=self.params, cache=cache)
no_match_reset = (
match_reset_func is None or
match_reset_func(x, w, c_, params=self.params, cache=cache)
Expand Down

0 comments on commit c90d489

Please sign in to comment.