Skip to content

Commit

Permalink
add examples and improve FusionART prepare_data
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasMelton committed Oct 18, 2024
1 parent e47b04d commit 3ea1c78
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,17 @@ module_x = FuzzyART(rho=0.0, alpha = 0.0, beta=1.0)
r_hat = 0.5*np.sqrt(train_X.shape[1]) # no restriction on hyperpshere size
module_y = HypersphereART(rho=0.0, alpha = 0.0, beta=1.0, r_hat=r_hat)

# Initialize the SimpleARTMAP model
model = FusionART(modules=[module_x, module_y])
# Initialize the FusionARTMAP model
gamma_values = [0.5, 0.5] # eqaul weight to both channels
channel_dims = [
2*train_X.shape[1], # fuzzy ART complement codes data so channel dim is 2*n_features
train_y.shape[1]
]
model = FusionART(
modules=[module_x, module_y],
gamma_values=gamma_values,
channel_dims=channel_dims
)

# Prepare Data
train_Xy = model.join_channel_data(channel_data=[train_X, train_y])
Expand All @@ -154,8 +163,8 @@ test_Xy_prep = model.prepare_data(test_Xy)
# Fit the model
model.fit(train_X_prep, train_y)

# Predict data labels
predictions = model.predict(test_X_prep)
# Predict y-channel values
pred_y = model.predict_regression(test_Xy_prep, target_channels=[1])
```

<!-- END quick-start -->
Expand Down
2 changes: 1 addition & 1 deletion artlib/fusion/FusionART.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def validate_params(params: Dict):
assert "gamma_values" in params
assert all([1.0 >= g >= 0.0 for g in params["gamma_values"]])
assert sum(params["gamma_values"]) == 1.0
assert isinstance(params["gamma_values"], np.ndarray)
assert isinstance(params["gamma_values"], (np.ndarray, list))

def validate_data(self, X: np.ndarray):
"""Validate the input data for clustering.
Expand Down

0 comments on commit 3ea1c78

Please sign in to comment.