diff --git a/README.md b/README.md index 00cbaa9..70a1101 100644 --- a/README.md +++ b/README.md @@ -69,24 +69,102 @@ Ensure you have Python 3.9 or newer installed. ## Quick Start -Here's a quick example of how to use AdaptiveResonanceLib with the Fuzzy ART model: +Here are some quick examples to get you started with AdaptiveResonanceLib: + +### Clustering Data with the Fuzzy ART model ```python from artlib import FuzzyART import numpy as np # Your dataset -train_X = np.array([...]) +train_X = np.array([...]) # shape (n_samples, n_features) test_X = np.array([...]) # Initialize the Fuzzy ART model model = FuzzyART(rho=0.7, alpha = 0.0, beta=1.0) +# Prepare Data +train_X_prep = model.prepare_data(train_X) +test_X_prep = model.prepare_data(test_X) + +# Fit the model +model.fit(train_X_prep) + +# Predict data labels +predictions = model.predict(test_X_prep) +``` + +### Fitting a Classification Model with SimpleARTMAP + +```python +from artlib import GaussianART, SimpleARTMAP +import numpy as np + +# Your dataset +train_X = np.array([...]) # shape (n_samples, n_features) +train_y = np.array([...]) # shape (n_samples, ), must be integers +test_X = np.array([...]) + +# Initialize the Gaussian ART model +sigma_init = np.array([0.5]*train_X.shape[1]) # variance estimate for each feature +module_a = GaussianART(rho=0.0, sigma_init=sigma_init) + +# Initialize the SimpleARTMAP model +model = SimpleARTMAP(module_a=module_a) + +# Prepare Data +train_X_prep = model.prepare_data(train_X) +test_X_prep = model.prepare_data(test_X) + +# Fit the model +model.fit(train_X_prep, train_y) + +# Predict data labels +predictions = model.predict(test_X_prep) +``` + +### Fitting a Regression Model with FusionART + +```python +from artlib import FuzzyART, HypersphereART, FusionART +import numpy as np + +# Your dataset +train_X = np.array([...]) # shape (n_samples, n_features_X) +train_y = np.array([...]) # shape (n_samples, n_features_y) +test_X = np.array([...]) + +# Initialize the Fuzzy ART model +module_x = FuzzyART(rho=0.0, alpha = 0.0, beta=1.0) + +# Initialize the Hypersphere ART model +r_hat = 0.5*np.sqrt(train_X.shape[1]) # no restriction on hyperpshere size +module_y = HypersphereART(rho=0.0, alpha = 0.0, beta=1.0, r_hat=r_hat) + +# Initialize the FusionARTMAP model +gamma_values = [0.5, 0.5] # eqaul weight to both channels +channel_dims = [ + 2*train_X.shape[1], # fuzzy ART complement codes data so channel dim is 2*n_features + train_y.shape[1] +] +model = FusionART( + modules=[module_x, module_y], + gamma_values=gamma_values, + channel_dims=channel_dims +) + +# Prepare Data +train_Xy = model.join_channel_data(channel_data=[train_X, train_y]) +train_Xy_prep = model.prepare_data(train_Xy) +test_Xy = model.join_channel_data(channel_data=[train_X], skip_channels=[1]) +test_Xy_prep = model.prepare_data(test_Xy) + # Fit the model -model.fit(train_X) +model.fit(train_X_prep, train_y) -# Predict new data points -predictions = model.predict(test_X) +# Predict y-channel values +pred_y = model.predict_regression(test_Xy_prep, target_channels=[1]) ``` diff --git a/artlib/fusion/FusionART.py b/artlib/fusion/FusionART.py index 5880b67..781e66a 100644 --- a/artlib/fusion/FusionART.py +++ b/artlib/fusion/FusionART.py @@ -171,7 +171,7 @@ def validate_params(params: Dict): assert "gamma_values" in params assert all([1.0 >= g >= 0.0 for g in params["gamma_values"]]) assert sum(params["gamma_values"]) == 1.0 - assert isinstance(params["gamma_values"], np.ndarray) + assert isinstance(params["gamma_values"], (np.ndarray, list)) def validate_data(self, X: np.ndarray): """Validate the input data for clustering. @@ -198,7 +198,9 @@ def check_dimensions(self, X: np.ndarray): """ assert X.shape[1] == self.dim_, "Invalid data shape" - def prepare_data(self, channel_data: List[np.ndarray]) -> np.ndarray: + def prepare_data( + self, channel_data: List[np.ndarray], skip_channels: List[int] = [] + ) -> np.ndarray: """Prepare the input data by processing each channel's data through its respective ART module. @@ -206,6 +208,8 @@ def prepare_data(self, channel_data: List[np.ndarray]) -> np.ndarray: ---------- channel_data : list of np.ndarray List of arrays, one for each channel. + skip_channels : list of int, optional + Channels to be skipped (default is []). Returns ------- @@ -213,28 +217,40 @@ def prepare_data(self, channel_data: List[np.ndarray]) -> np.ndarray: Processed and concatenated data. """ + skip_channels = [self.n + k if k < 0 else k for k in skip_channels] prepared_channel_data = [ - self.modules[i].prepare_data(channel_data[i]) for i in range(self.n) + self.modules[i].prepare_data(channel_data[i]) + for i in range(self.n) + if i not in skip_channels ] - return self.join_channel_data(prepared_channel_data) - def restore_data(self, X: np.ndarray) -> List[np.ndarray]: + return self.join_channel_data( + prepared_channel_data, skip_channels=skip_channels + ) + + def restore_data( + self, X: np.ndarray, skip_channels: List[int] = [] + ) -> List[np.ndarray]: """Restore data to its original state before preparation. Parameters ---------- X : np.ndarray The prepared data. - + skip_channels : list of int, optional + Channels to be skipped (default is []). Returns ------- np.ndarray Restored data for each channel. """ - channel_data = self.split_channel_data(X) + skip_channels = [self.n + k if k < 0 else k for k in skip_channels] + channel_data = self.split_channel_data(X, skip_channels=skip_channels) restored_channel_data = [ - self.modules[i].restore_data(channel_data[i]) for i in range(self.n) + self.modules[i].restore_data(channel_data[i]) + for i in range(self.n) + if i not in skip_channels ] return restored_channel_data