Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test BARTMAP #50

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions biclustering/BARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np
from typing import Optional
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from common.BaseART import BaseART
from sklearn.base import BaseEstimator, BiclusterMixin
from scipy.stats import pearsonr
Expand All @@ -38,6 +40,14 @@ def column_labels_(self):
def row_labels_(self):
return self.module_a.labels_

@property
def n_row_clusters(self):
return self.module_a.n_clusters

@property
def n_column_clusters(self):
return self.module_b.n_clusters

def _get_x_cb(self, X: np.ndarray, c_b: int):
b_components = self.module_b.labels_ == c_b
return X[b_components]
Expand All @@ -47,8 +57,10 @@ def _pearsonr(a: np.ndarray, b: np.ndarray):
r, _ = pearsonr(a, b)
return r

def _average_pearson_corr(self, X: np.ndarray, k: int, c_a: int, c_b: int) -> float:
X_a = X[self.column_labels_ == c_a, :]
def _average_pearson_corr(self, X: np.ndarray, k: int, c_b: int) -> float:
X_a = X[self.column_labels_ == c_b, :]
if len(X_a) == 0:
raise ValueError("HERE")
X_k_cb = self._get_x_cb(X[k,:], c_b)
mean_r = np.mean(
[
Expand All @@ -57,10 +69,15 @@ def _average_pearson_corr(self, X: np.ndarray, k: int, c_a: int, c_b: int) -> fl
]
)

return mean_r
return float(mean_r)

def validate_data(self, X_a: np.ndarray, X_b: np.ndarray):
self.module_a.validate_data(X_a)
self.module_b.validate_data(X_b)

def match_criterion_bin(self, X: np.ndarray, k: int, c_a: int, c_b: int, params: dict) -> bool:
return self._average_pearson_corr(X, k, c_a, c_b) >= params["eta"]
def match_criterion_bin(self, X: np.ndarray, k: int, c_b: int, params: dict) -> bool:
M = self._average_pearson_corr(X, k, c_b)
return M >= self.params["eta"]

def match_reset_func(
self,
Expand All @@ -73,7 +90,7 @@ def match_reset_func(
) -> bool:
k = extra["k"]
for cluster_b in range(len(self.module_b.W)):
if self.match_criterion_bin(self.X, k, cluster_a, cluster_b, params):
if self.match_criterion_bin(self.X, k, cluster_b, params):
return True
return False

Expand All @@ -86,20 +103,25 @@ def step_fit(self, X: np.ndarray, k: int) -> int:

def fit(self, X: np.ndarray, max_iter=1):
# Check that X and y have correct shape
self.validate_data(X)
self.X = X

n = X.shape[0]
self.module_b = self.module_b.fit(X.T, max_iter=max_iter)
X_a = self.module_b.prepare_data(X)
X_b = self.module_b.prepare_data(X.T)
self.validate_data(X_a, X_b)


self.module_b = self.module_b.fit(X_b, max_iter=max_iter)

# init module A
self.module_a.W = []
self.module_a.labels_ = np.zeros((X.shape[0],), dtype=int)

for _ in range(max_iter):
for k in range(n):
print(k, self.module_a.n_clusters)
self.module_a.pre_step_fit(X)
c_a = self.step_fit(X, k)
c_a = self.step_fit(X_a, k)
self.module_a.labels_[k] = c_a

self.rows_ = np.vstack(
Expand All @@ -119,4 +141,17 @@ def fit(self, X: np.ndarray, max_iter=1):
return self


def visualize(
self,
cmap: Optional[Colormap] = None
):
import matplotlib.pyplot as plt

if cmap is None:
from matplotlib.pyplot import cm
cmap=plt.cm.Blues

plt.matshow(
np.outer(np.sort(self.row_labels_) + 1, np.sort(self.column_labels_) + 1),
cmap=cmap,
)
5 changes: 5 additions & 0 deletions common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.utils.validation import check_is_fitted
from common.utils import normalize


class BaseART(BaseEstimator, ClusterMixin):
Expand All @@ -12,6 +13,10 @@ def __init__(self, params: dict):
self.validate_params(params)
self.params = params

@staticmethod
def prepare_data(X: np.ndarray) -> np.ndarray:
return normalize(X)

@property
def n_clusters(self) -> int:
if hasattr(self, "W"):
Expand Down
4 changes: 4 additions & 0 deletions elementary/FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def get_bounding_box(w: np.ndarray, n: Optional[int] = None) -> tuple[list[int],
class FuzzyART(BaseART):
# implementation of FuzzyART

@staticmethod
def prepare_data(X: np.ndarray) -> np.ndarray:
return prepare_data(X)

@staticmethod
def validate_params(params: dict):
assert "rho" in params
Expand Down
53 changes: 53 additions & 0 deletions examples/test_bartmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

from sklearn.datasets import make_checkerboard
import matplotlib.pyplot as plt
import path
import sys
import numpy as np

# directory reach
directory = path.Path(__file__).abspath()

print(directory.parent)
# setting path
sys.path.append(directory.parent.parent)

from biclustering.BARTMAP import BARTMAP
from elementary.FuzzyART import FuzzyART
from common.utils import normalize


def cluster_checkerboard():
n_clusters = (4, 3)
data, rows, columns = make_checkerboard(
shape=(300, 300), n_clusters=n_clusters, noise=10, shuffle=False, random_state=42
)
print("Data has shape:", data.shape)

X = normalize(data)
print("Prepared data has shape:", X.shape)

params_a = {
"rho": 0.6,
"alpha": 0.0,
"beta": 1.0
}
params_b = {
"rho": 0.6,
"alpha": 0.0,
"beta": 1.0
}
art_a = FuzzyART(params_a)
art_b = FuzzyART(params_b)
cls = BARTMAP(art_a, art_b, {"eta": -1.})
cls.fit(X)

print(f"{cls.n_row_clusters} row clusters found")
print(f"{cls.n_column_clusters} column clusters found")

cls.visualize()
plt.show()


if __name__ == "__main__":
cluster_checkerboard()