Skip to content

Commit

Permalink
improve error message and add helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasMelton committed Jan 9, 2025
1 parent 2610414 commit 8b495e8
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
58 changes: 56 additions & 2 deletions artlib/common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,45 @@ def set_params(self, **params):
self.validate_params(local_params)
return self

def set_data_bounds(self, lower_bounds: np.ndarray, upper_bounds: np.ndarray):
"""Manually set the data bounds for normalization.
Parameters
----------
lower_bounds : np.ndarray
The lower bounds for each column.
upper_bounds : np.ndarray
The upper bounds for each column.
"""
if self.is_fitted_:
raise ValueError("Cannot change data limits after fit.")
self.d_min_ = lower_bounds
self.d_max_ = upper_bounds

def find_data_bounds(
self, *data_batches: list[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
"""Manually set the data bounds for normalization.
Parameters
----------
*data_batches : list[np.ndarray]
Batches of data to be presented to the model
Returns
-------
tuple[np.ndarray, np.ndarray]
Lower and upper bounds for data.
"""
all_data = np.vstack(data_batches)
lower_bounds = np.min(all_data)
upper_bounds = np.max(all_data)

return lower_bounds, upper_bounds

def prepare_data(self, X: np.ndarray) -> np.ndarray:
"""Prepare data for clustering.
Expand Down Expand Up @@ -187,8 +226,23 @@ def validate_data(self, X: np.ndarray):
- X: data set
"""
assert np.all(X >= 0), "Data has not been normalized"
assert np.all(X <= 1.0), "Data has not been normalized"
normalization_message = (
"Data has not been normalized or was not normalized "
"correctly. All values must fall between 0 and 1, "
"inclusively."
)
if self.is_fitted_:
normalization_message += (
"\nThis appears to not be the first batch of "
"data. Data boundaries must be calculated for "
"the entire data space. Prior to fitting, use "
"BaseART.set_data_bounds() to manually set the "
"bounds for your data or use "
"BaseART.find_data_bounds() to identify the "
"bounds automatically for multiple batches."
)
assert np.all(X >= 0), normalization_message
assert np.all(X <= 1.0), normalization_message
self.check_dimensions(X)

def category_choice(
Expand Down
63 changes: 63 additions & 0 deletions unit_tests/test_FuzzyART.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,66 @@ def test_clustering(art_model):
labels = art_model.fit_predict(data)

assert np.all(np.equal(labels, np.array([0, 0, 1, 2, 3])))


def test_validate_data(art_model):
# Test validate_data with normalized data
X = np.array([[-0.1, 0.2], [1.1, 0.4]])
art_model.is_fitted_ = False
with pytest.raises(AssertionError):
art_model.validate_data(X)

def test_validate_data_again(art_model):
# Test validate_data with normalized data
X = np.array([[0.1, 0.2], [0.3, 0.4]])
art_model.is_fitted_ = False
art_model.validate_data(X) # Should pass without assertion error

# Test validate_data with data out of bounds
X_invalid = np.array([[-0.1, 0.2], [1.1, 0.4]])
with pytest.raises(AssertionError):
art_model.validate_data(X_invalid)


def test_set_data_bounds(art_model):
# Test set_data_bounds with valid bounds
lower_bounds = np.array([0.0, 0.0])
upper_bounds = np.array([1.0, 1.0])
art_model.is_fitted_ = False
art_model.set_data_bounds(lower_bounds, upper_bounds)
assert np.all(art_model.d_min_ == lower_bounds)
assert np.all(art_model.d_max_ == upper_bounds)

# Test set_data_bounds after the model is fitted
art_model.is_fitted_ = True
with pytest.raises(ValueError, match="Cannot change data limits after fit."):
art_model.set_data_bounds(lower_bounds, upper_bounds)
X = np.array([[0.1, 0.2], [0.3, 0.4]])
X_norm = art_model.prepare_data(X)
assert np.all(X_norm == X)


def test_find_data_bounds(art_model):
# Test find_data_bounds with multiple data batches
batch_1 = np.array([[0.1, 0.2], [0.3, 0.4]])
batch_2 = np.array([[0.0, 0.1], [0.5, 0.6]])
lower_bounds, upper_bounds = art_model.find_data_bounds(batch_1, batch_2)
np.testing.assert_array_equal(lower_bounds, np.array([0.0, 0.1]))
np.testing.assert_array_equal(upper_bounds, np.array([0.5, 0.6]))


def test_prepare_data(art_model):
# Test prepare_data with valid data
X = np.array([[0.0, 0.5], [0.2, 1.0]])
art_model.d_min_ = np.array([0.0, 0.0])
art_model.d_max_ = np.array([1.0, 1.0])
normalized_X = art_model.prepare_data(X)
np.testing.assert_array_almost_equal(normalized_X, X) # Already normalized

# Test prepare_data with data requiring normalization
X = np.array([[1.0, 10.0], [5.0, 20.0]])
art_model.d_min_ = np.array([1.0, 10.0])
art_model.d_max_ = np.array([5.0, 20.0])
normalized_X = art_model.prepare_data(X)
expected_normalized_X = np.array([[0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0]])
np.testing.assert_array_almost_equal(normalized_X, expected_normalized_X)

0 comments on commit 8b495e8

Please sign in to comment.