diff --git a/artlib/common/BaseART.py b/artlib/common/BaseART.py index 3af9e7f..6e6abce 100644 --- a/artlib/common/BaseART.py +++ b/artlib/common/BaseART.py @@ -646,6 +646,102 @@ def partial_fit( self.labels_[i + j] = c return self + def fit_gif( + self, + X: np.ndarray, + y: Optional[np.ndarray] = None, + match_reset_func: Optional[Callable] = None, + max_iter=1, + match_tracking: Literal["MT+", "MT-", "MT0", "MT1", "MT~"] = "MT+", + epsilon: float = 0.0, + verbose: bool = False, + ax: Optional[Axes] = None, + filename: Optional[str] = None, + colors: Optional[IndexableOrKeyable] = None, + n_cluster_estimate: int = 20, + fps: int = 5, + **kwargs, + ): + """Fit the model to the data and make a gif of the process. + + Parameters + ---------- + X : np.ndarray + The dataset. + y : np.ndarray, optional + Not used. For compatibility. + match_reset_func : callable, optional + A callable that influences cluster creation. + max_iter : int, default=1 + Number of iterations to fit the model on the same dataset. + match_tracking : {"MT+", "MT-", "MT0", "MT1", "MT~"}, default="MT+" + Method for resetting match criterion. + epsilon : float, default=0.0 + Epsilon value used for adjusting match criterion. + verbose : bool, default=False + If True, displays progress of the fitting process. + ax : matplotlib.axes.Axes, optional + Figure axes. + colors : IndexableOrKeyable, optional + Colors to use for each cluster. + n_cluster_estimate : int, default=20 + estimate of number of clusters. Used for coloring plot. + fps : int, default=5 + gif frames per second + **kwargs : dict + see :func: `artlib.common.BaseART.visualize` + + """ + import matplotlib.pyplot as plt + from matplotlib.animation import PillowWriter + + if ax is None: + fig, ax = plt.subplots() + ax.set_xlim(-0.1, 1.1) + ax.set_ylim(-0.1, 1.1) + if filename is None: + filename = f"fit_gif_{self.__class__.__name__}.gif" + if colors is None: + from matplotlib.pyplot import cm + + colors = cm.rainbow(np.linspace(0, 1, n_cluster_estimate)) + black = np.array([[0, 0, 0, 1]]) # RGBA for black + colors = np.vstack((colors, black)) # Add black at the end + + self.validate_data(X) + self.check_dimensions(X) + self.is_fitted_ = True + + self.W = [] + self.labels_ = -np.ones((X.shape[0],), dtype=int) + + writer = PillowWriter(fps=fps) + with writer.saving(fig, filename, dpi=80): + for _ in range(max_iter): + if verbose: + from tqdm import tqdm + + x_iter = tqdm(enumerate(X), total=int(X.shape[0])) + else: + x_iter = enumerate(X) + for i, x in x_iter: + self.pre_step_fit(X) + c = self.step_fit( + x, + match_reset_func=match_reset_func, + match_tracking=match_tracking, + epsilon=epsilon, + ) + self.labels_[i] = c + self.post_step_fit(X) + ax.clear() + ax.set_xlim(-0.1, 1.1) + ax.set_ylim(-0.1, 1.1) + self.visualize(X, self.labels_, ax, colors=colors, **kwargs) + writer.grab_frame() + self.post_fit(X) + return self + def predict(self, X: np.ndarray) -> np.ndarray: """Predict labels for the data. diff --git a/examples/demo_make_gif.py b/examples/demo_make_gif.py new file mode 100644 index 0000000..687e4a0 --- /dev/null +++ b/examples/demo_make_gif.py @@ -0,0 +1,33 @@ +from sklearn.datasets import make_blobs +import matplotlib.pyplot as plt + +from artlib import FuzzyART + + +def cluster_blobs(): + data, target = make_blobs( + n_samples=150, + centers=3, + cluster_std=0.50, + random_state=0, + shuffle=False, + ) + print("Data has shape:", data.shape) + + params = {"rho": 0.5, "alpha": 0.0, "beta": 1.0} + cls = FuzzyART(**params) + + X = cls.prepare_data(data) + print("Prepared data has shape:", X.shape) + + cls = cls.fit_gif(X, filename="fit_gif_FuzzyART.gif", n_cluster_estimate=3) + y = cls.labels_ + + print(f"{cls.n_clusters} clusters found") + + cls.visualize(X, y) + plt.show() + + +if __name__ == "__main__": + cluster_blobs()