From 8b495e8fafd4eda9d2181d3d5a07e5da277c932f Mon Sep 17 00:00:00 2001 From: niklas melton Date: Wed, 8 Jan 2025 21:10:03 -0600 Subject: [PATCH] improve error message and add helper functions --- artlib/common/BaseART.py | 58 ++++++++++++++++++++++++++++++++-- unit_tests/test_FuzzyART.py | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/artlib/common/BaseART.py b/artlib/common/BaseART.py index 6e6abce..409816b 100644 --- a/artlib/common/BaseART.py +++ b/artlib/common/BaseART.py @@ -106,6 +106,45 @@ def set_params(self, **params): self.validate_params(local_params) return self + def set_data_bounds(self, lower_bounds: np.ndarray, upper_bounds: np.ndarray): + """Manually set the data bounds for normalization. + + Parameters + ---------- + lower_bounds : np.ndarray + The lower bounds for each column. + + upper_bounds : np.ndarray + The upper bounds for each column. + + """ + if self.is_fitted_: + raise ValueError("Cannot change data limits after fit.") + self.d_min_ = lower_bounds + self.d_max_ = upper_bounds + + def find_data_bounds( + self, *data_batches: list[np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray]: + """Manually set the data bounds for normalization. + + Parameters + ---------- + *data_batches : list[np.ndarray] + Batches of data to be presented to the model + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Lower and upper bounds for data. + + """ + all_data = np.vstack(data_batches) + lower_bounds = np.min(all_data) + upper_bounds = np.max(all_data) + + return lower_bounds, upper_bounds + def prepare_data(self, X: np.ndarray) -> np.ndarray: """Prepare data for clustering. @@ -187,8 +226,23 @@ def validate_data(self, X: np.ndarray): - X: data set """ - assert np.all(X >= 0), "Data has not been normalized" - assert np.all(X <= 1.0), "Data has not been normalized" + normalization_message = ( + "Data has not been normalized or was not normalized " + "correctly. All values must fall between 0 and 1, " + "inclusively." + ) + if self.is_fitted_: + normalization_message += ( + "\nThis appears to not be the first batch of " + "data. Data boundaries must be calculated for " + "the entire data space. Prior to fitting, use " + "BaseART.set_data_bounds() to manually set the " + "bounds for your data or use " + "BaseART.find_data_bounds() to identify the " + "bounds automatically for multiple batches." + ) + assert np.all(X >= 0), normalization_message + assert np.all(X <= 1.0), normalization_message self.check_dimensions(X) def category_choice( diff --git a/unit_tests/test_FuzzyART.py b/unit_tests/test_FuzzyART.py index e07c848..e4cd338 100644 --- a/unit_tests/test_FuzzyART.py +++ b/unit_tests/test_FuzzyART.py @@ -148,3 +148,66 @@ def test_clustering(art_model): labels = art_model.fit_predict(data) assert np.all(np.equal(labels, np.array([0, 0, 1, 2, 3]))) + + +def test_validate_data(art_model): + # Test validate_data with normalized data + X = np.array([[-0.1, 0.2], [1.1, 0.4]]) + art_model.is_fitted_ = False + with pytest.raises(AssertionError): + art_model.validate_data(X) + +def test_validate_data_again(art_model): + # Test validate_data with normalized data + X = np.array([[0.1, 0.2], [0.3, 0.4]]) + art_model.is_fitted_ = False + art_model.validate_data(X) # Should pass without assertion error + + # Test validate_data with data out of bounds + X_invalid = np.array([[-0.1, 0.2], [1.1, 0.4]]) + with pytest.raises(AssertionError): + art_model.validate_data(X_invalid) + + +def test_set_data_bounds(art_model): + # Test set_data_bounds with valid bounds + lower_bounds = np.array([0.0, 0.0]) + upper_bounds = np.array([1.0, 1.0]) + art_model.is_fitted_ = False + art_model.set_data_bounds(lower_bounds, upper_bounds) + assert np.all(art_model.d_min_ == lower_bounds) + assert np.all(art_model.d_max_ == upper_bounds) + + # Test set_data_bounds after the model is fitted + art_model.is_fitted_ = True + with pytest.raises(ValueError, match="Cannot change data limits after fit."): + art_model.set_data_bounds(lower_bounds, upper_bounds) + X = np.array([[0.1, 0.2], [0.3, 0.4]]) + X_norm = art_model.prepare_data(X) + assert np.all(X_norm == X) + + +def test_find_data_bounds(art_model): + # Test find_data_bounds with multiple data batches + batch_1 = np.array([[0.1, 0.2], [0.3, 0.4]]) + batch_2 = np.array([[0.0, 0.1], [0.5, 0.6]]) + lower_bounds, upper_bounds = art_model.find_data_bounds(batch_1, batch_2) + np.testing.assert_array_equal(lower_bounds, np.array([0.0, 0.1])) + np.testing.assert_array_equal(upper_bounds, np.array([0.5, 0.6])) + + +def test_prepare_data(art_model): + # Test prepare_data with valid data + X = np.array([[0.0, 0.5], [0.2, 1.0]]) + art_model.d_min_ = np.array([0.0, 0.0]) + art_model.d_max_ = np.array([1.0, 1.0]) + normalized_X = art_model.prepare_data(X) + np.testing.assert_array_almost_equal(normalized_X, X) # Already normalized + + # Test prepare_data with data requiring normalization + X = np.array([[1.0, 10.0], [5.0, 20.0]]) + art_model.d_min_ = np.array([1.0, 10.0]) + art_model.d_max_ = np.array([5.0, 20.0]) + normalized_X = art_model.prepare_data(X) + expected_normalized_X = np.array([[0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(normalized_X, expected_normalized_X) \ No newline at end of file