From 1fb869a397fb081abe738585a35a5ac9878b8e95 Mon Sep 17 00:00:00 2001 From: niklas melton Date: Wed, 13 Mar 2024 21:49:49 -0500 Subject: [PATCH] assert parameter types --- biclustering/BARTMAP.py | 16 ++++++++++++++-- elementary/ART1.py | 3 +++ elementary/ART2.py | 3 +++ elementary/BayesianART.py | 2 ++ elementary/EllipsoidART.py | 5 +++++ elementary/FuzzyART.py | 3 +++ elementary/GaussianART.py | 2 ++ elementary/HypersphereART.py | 4 ++++ elementary/QuadraticNeuronART.py | 8 ++++++++ fusion/FusionART.py | 1 + 10 files changed, 45 insertions(+), 2 deletions(-) diff --git a/biclustering/BARTMAP.py b/biclustering/BARTMAP.py index a4bedf1..8bf115c 100644 --- a/biclustering/BARTMAP.py +++ b/biclustering/BARTMAP.py @@ -123,6 +123,7 @@ def validate_params(params): """ assert "eta" in params + assert isinstance(params["eta"], float) @property def column_labels_(self): @@ -140,9 +141,20 @@ def n_row_clusters(self): def n_column_clusters(self): return self.module_b.n_clusters - def _get_x_cb(self, X: np.ndarray, c_b: int): + def _get_x_cb(self, x: np.ndarray, c_b: int): + """ + get the components of a vector belonging to a b-side cluster + + Parameters: + - x: a sample vector + - c_b: b-side cluster label + + Returns: + x filtered to features belonging to the b-side cluster c_b + + """ b_components = self.module_b.labels_ == c_b - return X[b_components] + return x[b_components] @staticmethod def _pearsonr(a: np.ndarray, b: np.ndarray): diff --git a/elementary/ART1.py b/elementary/ART1.py index 132fd66..2c4b66e 100644 --- a/elementary/ART1.py +++ b/elementary/ART1.py @@ -46,6 +46,9 @@ def validate_params(params: dict): assert 1. >= params["rho"] >= 0. assert 1. >= params["beta"] >= 0. assert params["L"] >= 1. + assert isinstance(params["rho"], float) + assert isinstance(params["beta"], float) + assert isinstance(params["L"], float) def validate_data(self, X: np.ndarray): """ diff --git a/elementary/ART2.py b/elementary/ART2.py index c939e1a..d2c57b1 100644 --- a/elementary/ART2.py +++ b/elementary/ART2.py @@ -57,6 +57,9 @@ def validate_params(params: dict): assert 1. >= params["rho"] >= 0. assert 1. >= params["alpha"] >= 0. assert 1. >= params["beta"] >= 0. + assert isinstance(params["rho"], float) + assert isinstance(params["alpha"], float) + assert isinstance(params["beta"], float) def check_dimensions(self, X: np.ndarray): """ diff --git a/elementary/BayesianART.py b/elementary/BayesianART.py index a1d2ef8..80aaca6 100644 --- a/elementary/BayesianART.py +++ b/elementary/BayesianART.py @@ -43,6 +43,8 @@ def validate_params(params: dict): assert "rho" in params assert "cov_init" in params assert params["rho"] > 0 + assert isinstance(params["rho"], float) + assert isinstance(params["cov_init"], np.ndarray) def check_dimensions(self, X: np.ndarray): """ diff --git a/elementary/EllipsoidART.py b/elementary/EllipsoidART.py index 7e4b643..40fa669 100644 --- a/elementary/EllipsoidART.py +++ b/elementary/EllipsoidART.py @@ -54,6 +54,11 @@ def validate_params(params: dict): assert 1.0 >= params["alpha"] >= 0. assert 1.0 >= params["beta"] >= 0. assert 1.0 >= params["mu"] > 0. + assert isinstance(params["rho"], float) + assert isinstance(params["alpha"], float) + assert isinstance(params["beta"], float) + assert isinstance(params["mu"], float) + assert isinstance(params["r_hat"], float) @staticmethod def category_distance(i: np.ndarray, centroid: np.ndarray, major_axis: np.ndarray, params): diff --git a/elementary/FuzzyART.py b/elementary/FuzzyART.py index ca4f257..37be70d 100644 --- a/elementary/FuzzyART.py +++ b/elementary/FuzzyART.py @@ -98,6 +98,9 @@ def validate_params(params: dict): assert 1.0 >= params["rho"] >= 0. assert params["alpha"] >= 0. assert 1.0 >= params["beta"] > 0. + assert isinstance(params["rho"], float) + assert isinstance(params["alpha"], float) + assert isinstance(params["beta"], float) def check_dimensions(self, X: np.ndarray): """ diff --git a/elementary/GaussianART.py b/elementary/GaussianART.py index 6a812c9..0d241f9 100644 --- a/elementary/GaussianART.py +++ b/elementary/GaussianART.py @@ -39,6 +39,8 @@ def validate_params(params: dict): assert "rho" in params assert "sigma_init" in params assert 1.0 >= params["rho"] > 0. + assert isinstance(params["rho"], float) + assert isinstance(params["sigma_init"], np.ndarray) def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]: diff --git a/elementary/HypersphereART.py b/elementary/HypersphereART.py index 83b8665..a64a1f3 100644 --- a/elementary/HypersphereART.py +++ b/elementary/HypersphereART.py @@ -45,6 +45,10 @@ def validate_params(params: dict): assert 1.0 >= params["rho"] >= 0. assert params["alpha"] >= 0. assert 1.0 >= params["beta"] >= 0. + assert isinstance(params["rho"], float) + assert isinstance(params["alpha"], float) + assert isinstance(params["beta"], float) + assert isinstance(params["r_hat"], float) @staticmethod def category_distance(i: np.ndarray, centroid: np.ndarray, radius: float, params) -> float: diff --git a/elementary/QuadraticNeuronART.py b/elementary/QuadraticNeuronART.py index 02a0555..6ef31a3 100644 --- a/elementary/QuadraticNeuronART.py +++ b/elementary/QuadraticNeuronART.py @@ -55,6 +55,14 @@ def validate_params(params: dict): assert "lr_w" in params assert "lr_s" in params assert 1.0 >= params["rho"] >= 0. + assert 1.0 >= params["lr_b"] > 0. + assert 1.0 >= params["lr_w"] >= 0. + assert 1.0 >= params["lr_s"] >= 0. + assert isinstance(params["rho"], float) + assert isinstance(params["s_init"], float) + assert isinstance(params["lr_b"], float) + assert isinstance(params["lr_w"], float) + assert isinstance(params["lr_s"], float) def category_choice(self, i: np.ndarray, w: np.ndarray, params: dict) -> tuple[float, Optional[dict]]: """ diff --git a/fusion/FusionART.py b/fusion/FusionART.py index a7e2a6a..d8e6958 100644 --- a/fusion/FusionART.py +++ b/fusion/FusionART.py @@ -89,6 +89,7 @@ def validate_params(params: dict): assert "gamma_values" in params assert all([1.0 >= g >= 0.0 for g in params["gamma_values"]]) assert sum(params["gamma_values"]) == 1.0 + assert isinstance(params["gamma_values"], np.ndarray) def validate_data(self, X: np.ndarray):