Skip to content

Commit

Permalink
Merge pull request #114 from NiklasMelton/update-quick-start
Browse files Browse the repository at this point in the history
Update quick start
  • Loading branch information
NiklasMelton authored Oct 18, 2024
2 parents 6269992 + 3ea1c78 commit 4478451
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 13 deletions.
88 changes: 83 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,102 @@ Ensure you have Python 3.9 or newer installed.
<!-- START quick-start -->
## Quick Start

Here's a quick example of how to use AdaptiveResonanceLib with the Fuzzy ART model:
Here are some quick examples to get you started with AdaptiveResonanceLib:

### Clustering Data with the Fuzzy ART model

```python
from artlib import FuzzyART
import numpy as np

# Your dataset
train_X = np.array([...])
train_X = np.array([...]) # shape (n_samples, n_features)
test_X = np.array([...])

# Initialize the Fuzzy ART model
model = FuzzyART(rho=0.7, alpha = 0.0, beta=1.0)

# Prepare Data
train_X_prep = model.prepare_data(train_X)
test_X_prep = model.prepare_data(test_X)

# Fit the model
model.fit(train_X_prep)

# Predict data labels
predictions = model.predict(test_X_prep)
```

### Fitting a Classification Model with SimpleARTMAP

```python
from artlib import GaussianART, SimpleARTMAP
import numpy as np

# Your dataset
train_X = np.array([...]) # shape (n_samples, n_features)
train_y = np.array([...]) # shape (n_samples, ), must be integers
test_X = np.array([...])

# Initialize the Gaussian ART model
sigma_init = np.array([0.5]*train_X.shape[1]) # variance estimate for each feature
module_a = GaussianART(rho=0.0, sigma_init=sigma_init)

# Initialize the SimpleARTMAP model
model = SimpleARTMAP(module_a=module_a)

# Prepare Data
train_X_prep = model.prepare_data(train_X)
test_X_prep = model.prepare_data(test_X)

# Fit the model
model.fit(train_X_prep, train_y)

# Predict data labels
predictions = model.predict(test_X_prep)
```

### Fitting a Regression Model with FusionART

```python
from artlib import FuzzyART, HypersphereART, FusionART
import numpy as np

# Your dataset
train_X = np.array([...]) # shape (n_samples, n_features_X)
train_y = np.array([...]) # shape (n_samples, n_features_y)
test_X = np.array([...])

# Initialize the Fuzzy ART model
module_x = FuzzyART(rho=0.0, alpha = 0.0, beta=1.0)

# Initialize the Hypersphere ART model
r_hat = 0.5*np.sqrt(train_X.shape[1]) # no restriction on hyperpshere size
module_y = HypersphereART(rho=0.0, alpha = 0.0, beta=1.0, r_hat=r_hat)

# Initialize the FusionARTMAP model
gamma_values = [0.5, 0.5] # eqaul weight to both channels
channel_dims = [
2*train_X.shape[1], # fuzzy ART complement codes data so channel dim is 2*n_features
train_y.shape[1]
]
model = FusionART(
modules=[module_x, module_y],
gamma_values=gamma_values,
channel_dims=channel_dims
)

# Prepare Data
train_Xy = model.join_channel_data(channel_data=[train_X, train_y])
train_Xy_prep = model.prepare_data(train_Xy)
test_Xy = model.join_channel_data(channel_data=[train_X], skip_channels=[1])
test_Xy_prep = model.prepare_data(test_Xy)

# Fit the model
model.fit(train_X)
model.fit(train_X_prep, train_y)

# Predict new data points
predictions = model.predict(test_X)
# Predict y-channel values
pred_y = model.predict_regression(test_Xy_prep, target_channels=[1])
```

<!-- END quick-start -->
Expand Down
32 changes: 24 additions & 8 deletions artlib/fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def validate_params(params: Dict):
assert "gamma_values" in params
assert all([1.0 >= g >= 0.0 for g in params["gamma_values"]])
assert sum(params["gamma_values"]) == 1.0
assert isinstance(params["gamma_values"], np.ndarray)
assert isinstance(params["gamma_values"], (np.ndarray, list))

def validate_data(self, X: np.ndarray):
"""Validate the input data for clustering.
Expand All @@ -198,43 +198,59 @@ def check_dimensions(self, X: np.ndarray):
"""
assert X.shape[1] == self.dim_, "Invalid data shape"

def prepare_data(self, channel_data: List[np.ndarray]) -> np.ndarray:
def prepare_data(
self, channel_data: List[np.ndarray], skip_channels: List[int] = []
) -> np.ndarray:
"""Prepare the input data by processing each channel's data through its
respective ART module.
Parameters
----------
channel_data : list of np.ndarray
List of arrays, one for each channel.
skip_channels : list of int, optional
Channels to be skipped (default is []).
Returns
-------
np.ndarray
Processed and concatenated data.
"""
skip_channels = [self.n + k if k < 0 else k for k in skip_channels]
prepared_channel_data = [
self.modules[i].prepare_data(channel_data[i]) for i in range(self.n)
self.modules[i].prepare_data(channel_data[i])
for i in range(self.n)
if i not in skip_channels
]
return self.join_channel_data(prepared_channel_data)

def restore_data(self, X: np.ndarray) -> List[np.ndarray]:
return self.join_channel_data(
prepared_channel_data, skip_channels=skip_channels
)

def restore_data(
self, X: np.ndarray, skip_channels: List[int] = []
) -> List[np.ndarray]:
"""Restore data to its original state before preparation.
Parameters
----------
X : np.ndarray
The prepared data.
skip_channels : list of int, optional
Channels to be skipped (default is []).
Returns
-------
np.ndarray
Restored data for each channel.
"""
channel_data = self.split_channel_data(X)
skip_channels = [self.n + k if k < 0 else k for k in skip_channels]
channel_data = self.split_channel_data(X, skip_channels=skip_channels)
restored_channel_data = [
self.modules[i].restore_data(channel_data[i]) for i in range(self.n)
self.modules[i].restore_data(channel_data[i])
for i in range(self.n)
if i not in skip_channels
]
return restored_channel_data

Expand Down

0 comments on commit 4478451

Please sign in to comment.