Skip to content

Commit

Permalink
Merge pull request #59 from NiklasMelton/test-grid-search
Browse files Browse the repository at this point in the history
test gridsearchCV
  • Loading branch information
NiklasMelton authored Mar 14, 2024
2 parents 50999b6 + ae3525c commit 51951c0
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 19 deletions.
60 changes: 59 additions & 1 deletion artlib/common/BaseARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,79 @@ def set_params(self, **params):
return self

def map_a2b(self, y_a: Union[np.ndarray, int]) -> Union[np.ndarray, int]:
"""
map an a-side label to a b-side label
Parameters:
- y_a: side a label(s)
Returns:
side B cluster label(s)
"""
if isinstance(y_a, int):
return self.map[y_a]
u, inv = np.unique(y_a, return_inverse=True)
return np.array([self.map[x] for x in u], dtype=int)[inv].reshape(y_a.shape)

def validate_data(self, X: np.ndarray, y: np.ndarray):
"""
validates the data prior to clustering
Parameters:
- X: data set A
- y: data set B
"""
raise NotImplementedError

def fit(self, X: np.ndarray, y: np.ndarray, max_iter=1):
"""
Fit the model to the data
Parameters:
- X: data set A
- y: data set B
- max_iter: number of iterations to fit the model on the same data set
"""
raise NotImplementedError

def partial_fit(self, X: np.ndarray, y: np.ndarray):
"""
Partial fit the model to the data
Parameters:
- X: data set A
- y: data set B
"""
raise NotImplementedError

def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def predict(self, X: np.ndarray) -> np.ndarray:
"""
predict labels for the data
Parameters:
- X: data set A
Returns:
B labels for the data
"""
raise NotImplementedError

def predict_ab(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
predict labels for the data, both A-side and B-side
Parameters:
- X: data set A
Returns:
A labels for the data, B labels for the data
"""
raise NotImplementedError

def plot_cluster_bounds(self, ax: Axes, colors: Iterable, linewidth: int = 1):
Expand Down
7 changes: 4 additions & 3 deletions artlib/hierarchical/DeepARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def get_params(self, deep: bool = True) -> dict:
"""
out = dict()
for i, module in enumerate(self.modules):
deep_items = module.get_params().items()
out.update((f"module_{i}" + "__" + k, val) for k, val in deep_items)
out[f"module_{i}"] = module
if deep:
deep_items = module.get_params().items()
out.update((f"module_{i}" + "__" + k, val) for k, val in deep_items)
return out

def set_params(self, **params):
Expand Down Expand Up @@ -221,7 +222,7 @@ def predict(self, X: Union[np.ndarray, list[np.ndarray]]) -> list[np.ndarray]:
x = X[-1]
else:
x = X
pred_a, pred_b = self.layers[-1].predict(x)
pred_a, pred_b = self.layers[-1].predict_ab(x)
pred = [pred_a, pred_b]
for layer in self.layers[:-1][::-1]:
pred.append(layer.map_a2b(pred[-1]))
Expand Down
34 changes: 25 additions & 9 deletions artlib/supervised/ARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def get_params(self, deep: bool = True) -> dict:
Parameter names mapped to their values.
"""
out = dict()
out = {
"module_a": self.module_a,
"module_b": self.module_b,
}

deep_a_items = self.module_a.get_params().items()
out.update(("module_a" + "__" + k, val) for k, val in deep_a_items)
out["module_a"] = self.module_a
if deep:
deep_a_items = self.module_a.get_params().items()
out.update(("module_a" + "__" + k, val) for k, val in deep_a_items)

deep_b_items = self.module_b.get_params().items()
out.update(("module_b" + "__" + k, val) for k, val in deep_b_items)
out["module_b"] = self.module_b
deep_b_items = self.module_b.get_params().items()
out.update(("module_b" + "__" + k, val) for k, val in deep_b_items)
return out


Expand Down Expand Up @@ -104,16 +106,30 @@ def partial_fit(self, X: np.ndarray, y: np.ndarray):
return self


def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def predict(self, X: np.ndarray) -> np.ndarray:
"""
predict labels for the data
Parameters:
- X: data set A
Returns:
A labels for the data, B labels for the data
B labels for the data
"""
check_is_fitted(self)
return super(ARTMAP, self).predict(X)

def predict_ab(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
predict labels for the data, both A-side and B-side
Parameters:
- X: data set A
Returns:
A labels for the data, B labels for the data
"""
check_is_fitted(self)
return super(ARTMAP, self).predict_ab(X)
29 changes: 23 additions & 6 deletions artlib/supervised/SimpleARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def get_params(self, deep: bool = True) -> dict:
Parameter names mapped to their values.
"""
out = dict()
deep_items = self.module_a.get_params().items()
out.update(("module_a" + "__" + k, val) for k, val in deep_items)
out["module_a"] = self.module_a
out = {"module_a": self.module_a}
if deep:
deep_items = self.module_a.get_params().items()
out.update(("module_a" + "__" + k, val) for k, val in deep_items)
return out


Expand Down Expand Up @@ -192,14 +192,31 @@ def step_pred(self, x: np.ndarray) -> tuple[int, int]:
c_b = self.map[c_a]
return c_a, c_b


def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def predict(self, X: np.ndarray) -> np.ndarray:
"""
predict labels for the data
Parameters:
- X: data set A
Returns:
B labels for the data
"""
check_is_fitted(self)
y_b = np.zeros((X.shape[0],), dtype=int)
for i, x in enumerate(X):
c_a, c_b = self.step_pred(x)
y_b[i] = c_b
return y_b

def predict_ab(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
predict labels for the data, both A-side and B-side
Parameters:
- X: data set A
Returns:
A labels for the data, B labels for the data
Expand Down
Empty file added examples/test_grid_search_cv.py
Empty file.

0 comments on commit 51951c0

Please sign in to comment.