Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn compat #53

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ import numpy as np

# Your dataset
train_X = np.array([...])

# Your model parameters
params = {...}
test_X = np.array([...])

# Initialize the Fuzzy ART model
model = FuzzyART(params)
model = FuzzyART(rho=0.7, alpha = 0.0, beta=1.0)

# Fit the model
model.fit(train_X)
Expand Down
70 changes: 70 additions & 0 deletions common/BaseART.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from typing import Optional, Callable, Iterable
from collections import defaultdict
from matplotlib.axes import Axes
from warnings import warn
from sklearn.base import BaseEstimator, ClusterMixin
Expand All @@ -19,6 +20,75 @@ def __init__(self, params: dict):
self.validate_params(params)
self.params = params

def __getattr__(self, key):
if key in self.params:
return self.params[key]
else:
# If the key is not in params, raise an AttributeError
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")

def __setattr__(self, key, value):
if key in self.__dict__.get('params', {}):
# If key is in params, set its value
self.params[key] = value
else:
# Otherwise, proceed with normal attribute setting
super().__setattr__(key, value)


def get_params(self, deep: bool = True) -> dict:
"""

Parameters:
- deep: If True, will return the parameters for this class and contained subobjects that are estimators.

Returns:
Parameter names mapped to their values.

"""
return self.params

def set_params(self, **params):
"""Set the parameters of this estimator.

Specific redefinition of sklearn.BaseEstimator.set_params for ART classes

Parameters:
- **params : Estimator parameters.

Returns:
- self : estimator instance
"""

if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
local_params = dict()

nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition("__")
if key not in valid_params:
local_valid_params = list(valid_params.keys())
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Valid parameters are: {local_valid_params!r}."
)

if delim:
nested_params[key][sub_key] = value
else:
setattr(self, key, value)
valid_params[key] = value
local_params[key] = value

for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
self.validate_params(local_params)
return self


@staticmethod
def prepare_data(X: np.ndarray) -> np.ndarray:
"""
Expand Down
46 changes: 44 additions & 2 deletions common/BaseARTMAP.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,54 @@
import numpy as np
from typing import Union, Optional, Iterable
from collections import defaultdict
from matplotlib.axes import Axes
from sklearn.base import BaseEstimator, ClassifierMixin, ClusterMixin

class BaseARTMAP(BaseEstimator, ClassifierMixin, ClusterMixin):
map: dict[int, int]

def map_a2b(self, y_a: Union[np.ndarray, int]) -> np.ndarray:
def __init__(self):
self.map: dict[int, int] = dict()

def set_params(self, **params):
"""Set the parameters of this estimator.

Specific redefinition of sklearn.BaseEstimator.set_params for ARTMAP classes

Parameters:
- **params : Estimator parameters.

Returns:
- self : estimator instance
"""

if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
local_params = dict()

nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition("__")
if key not in valid_params:
local_valid_params = list(valid_params.keys())
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Valid parameters are: {local_valid_params!r}."
)

if delim:
nested_params[key][sub_key] = value
else:
setattr(self, key, value)
valid_params[key] = value
local_params[key] = value

for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
return self

def map_a2b(self, y_a: Union[np.ndarray, int]) -> Union[np.ndarray, int]:
if isinstance(y_a, int):
return self.map[y_a]
u, inv = np.unique(y_a, return_inverse=True)
Expand Down
7 changes: 7 additions & 0 deletions elementary/ART1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def prepare_data(data: np.ndarray) -> np.ndarray:

class ART1(BaseART):
# implementation of ART 1
def __init__(self, rho: float, beta: float, L: float):
params = {
"rho": rho,
"beta": beta,
"L": L
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
7 changes: 7 additions & 0 deletions elementary/ART2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def prepare_data(data: np.ndarray) -> np.ndarray:
class ART2A(BaseART):
warn("Do Not Use ART2. It does not work. This module is provided for completeness only")
# implementation of ART 2-A
def __init__(self, rho: float, alpha: float, beta: float):
params = {
"rho": rho,
"alpha": alpha,
"beta": beta,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
6 changes: 6 additions & 0 deletions elementary/BayesianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def prepare_data(data: np.ndarray) -> np.ndarray:
class BayesianART(BaseART):
# implementation of Bayesian ART
pi2 = np.pi * 2
def __init__(self, rho: float, cov_init: np.ndarray):
params = {
"rho": rho,
"cov_init": cov_init,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
17 changes: 17 additions & 0 deletions elementary/DualVigilanceART.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ def __init__(self, base_module: BaseART, lower_bound: float):
self.lower_bound = lower_bound
self.map: dict[int, int] = dict()

def get_params(self, deep: bool = True) -> dict:
"""

Parameters:
- deep: If True, will return the parameters for this class and contained subobjects that are estimators.

Returns:
Parameter names mapped to their values.

"""
out = self.params
if deep:
deep_items = self.base_module.get_params().items()
out.update(("base_module" + "__" + k, val) for k, val in deep_items)
out["base_module"] = self.base_module
return out

@property
def n_clusters(self) -> int:
return len(set(c for c in self.map.values()))
Expand Down
9 changes: 9 additions & 0 deletions elementary/EllipsoidART.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@

class EllipsoidART(BaseART):
# implementation of EllipsoidART
def __init__(self, rho: float, alpha: float, beta: float, mu: float, r_hat: float):
params = {
"rho": rho,
"alpha": alpha,
"beta": beta,
"mu": mu,
"r_hat": r_hat,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
7 changes: 7 additions & 0 deletions elementary/FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def get_bounding_box(w: np.ndarray, n: Optional[int] = None) -> tuple[list[int],

class FuzzyART(BaseART):
# implementation of FuzzyART
def __init__(self, rho: float, alpha: float, beta: float):
params = {
"rho": rho,
"alpha": alpha,
"beta": beta,
}
super().__init__(params)

@staticmethod
def prepare_data(X: np.ndarray) -> np.ndarray:
Expand Down
6 changes: 6 additions & 0 deletions elementary/GaussianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
class GaussianART(BaseART):
# implementation of GaussianART
pi2 = np.pi*2
def __init__(self, rho: float, sigma_init: np.ndarray):
params = {
"rho": rho,
"sigma_init": sigma_init,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
8 changes: 8 additions & 0 deletions elementary/HypersphereART.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@

class HypersphereART(BaseART):
# implementation of HypersphereART
def __init__(self, rho: float, alpha: float, beta: float, r_hat: float):
params = {
"rho": rho,
"alpha": alpha,
"beta": beta,
"r_hat": r_hat,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
9 changes: 9 additions & 0 deletions elementary/QuadraticNeuronART.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def prepare_data(data: np.ndarray) -> np.ndarray:

class QuadraticNeuronART(BaseART):
# implementation of QuadraticNeuronART
def __init__(self, rho: float, s_init: float, lr_b: float, lr_w: float, lr_s: float):
params = {
"rho": rho,
"s_init": s_init,
"lr_b": lr_b,
"lr_w": lr_w,
"lr_s": lr_s,
}
super().__init__(params)

@staticmethod
def validate_params(params: dict):
Expand Down
2 changes: 1 addition & 1 deletion examples/test_art1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cluster_blobs():
"beta": 1.0,
"L": 1.0
}
cls = ART1(params)
cls = ART1(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
2 changes: 1 addition & 1 deletion examples/test_art2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def cluster_blobs():
"alpha": 0.0,
"beta": 1.0,
}
cls = ART2A(params)
cls = ART2A(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
4 changes: 2 additions & 2 deletions examples/test_artmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cluster_iris():
"alpha": 0.0,
"beta": 1.0
}
art = FuzzyART(params)
art = FuzzyART(**params)

cls = SimpleARTMAP(art)

Expand Down Expand Up @@ -63,7 +63,7 @@ def cluster_blobs():
"alpha": 0.0,
"beta": 1.0
}
art = FuzzyART(params)
art = FuzzyART(**params)

cls = SimpleARTMAP(art)

Expand Down
2 changes: 1 addition & 1 deletion examples/test_bayesian_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def cluster_blobs():
"rho": 2e-5,
"cov_init": np.array([[1e-4, 0.0], [0.0, 1e-4]]),
}
cls = BayesianART(params)
cls = BayesianART(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
2 changes: 1 addition & 1 deletion examples/test_dual_vigilance_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def cluster_blobs():
"alpha": 0.8,
"beta": 1.0
}
base_art = FuzzyART(params)
base_art = FuzzyART(**params)
cls = DualVigilanceART(base_art, 0.78)
y = cls.fit_predict(X)

Expand Down
2 changes: 1 addition & 1 deletion examples/test_ellipsoid_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def cluster_blobs():
"r_hat": 0.6,
"mu": 0.8
}
cls = EllipsoidART(params)
cls = EllipsoidART(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
4 changes: 2 additions & 2 deletions examples/test_fusion_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def cluster_blobs():
"alpha": 0.0,
"beta": 1.0
}
art_a = FuzzyART(params)
art_b = FuzzyART(params)
art_a = FuzzyART(**params)
art_b = FuzzyART(**params)
cls = FusionART([art_a, art_b], gamma_values=[0.5, 0.5], channel_dims=[2,2])
y = cls.fit_predict(X)

Expand Down
2 changes: 1 addition & 1 deletion examples/test_fuzzy_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def cluster_blobs():
"alpha": 0.0,
"beta": 1.0
}
cls = FuzzyART(params)
cls = FuzzyART(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
2 changes: 1 addition & 1 deletion examples/test_gaussian_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def cluster_blobs():
"rho": 0.15,
"sigma_init": np.array([0.5, 0.5]),
}
cls = GaussianART(params)
cls = GaussianART(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
2 changes: 1 addition & 1 deletion examples/test_hypersphere_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cluster_blobs():
"beta": 1.0,
"r_hat": 0.8
}
cls = HypersphereART(params)
cls = HypersphereART(**params)
y = cls.fit_predict(X)

print(f"{cls.n_clusters} clusters found")
Expand Down
Loading