diff --git a/docs/artists_api.md b/docs/artists_api.md index d3ae2f0..e035bf7 100644 --- a/docs/artists_api.md +++ b/docs/artists_api.md @@ -51,6 +51,7 @@ ~Scatter.data ~Scatter.visible ~Scatter.color_indices + ~Scatter.size .. rubric:: Attributes Summary @@ -77,6 +78,7 @@ .. autoattribute:: data .. autoattribute:: visible .. autoattribute:: color_indices + .. autoattribute:: size .. rubric:: Attributes Documentation diff --git a/src/biaplotter/_tests/test_artists.py b/src/biaplotter/_tests/test_artists.py index eb11111..4142482 100644 --- a/src/biaplotter/_tests/test_artists.py +++ b/src/biaplotter/_tests/test_artists.py @@ -44,13 +44,29 @@ def on_color_indices_changed(color_indices): assert np.all(colors[0] == scatter.categorical_colormap(0)) assert np.all(colors[50] == scatter.categorical_colormap(2)) - # Test alpha property - collected_alpha_signals = [] - def on_alpha_changed(alpha): - collected_alpha_signals.append(alpha) - scatter.alpha = np.linspace(start=0.1, stop=1.0, num=size) - assert np.all(scatter.alpha == np.linspace(start=0.1, stop=1.0, num=size)) - assert scatter._scatter.get_alpha() is not None + # Test axis limits + x_margin = 0.05 * (np.max(data[:, 0]) - np.min(data[:, 0])) + y_margin = 0.05 * (np.max(data[:, 1]) - np.min(data[:, 1])) + assert np.isclose(ax.get_xlim(), (np.min(data[:, 0]) - x_margin, np.max(data[:, 0]) + x_margin)).all() + assert np.isclose(ax.get_ylim(), (np.min(data[:, 1]) - y_margin, np.max(data[:, 1]) + y_margin)).all() + + # Test size property + scatter.size = 5.0 + assert scatter.size == 5.0 + sizes = scatter._scatter.get_sizes() + assert np.all(sizes == 5.0) + + scatter.size = np.linspace(1, 10, size) + assert np.all(scatter.size == np.linspace(1, 10, size)) + sizes = scatter._scatter.get_sizes() + assert np.all(sizes == np.linspace(1, 10, size)) + + # Test size reset when new data is set + new_data = np.random.rand(size, 2) + scatter.data = new_data + assert scatter.size == 50.0 # that's the default + sizes = scatter._scatter.get_sizes() + assert np.all(sizes == 50.0) def test_histogram2d(): diff --git a/src/biaplotter/artists.py b/src/biaplotter/artists.py index a782aa9..39afd19 100644 --- a/src/biaplotter/artists.py +++ b/src/biaplotter/artists.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from nap_plot_tools.cmap import cat10_mod_cmap, cat10_mod_cmap_first_transparent from psygnal import Signal -from typing import Tuple, List +from typing import Tuple, List, Union class Artist(ABC): @@ -96,8 +96,6 @@ class Scatter(Artist): a colormap to use for the artist, by default cat10_mod_cmap from nap-plot-tools color_indices : (N,) np.ndarray[int] or int, optional array of indices to map to the colormap, by default None - alpha : (N,) np.ndarray[float] or float, optional - array of alpha values for the scatter points, by default 1.0 Notes ----- @@ -117,7 +115,6 @@ class Scatter(Artist): >>> scatter.data = data >>> scatter.visible = True >>> scatter.color_indices = np.linspace(start=0, stop=5, num=100, endpoint=False, dtype=int) - >>> scatter.alpha = np.linspace(start=0.1, stop=1.0, num=100) >>> plt.show() """ #: Signal emitted when the `data` is changed. @@ -125,14 +122,14 @@ class Scatter(Artist): #: Signal emitted when the `color_indices` are changed. color_indices_changed_signal: Signal = Signal(np.ndarray) - def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None, alpha: np.ndarray = 1.0): + def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None): """Initializes the scatter plot artist. """ super().__init__(ax, data, categorical_colormap, color_indices) #: Stores the scatter plot matplotlib object self._scatter = None - self._alpha = alpha self.data = data + self._size = 50 # Default size self.draw() # Initial draw of the scatter plot @property @@ -164,7 +161,7 @@ def data(self, value: np.ndarray): # emit signal self.data_changed_signal.emit(self._data) if self._scatter is None: - self._scatter = self.ax.scatter(value[:, 0], value[:, 1]) + self._scatter = self.ax.scatter(value[:, 0], value[:, 1], s=self._size) self.color_indices = 0 # Set default color index else: # If the scatter plot already exists, just update its data @@ -181,6 +178,13 @@ def data(self, value: np.ndarray): # fill with zeros where new data is larger color_indices[color_indices_size:] = 0 self.color_indices = color_indices + self.size = 50 + + x_margin = 0.05 * (np.max(value[:, 0]) - np.min(value[:, 0])) + y_margin = 0.05 * (np.max(value[:, 1]) - np.min(value[:, 1])) + self.ax.set_xlim(np.min(value[:, 0]) - x_margin, np.max(value[:, 0]) + x_margin) + self.ax.set_ylim(np.min(value[:, 1]) - y_margin, np.max(value[:, 1]) + y_margin) + self.draw() @property @@ -242,27 +246,24 @@ def color_indices(self, indices: np.ndarray): self.draw() @property - def alpha(self) -> np.ndarray: - """Gets or sets the alpha values for the scatter plot. + def size(self) -> Union[float, np.ndarray]: + """Gets or sets the size of the points in the scatter plot. Triggers a draw idle command. Returns ------- - alpha : (N,) np.ndarray[float] or float - alpha values for the scatter plot. Accepts a scalar or an array of floats. + size : float or (N,) np.ndarray[float] + size of the points in the scatter plot. Accepts a scalar or an array of floats. """ - return self._alpha - - @alpha.setter - def alpha(self, value: np.ndarray): - """Sets the alpha values for the scatter plot and updates the display accordingly.""" - # Check if value is a scalar - if np.isscalar(value): - value = np.full(len(self._data), value) - self._alpha = value + return self._size + + @size.setter + def size(self, value: Union[float, np.ndarray]): + """Sets the size of the points in the scatter plot.""" + self._size = value if self._scatter is not None: - self._scatter.set_alpha(value) + self._scatter.set_sizes(np.full(len(self._data), value) if np.isscalar(value) else value) self.draw() def draw(self):