Skip to content

Commit

Permalink
Merge pull request #56 from NiklasMelton/add-parameter-type-checks
Browse files Browse the repository at this point in the history
assert parameter types
  • Loading branch information
NiklasMelton authored Mar 14, 2024
2 parents c181f18 + 1fb869a commit 4d8aea0
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 2 deletions.
16 changes: 14 additions & 2 deletions biclustering/BARTMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def validate_params(params):
"""
assert "eta" in params
assert isinstance(params["eta"], float)

@property
def column_labels_(self):
Expand All @@ -140,9 +141,20 @@ def n_row_clusters(self):
def n_column_clusters(self):
return self.module_b.n_clusters

def _get_x_cb(self, X: np.ndarray, c_b: int):
def _get_x_cb(self, x: np.ndarray, c_b: int):
"""
get the components of a vector belonging to a b-side cluster
Parameters:
- x: a sample vector
- c_b: b-side cluster label
Returns:
x filtered to features belonging to the b-side cluster c_b
"""
b_components = self.module_b.labels_ == c_b
return X[b_components]
return x[b_components]

@staticmethod
def _pearsonr(a: np.ndarray, b: np.ndarray):
Expand Down
3 changes: 3 additions & 0 deletions elementary/ART1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def validate_params(params: dict):
assert 1. >= params["rho"] >= 0.
assert 1. >= params["beta"] >= 0.
assert params["L"] >= 1.
assert isinstance(params["rho"], float)
assert isinstance(params["beta"], float)
assert isinstance(params["L"], float)

def validate_data(self, X: np.ndarray):
"""
Expand Down
3 changes: 3 additions & 0 deletions elementary/ART2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def validate_params(params: dict):
assert 1. >= params["rho"] >= 0.
assert 1. >= params["alpha"] >= 0.
assert 1. >= params["beta"] >= 0.
assert isinstance(params["rho"], float)
assert isinstance(params["alpha"], float)
assert isinstance(params["beta"], float)

def check_dimensions(self, X: np.ndarray):
"""
Expand Down
2 changes: 2 additions & 0 deletions elementary/BayesianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def validate_params(params: dict):
assert "rho" in params
assert "cov_init" in params
assert params["rho"] > 0
assert isinstance(params["rho"], float)
assert isinstance(params["cov_init"], np.ndarray)

def check_dimensions(self, X: np.ndarray):
"""
Expand Down
5 changes: 5 additions & 0 deletions elementary/EllipsoidART.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def validate_params(params: dict):
assert 1.0 >= params["alpha"] >= 0.
assert 1.0 >= params["beta"] >= 0.
assert 1.0 >= params["mu"] > 0.
assert isinstance(params["rho"], float)
assert isinstance(params["alpha"], float)
assert isinstance(params["beta"], float)
assert isinstance(params["mu"], float)
assert isinstance(params["r_hat"], float)

@staticmethod
def category_distance(i: np.ndarray, centroid: np.ndarray, major_axis: np.ndarray, params):
Expand Down
3 changes: 3 additions & 0 deletions elementary/FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def validate_params(params: dict):
assert 1.0 >= params["rho"] >= 0.
assert params["alpha"] >= 0.
assert 1.0 >= params["beta"] > 0.
assert isinstance(params["rho"], float)
assert isinstance(params["alpha"], float)
assert isinstance(params["beta"], float)

def check_dimensions(self, X: np.ndarray):
"""
Expand Down
2 changes: 2 additions & 0 deletions elementary/GaussianART.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def validate_params(params: dict):
assert "rho" in params
assert "sigma_init" in params
assert 1.0 >= params["rho"] > 0.
assert isinstance(params["rho"], float)
assert isinstance(params["sigma_init"], np.ndarray)


def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
Expand Down
4 changes: 4 additions & 0 deletions elementary/HypersphereART.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def validate_params(params: dict):
assert 1.0 >= params["rho"] >= 0.
assert params["alpha"] >= 0.
assert 1.0 >= params["beta"] >= 0.
assert isinstance(params["rho"], float)
assert isinstance(params["alpha"], float)
assert isinstance(params["beta"], float)
assert isinstance(params["r_hat"], float)

@staticmethod
def category_distance(i: np.ndarray, centroid: np.ndarray, radius: float, params) -> float:
Expand Down
8 changes: 8 additions & 0 deletions elementary/QuadraticNeuronART.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def validate_params(params: dict):
assert "lr_w" in params
assert "lr_s" in params
assert 1.0 >= params["rho"] >= 0.
assert 1.0 >= params["lr_b"] > 0.
assert 1.0 >= params["lr_w"] >= 0.
assert 1.0 >= params["lr_s"] >= 0.
assert isinstance(params["rho"], float)
assert isinstance(params["s_init"], float)
assert isinstance(params["lr_b"], float)
assert isinstance(params["lr_w"], float)
assert isinstance(params["lr_s"], float)

def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]:
"""
Expand Down
1 change: 1 addition & 0 deletions fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def validate_params(params: dict):
assert "gamma_values" in params
assert all([1.0 >= g >= 0.0 for g in params["gamma_values"]])
assert sum(params["gamma_values"]) == 1.0
assert isinstance(params["gamma_values"], np.ndarray)


def validate_data(self, X: np.ndarray):
Expand Down

0 comments on commit 4d8aea0

Please sign in to comment.