Skip to content

Commit

Permalink
Merge pull request #119 from NiklasMelton/add-gif-mode
Browse files Browse the repository at this point in the history
add fit_gif method
  • Loading branch information
NiklasMelton authored Oct 29, 2024
2 parents 0a19d3e + 08404f9 commit 0a0806d
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
96 changes: 96 additions & 0 deletions artlib/common/BaseART.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,102 @@ def partial_fit(
self.labels_[i + j] = c
return self

def fit_gif(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
match_reset_func: Optional[Callable] = None,
max_iter=1,
match_tracking: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+",
epsilon: float = 0.0,
verbose: bool = False,
ax: Optional[Axes] = None,
filename: Optional[str] = None,
colors: Optional[IndexableOrKeyable] = None,
n_cluster_estimate: int = 20,
fps: int = 5,
**kwargs,
):
"""Fit the model to the data and make a gif of the process.
Parameters
----------
X : np.ndarray
The dataset.
y : np.ndarray, optional
Not used. For compatibility.
match_reset_func : callable, optional
A callable that influences cluster creation.
max_iter : int, default=1
Number of iterations to fit the model on the same dataset.
match_tracking : {"MT+", "MT-", "MT0", "MT1", "MT~"}, default="MT+"
Method for resetting match criterion.
epsilon : float, default=0.0
Epsilon value used for adjusting match criterion.
verbose : bool, default=False
If True, displays progress of the fitting process.
ax : matplotlib.axes.Axes, optional
Figure axes.
colors : IndexableOrKeyable, optional
Colors to use for each cluster.
n_cluster_estimate : int, default=20
estimate of number of clusters. Used for coloring plot.
fps : int, default=5
gif frames per second
**kwargs : dict
see :func: `artlib.common.BaseART.visualize`
"""
import matplotlib.pyplot as plt
from matplotlib.animation import PillowWriter

if ax is None:
fig, ax = plt.subplots()
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.1, 1.1)
if filename is None:
filename = f"fit_gif_{self.__class__.__name__}.gif"
if colors is None:
from matplotlib.pyplot import cm

colors = cm.rainbow(np.linspace(0, 1, n_cluster_estimate))
black = np.array([[0, 0, 0, 1]]) # RGBA for black
colors = np.vstack((colors, black)) # Add black at the end

self.validate_data(X)
self.check_dimensions(X)
self.is_fitted_ = True

self.W = []
self.labels_ = -np.ones((X.shape[0],), dtype=int)

writer = PillowWriter(fps=fps)
with writer.saving(fig, filename, dpi=80):
for _ in range(max_iter):
if verbose:
from tqdm import tqdm

x_iter = tqdm(enumerate(X), total=int(X.shape[0]))
else:
x_iter = enumerate(X)
for i, x in x_iter:
self.pre_step_fit(X)
c = self.step_fit(
x,
match_reset_func=match_reset_func,
match_tracking=match_tracking,
epsilon=epsilon,
)
self.labels_[i] = c
self.post_step_fit(X)
ax.clear()
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.1, 1.1)
self.visualize(X, self.labels_, ax, colors=colors, **kwargs)
writer.grab_frame()
self.post_fit(X)
return self

def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict labels for the data.
Expand Down
33 changes: 33 additions & 0 deletions examples/demo_make_gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

from artlib import FuzzyART


def cluster_blobs():
data, target = make_blobs(
n_samples=150,
centers=3,
cluster_std=0.50,
random_state=0,
shuffle=False,
)
print("Data has shape:", data.shape)

params = {"rho": 0.5, "alpha": 0.0, "beta": 1.0}
cls = FuzzyART(**params)

X = cls.prepare_data(data)
print("Prepared data has shape:", X.shape)

cls = cls.fit_gif(X, filename="fit_gif_FuzzyART.gif", n_cluster_estimate=3)
y = cls.labels_

print(f"{cls.n_clusters} clusters found")

cls.visualize(X, y)
plt.show()


if __name__ == "__main__":
cluster_blobs()

0 comments on commit 0a0806d

Please sign in to comment.