diff --git a/src/biaplotter/_tests/test_artists.py b/src/biaplotter/_tests/test_artists.py index 4eabbb7..eb11111 100644 --- a/src/biaplotter/_tests/test_artists.py +++ b/src/biaplotter/_tests/test_artists.py @@ -44,6 +44,14 @@ def on_color_indices_changed(color_indices): assert np.all(colors[0] == scatter.categorical_colormap(0)) assert np.all(colors[50] == scatter.categorical_colormap(2)) + # Test alpha property + collected_alpha_signals = [] + def on_alpha_changed(alpha): + collected_alpha_signals.append(alpha) + scatter.alpha = np.linspace(start=0.1, stop=1.0, num=size) + assert np.all(scatter.alpha == np.linspace(start=0.1, stop=1.0, num=size)) + assert scatter._scatter.get_alpha() is not None + def test_histogram2d(): # Inputs diff --git a/src/biaplotter/artists.py b/src/biaplotter/artists.py index c01b1ac..d96b6b5 100644 --- a/src/biaplotter/artists.py +++ b/src/biaplotter/artists.py @@ -96,6 +96,8 @@ class Scatter(Artist): a colormap to use for the artist, by default cat10_mod_cmap from nap-plot-tools color_indices : (N,) np.ndarray[int] or int, optional array of indices to map to the colormap, by default None + alpha : (N,) np.ndarray[float] or float, optional + array of alpha values for the scatter points, by default 1.0 Notes ----- @@ -115,6 +117,7 @@ class Scatter(Artist): >>> scatter.data = data >>> scatter.visible = True >>> scatter.color_indices = np.linspace(start=0, stop=5, num=100, endpoint=False, dtype=int) + >>> scatter.alpha = np.linspace(start=0.1, stop=1.0, num=100) >>> plt.show() """ #: Signal emitted when the `data` is changed. @@ -122,12 +125,13 @@ class Scatter(Artist): #: Signal emitted when the `color_indices` are changed. color_indices_changed_signal: Signal = Signal(np.ndarray) - def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None): + def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None, alpha: np.ndarray = 1.0): """Initializes the scatter plot artist. """ super().__init__(ax, data, categorical_colormap, color_indices) #: Stores the scatter plot matplotlib object self._scatter = None + self._alpha = alpha self.data = data self.draw() # Initial draw of the scatter plot @@ -236,6 +240,30 @@ def color_indices(self, indices: np.ndarray): self.color_indices_changed_signal.emit(self._color_indices) self.draw() + @property + def alpha(self) -> np.ndarray: + """Gets or sets the alpha values for the scatter plot. + + Triggers a draw idle command. + + Returns + ------- + alpha : (N,) np.ndarray[float] or float + alpha values for the scatter plot. Accepts a scalar or an array of floats. + """ + return self._alpha + + @alpha.setter + def alpha(self, value: np.ndarray): + """Sets the alpha values for the scatter plot and updates the display accordingly.""" + # Check if value is a scalar + if np.isscalar(value): + value = np.full(len(self._data), value) + self._alpha = value + if self._scatter is not None: + self._scatter.set_alpha(value) + self.draw() + def draw(self): """Draws or redraws the scatter plot.""" self.ax.figure.canvas.draw_idle() diff --git a/src/biaplotter/plotter.py b/src/biaplotter/plotter.py index 7bed575..ae7baf6 100644 --- a/src/biaplotter/plotter.py +++ b/src/biaplotter/plotter.py @@ -263,3 +263,28 @@ def add_selector(self, selector_type: SelectorType, selector_instance: Interacti raise ValueError( f"Selector '{selector_type.name}' already exists.") self.selectors[selector_type] = selector_instance + + @property + def alpha(self) -> float: + """Gets or sets the alpha value for the scatter artist. + + Returns + ------- + alpha : float + The alpha value for the scatter artist. + """ + if isinstance(self._active_artist, Scatter): + return self._active_artist.alpha + return 1.0 + + @alpha.setter + def alpha(self, value: float): + """Sets the alpha value for the scatter artist. + + Parameters + ---------- + value : float + The alpha value to set for the scatter artist. + """ + if isinstance(self._active_artist, Scatter): + self._active_artist.alpha = value