From 4c3bfae18b045b6bb7410b9bf06ed16e27e74881 Mon Sep 17 00:00:00 2001 From: Niklas Melton Date: Fri, 18 Oct 2024 12:20:41 -0500 Subject: [PATCH 1/2] revert match_criterion typing --- artlib/common/BaseART.py | 2 +- artlib/elementary/ART1.py | 4 ++-- artlib/elementary/BayesianART.py | 2 +- artlib/elementary/EllipsoidART.py | 4 ++-- artlib/elementary/FuzzyART.py | 4 ++-- artlib/elementary/GaussianART.py | 4 ++-- artlib/elementary/QuadraticNeuronART.py | 4 ++-- artlib/experimental/ConvexHullART.py | 2 +- artlib/experimental/SeqART.py | 2 +- artlib/fusion/FusionART.py | 8 ++++---- artlib/topological/TopoART.py | 2 +- 11 files changed, 19 insertions(+), 19 deletions(-) diff --git a/artlib/common/BaseART.py b/artlib/common/BaseART.py index 6ab8255..e85a7c9 100644 --- a/artlib/common/BaseART.py +++ b/artlib/common/BaseART.py @@ -219,7 +219,7 @@ def match_criterion( w: np.ndarray, params: Dict, cache: Optional[Dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/ART1.py b/artlib/elementary/ART1.py index bc6b1c2..ad4e10f 100644 --- a/artlib/elementary/ART1.py +++ b/artlib/elementary/ART1.py @@ -7,7 +7,7 @@ # Processing, 37, 54 – 115. doi:10. 1016/S0734-189X(87)80014-2. import numpy as np -from typing import Optional, List, Tuple, Union, Dict +from typing import Optional, List, Tuple, Dict from artlib.common.BaseART import BaseART from artlib.common.utils import l1norm @@ -108,7 +108,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/BayesianART.py b/artlib/elementary/BayesianART.py index 6b19f9b..63bda72 100644 --- a/artlib/elementary/BayesianART.py +++ b/artlib/elementary/BayesianART.py @@ -127,7 +127,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/EllipsoidART.py b/artlib/elementary/EllipsoidART.py index 0c5b9d2..47796e8 100644 --- a/artlib/elementary/EllipsoidART.py +++ b/artlib/elementary/EllipsoidART.py @@ -10,7 +10,7 @@ # International Society for Optics and Photonics. doi:10.1117/12.421180. import numpy as np -from typing import Optional, List, Tuple, Union, Dict +from typing import Optional, List, Tuple, Dict from matplotlib.axes import Axes from artlib.common.BaseART import BaseART from artlib.common.utils import l2norm2, IndexableOrKeyable @@ -166,7 +166,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/FuzzyART.py b/artlib/elementary/FuzzyART.py index 5410b7e..ef666cb 100644 --- a/artlib/elementary/FuzzyART.py +++ b/artlib/elementary/FuzzyART.py @@ -5,7 +5,7 @@ # Neural Networks, 4, 759 – 771. doi:10.1016/0893-6080(91)90056-B. import numpy as np -from typing import Optional, Iterable, List, Tuple, Union, Dict +from typing import Optional, Iterable, List, Tuple, Dict from matplotlib.axes import Axes from artlib.common.BaseART import BaseART from artlib.common.utils import ( @@ -206,7 +206,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/GaussianART.py b/artlib/elementary/GaussianART.py index 34a917b..ee6778a 100644 --- a/artlib/elementary/GaussianART.py +++ b/artlib/elementary/GaussianART.py @@ -5,7 +5,7 @@ # Neural Networks, 9, 881 – 897. doi:10.1016/0893-6080(95)00115-8. import numpy as np -from typing import Optional, Iterable, List, Tuple, Union, Dict +from typing import Optional, Iterable, List, Tuple, Dict from matplotlib.axes import Axes from artlib.common.BaseART import BaseART from artlib.common.visualization import plot_gaussian_contours_fading @@ -109,7 +109,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/elementary/QuadraticNeuronART.py b/artlib/elementary/QuadraticNeuronART.py index 1d0c6f7..f997597 100644 --- a/artlib/elementary/QuadraticNeuronART.py +++ b/artlib/elementary/QuadraticNeuronART.py @@ -8,7 +8,7 @@ # Pattern Recognition, 38, 1887 – 1901. doi:10.1016/j.patcog.2005.04.010. import numpy as np -from typing import Optional, Iterable, List, Tuple, Union, Dict +from typing import Optional, Iterable, List, Tuple, Dict from matplotlib.axes import Axes from artlib.common.BaseART import BaseART from artlib.common.utils import l2norm2 @@ -133,7 +133,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters diff --git a/artlib/experimental/ConvexHullART.py b/artlib/experimental/ConvexHullART.py index 0929134..3d12ea4 100644 --- a/artlib/experimental/ConvexHullART.py +++ b/artlib/experimental/ConvexHullART.py @@ -271,7 +271,7 @@ def match_criterion( w: HullTypes, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """ Get the match criterion of the cluster. diff --git a/artlib/experimental/SeqART.py b/artlib/experimental/SeqART.py index 7621f60..1e052ee 100644 --- a/artlib/experimental/SeqART.py +++ b/artlib/experimental/SeqART.py @@ -234,7 +234,7 @@ def category_choice( def match_criterion( self, i: str, w: str, params: dict, cache: Optional[dict] = None - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """ Get the match criterion of the cluster. diff --git a/artlib/fusion/FusionART.py b/artlib/fusion/FusionART.py index 58449ef..5880b67 100644 --- a/artlib/fusion/FusionART.py +++ b/artlib/fusion/FusionART.py @@ -289,7 +289,7 @@ def match_criterion( params: Dict, cache: Optional[Dict] = None, skip_channels: List[int] = [], - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion for the cluster. Parameters @@ -308,7 +308,7 @@ def match_criterion( Returns ------- tuple - List of match criteria for each channel and the updated cache. + max match_criterion across channels and the updated cache. """ if cache is None: @@ -322,12 +322,12 @@ def match_criterion( cache[k], ) if k not in skip_channels - else (np.inf, {"match_criterion": np.inf}) + else (np.nan, {"match_criterion": np.inf}) for k in range(self.n) ] ) cache = {k: cache_k for k, cache_k in enumerate(caches)} - return M, cache + return np.nanmax(M), cache def match_criterion_bin( self, diff --git a/artlib/topological/TopoART.py b/artlib/topological/TopoART.py index 35efe23..66489d3 100644 --- a/artlib/topological/TopoART.py +++ b/artlib/topological/TopoART.py @@ -192,7 +192,7 @@ def match_criterion( w: np.ndarray, params: dict, cache: Optional[dict] = None, - ) -> Tuple[Union[float, List[float]], Optional[Dict]]: + ) -> Tuple[float, Optional[Dict]]: """Get the match criterion of the cluster. Parameters From fc233b8e4dfa2195ecce925f5fa90ea1fd4a22e9 Mon Sep 17 00:00:00 2001 From: Niklas Melton Date: Fri, 18 Oct 2024 12:23:32 -0500 Subject: [PATCH 2/2] revert match_criterion typing --- artlib/common/BaseART.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artlib/common/BaseART.py b/artlib/common/BaseART.py index e85a7c9..ea40e96 100644 --- a/artlib/common/BaseART.py +++ b/artlib/common/BaseART.py @@ -88,7 +88,7 @@ def set_params(self, **params): for key, value in params.items(): key, delim, sub_key = key.partition("__") if key not in valid_params: - local_valid_params = List(valid_params.keys()) + local_valid_params = list(valid_params.keys()) raise ValueError( f"Invalid parameter {key!r} for estimator {self}. " f"Valid parameters are: {local_valid_params!r}."