Skip to content

Commit

Permalink
Merge 0d85ed3 into 65e2e9b
Browse files Browse the repository at this point in the history
  • Loading branch information
jo-mueller authored Jul 23, 2024
2 parents 65e2e9b + 0d85ed3 commit 9b180a7
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 13 deletions.
57 changes: 45 additions & 12 deletions docs/examples/scatter_artist_example.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/biaplotter/_tests/test_artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def on_color_indices_changed(color_indices):
assert np.all(colors[0] == scatter.categorical_colormap(0))
assert np.all(colors[50] == scatter.categorical_colormap(2))

# Test alpha property
collected_alpha_signals = []
def on_alpha_changed(alpha):
collected_alpha_signals.append(alpha)
scatter.alpha = np.linspace(start=0.1, stop=1.0, num=size)
assert np.all(scatter.alpha == np.linspace(start=0.1, stop=1.0, num=size))
assert scatter._scatter.get_alpha() is not None


def test_histogram2d():
# Inputs
Expand Down
30 changes: 29 additions & 1 deletion src/biaplotter/artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class Scatter(Artist):
a colormap to use for the artist, by default cat10_mod_cmap from nap-plot-tools
color_indices : (N,) np.ndarray[int] or int, optional
array of indices to map to the colormap, by default None
alpha : (N,) np.ndarray[float] or float, optional
array of alpha values for the scatter points, by default 1.0
Notes
-----
Expand All @@ -115,19 +117,21 @@ class Scatter(Artist):
>>> scatter.data = data
>>> scatter.visible = True
>>> scatter.color_indices = np.linspace(start=0, stop=5, num=100, endpoint=False, dtype=int)
>>> scatter.alpha = np.linspace(start=0.1, stop=1.0, num=100)
>>> plt.show()
"""
#: Signal emitted when the `data` is changed.
data_changed_signal: Signal = Signal(np.ndarray)
#: Signal emitted when the `color_indices` are changed.
color_indices_changed_signal: Signal = Signal(np.ndarray)

def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None):
def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None, alpha: np.ndarray = 1.0):
"""Initializes the scatter plot artist.
"""
super().__init__(ax, data, categorical_colormap, color_indices)
#: Stores the scatter plot matplotlib object
self._scatter = None
self._alpha = alpha
self.data = data
self.draw() # Initial draw of the scatter plot

Expand Down Expand Up @@ -236,6 +240,30 @@ def color_indices(self, indices: np.ndarray):
self.color_indices_changed_signal.emit(self._color_indices)
self.draw()

@property
def alpha(self) -> np.ndarray:
"""Gets or sets the alpha values for the scatter plot.
Triggers a draw idle command.
Returns
-------
alpha : (N,) np.ndarray[float] or float
alpha values for the scatter plot. Accepts a scalar or an array of floats.
"""
return self._alpha

@alpha.setter
def alpha(self, value: np.ndarray):
"""Sets the alpha values for the scatter plot and updates the display accordingly."""
# Check if value is a scalar
if np.isscalar(value):
value = np.full(len(self._data), value)
self._alpha = value
if self._scatter is not None:
self._scatter.set_alpha(value)
self.draw()

def draw(self):
"""Draws or redraws the scatter plot."""
self.ax.figure.canvas.draw_idle()
Expand Down
25 changes: 25 additions & 0 deletions src/biaplotter/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,28 @@ def add_selector(self, selector_type: SelectorType, selector_instance: Interacti
raise ValueError(
f"Selector '{selector_type.name}' already exists.")
self.selectors[selector_type] = selector_instance

@property
def alpha(self) -> float:
"""Gets or sets the alpha value for the scatter artist.
Returns
-------
alpha : float
The alpha value for the scatter artist.
"""
if isinstance(self._active_artist, Scatter):
return self._active_artist.alpha
return 1.0

@alpha.setter
def alpha(self, value: float):
"""Sets the alpha value for the scatter artist.
Parameters
----------
value : float
The alpha value to set for the scatter artist.
"""
if isinstance(self._active_artist, Scatter):
self._active_artist.alpha = value

0 comments on commit 9b180a7

Please sign in to comment.