From 5667e2a5166bb0be2c5bbe1ce27577db849b3c88 Mon Sep 17 00:00:00 2001 From: niklas melton Date: Fri, 8 Mar 2024 12:30:29 -0600 Subject: [PATCH] dtype for labels --- elementary/BaseART.py | 4 ++-- examples/test_smart.py | 2 +- hierarchical/DeepARTMAP.py | 14 +++++++++++++- supervised/ARTMAP.py | 10 +++++----- topological/TopoART.py | 2 +- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/elementary/BaseART.py b/elementary/BaseART.py index a0cc4f6..ce70b81 100644 --- a/elementary/BaseART.py +++ b/elementary/BaseART.py @@ -106,7 +106,7 @@ def partial_fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None if not hasattr(self, 'W'): self.W: list[np.ndarray] = [] - self.labels_ = np.zeros((X.shape[0], )) + self.labels_ = np.zeros((X.shape[0], ), dtype=int) j = 0 else: j = len(self.labels_) @@ -122,7 +122,7 @@ def predict(self, X: np.ndarray): self.validate_data(X) self.check_dimensions(X) - y = np.zeros((X.shape[0],)) + y = np.zeros((X.shape[0],), dtype=int) for i, x in enumerate(X): c = self.step_pred(x) y[i] = c diff --git a/examples/test_smart.py b/examples/test_smart.py index 2190b88..ef6b452 100644 --- a/examples/test_smart.py +++ b/examples/test_smart.py @@ -54,7 +54,7 @@ def cluster_blobs(): if j == 0: layer_colors.append(colors[k]) else: - layer_colors.append(colors[cls.layers[j-1].map_a2b(k)]) + layer_colors.append(colors[cls.map_deep(j-1, k)]) cls.modules[j].plot_bounding_boxes(ax, layer_colors) plt.show() diff --git a/hierarchical/DeepARTMAP.py b/hierarchical/DeepARTMAP.py index 4a3fce8..9e3d69e 100644 --- a/hierarchical/DeepARTMAP.py +++ b/hierarchical/DeepARTMAP.py @@ -16,6 +16,10 @@ def __init__(self, modules: list[BaseART]): def labels_(self): return self.layers[0].labels_ + @property + def labels_deep_(self): + return np.concatenate([layer.labels_ for layer in self.layers]+[self.layers[-1].labels_a]) + @property def n_modules(self): return len(self.modules) @@ -24,6 +28,14 @@ def n_modules(self): def n_layers(self): return len(self.layers) + def map_deep(self, level: int, y_a: Union[np.ndarray, int]) -> Union[np.ndarray, int]: + y_b = self.layers[level].map_a2b(y_a) + if level > 0: + return self.map_deep(level-1, y_b) + else: + return y_b + + def validate_data( self, X: list[np.ndarray], @@ -53,7 +65,7 @@ def fit(self, X: list[np.ndarray], y: Optional[np.ndarray] = None, max_iter=1): self.layers[0] = self.layers[0].fit(X[1], X[0], max_iter=max_iter) for art_i in range(1, self.n_layers): - y_i = self.layers[art_i-1].labels_ + y_i = self.layers[art_i-1].labels_a self.layers[art_i] = self.layers[art_i].fit(X[art_i], y_i, max_iter=max_iter) return self diff --git a/supervised/ARTMAP.py b/supervised/ARTMAP.py index 35b06ed..787a8d7 100644 --- a/supervised/ARTMAP.py +++ b/supervised/ARTMAP.py @@ -13,7 +13,7 @@ def map_a2b(self, y_a: Union[np.ndarray, int]) -> np.ndarray: if isinstance(y_a, int): return self.map[y_a] u, inv = np.unique(y_a, return_inverse=True) - return np.array([self.map[x] for x in u])[inv].reshape(y_a.shape) + return np.array([self.map[x] for x in u], dtype=int)[inv].reshape(y_a.shape) def validate_data(self, X: np.ndarray, y: np.ndarray): raise NotImplementedError @@ -73,7 +73,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1): self.labels_ = y # init module A self.module_a.W = [] - self.module_a.labels_ = np.zeros((X.shape[0],)) + self.module_a.labels_ = np.zeros((X.shape[0],), dtype=int) for _ in range(max_iter): for i, (x, c_b) in enumerate(zip(X, y)): @@ -86,7 +86,7 @@ def partial_fit(self, X: np.ndarray, y: np.ndarray): if not hasattr(self, 'labels_'): self.labels_ = y self.module_a.W = [] - self.module_a.labels_ = np.zeros((X.shape[0],)) + self.module_a.labels_ = np.zeros((X.shape[0],), dtype=int) j = 0 else: j = len(self.labels_) @@ -119,8 +119,8 @@ def step_pred(self, x: np.ndarray) -> tuple[int, int]: def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: check_is_fitted(self) - y_a = np.zeros((X.shape[0],)) - y_b = np.zeros((X.shape[0],)) + y_a = np.zeros((X.shape[0],), dtype=int) + y_b = np.zeros((X.shape[0],), dtype=int) for i, x in enumerate(X): c_a, c_b = self.step_pred(x) y_a[i] = c_a diff --git a/topological/TopoART.py b/topological/TopoART.py index 3fb8da1..2f4a462 100644 --- a/topological/TopoART.py +++ b/topological/TopoART.py @@ -136,7 +136,7 @@ def fit(self, X: np.ndarray, match_reset_func: Optional[Callable] = None, max_it self.check_dimensions(X) self.W: list[np.ndarray] = [] - self.labels_ = np.zeros((X.shape[0], )) + self.labels_ = np.zeros((X.shape[0], ), dtype=int) for _ in range(max_iter): for i, x in enumerate(X): self.step_prune(X)