Skip to content

Commit

Permalink
Merge branch 'main' into jo-mueller/add-alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
jo-mueller committed Aug 8, 2024
2 parents 8ea5c35 + 41f58dd commit 0069da5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
2 changes: 2 additions & 0 deletions docs/artists_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
~Scatter.data
~Scatter.visible
~Scatter.color_indices
~Scatter.size
.. rubric:: Attributes Summary
Expand All @@ -77,6 +78,7 @@
.. autoattribute:: data
.. autoattribute:: visible
.. autoattribute:: color_indices
.. autoattribute:: size
.. rubric:: Attributes Documentation
Expand Down
30 changes: 23 additions & 7 deletions src/biaplotter/_tests/test_artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,29 @@ def on_color_indices_changed(color_indices):
assert np.all(colors[0] == scatter.categorical_colormap(0))
assert np.all(colors[50] == scatter.categorical_colormap(2))

# Test alpha property
collected_alpha_signals = []
def on_alpha_changed(alpha):
collected_alpha_signals.append(alpha)
scatter.alpha = np.linspace(start=0.1, stop=1.0, num=size)
assert np.all(scatter.alpha == np.linspace(start=0.1, stop=1.0, num=size))
assert scatter._scatter.get_alpha() is not None
# Test axis limits
x_margin = 0.05 * (np.max(data[:, 0]) - np.min(data[:, 0]))
y_margin = 0.05 * (np.max(data[:, 1]) - np.min(data[:, 1]))
assert np.isclose(ax.get_xlim(), (np.min(data[:, 0]) - x_margin, np.max(data[:, 0]) + x_margin)).all()
assert np.isclose(ax.get_ylim(), (np.min(data[:, 1]) - y_margin, np.max(data[:, 1]) + y_margin)).all()

# Test size property
scatter.size = 5.0
assert scatter.size == 5.0
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == 5.0)

scatter.size = np.linspace(1, 10, size)
assert np.all(scatter.size == np.linspace(1, 10, size))
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == np.linspace(1, 10, size))

# Test size reset when new data is set
new_data = np.random.rand(size, 2)
scatter.data = new_data
assert scatter.size == 50.0 # that's the default
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == 50.0)


def test_histogram2d():
Expand Down
43 changes: 22 additions & 21 deletions src/biaplotter/artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from nap_plot_tools.cmap import cat10_mod_cmap, cat10_mod_cmap_first_transparent
from psygnal import Signal
from typing import Tuple, List
from typing import Tuple, List, Union


class Artist(ABC):
Expand Down Expand Up @@ -96,8 +96,6 @@ class Scatter(Artist):
a colormap to use for the artist, by default cat10_mod_cmap from nap-plot-tools
color_indices : (N,) np.ndarray[int] or int, optional
array of indices to map to the colormap, by default None
alpha : (N,) np.ndarray[float] or float, optional
array of alpha values for the scatter points, by default 1.0
Notes
-----
Expand All @@ -117,22 +115,21 @@ class Scatter(Artist):
>>> scatter.data = data
>>> scatter.visible = True
>>> scatter.color_indices = np.linspace(start=0, stop=5, num=100, endpoint=False, dtype=int)
>>> scatter.alpha = np.linspace(start=0.1, stop=1.0, num=100)
>>> plt.show()
"""
#: Signal emitted when the `data` is changed.
data_changed_signal: Signal = Signal(np.ndarray)
#: Signal emitted when the `color_indices` are changed.
color_indices_changed_signal: Signal = Signal(np.ndarray)

def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None, alpha: np.ndarray = 1.0):
def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_colormap: Colormap = cat10_mod_cmap, color_indices: np.ndarray = None):
"""Initializes the scatter plot artist.
"""
super().__init__(ax, data, categorical_colormap, color_indices)
#: Stores the scatter plot matplotlib object
self._scatter = None
self._alpha = alpha
self.data = data
self._size = 50 # Default size
self.draw() # Initial draw of the scatter plot

@property
Expand Down Expand Up @@ -164,7 +161,7 @@ def data(self, value: np.ndarray):
# emit signal
self.data_changed_signal.emit(self._data)
if self._scatter is None:
self._scatter = self.ax.scatter(value[:, 0], value[:, 1])
self._scatter = self.ax.scatter(value[:, 0], value[:, 1], s=self._size)
self.color_indices = 0 # Set default color index
else:
# If the scatter plot already exists, just update its data
Expand All @@ -181,6 +178,13 @@ def data(self, value: np.ndarray):
# fill with zeros where new data is larger
color_indices[color_indices_size:] = 0
self.color_indices = color_indices
self.size = 50

x_margin = 0.05 * (np.max(value[:, 0]) - np.min(value[:, 0]))
y_margin = 0.05 * (np.max(value[:, 1]) - np.min(value[:, 1]))
self.ax.set_xlim(np.min(value[:, 0]) - x_margin, np.max(value[:, 0]) + x_margin)
self.ax.set_ylim(np.min(value[:, 1]) - y_margin, np.max(value[:, 1]) + y_margin)

self.draw()

@property
Expand Down Expand Up @@ -242,27 +246,24 @@ def color_indices(self, indices: np.ndarray):
self.draw()

@property
def alpha(self) -> np.ndarray:
"""Gets or sets the alpha values for the scatter plot.
def size(self) -> Union[float, np.ndarray]:
"""Gets or sets the size of the points in the scatter plot.
Triggers a draw idle command.
Returns
-------
alpha : (N,) np.ndarray[float] or float
alpha values for the scatter plot. Accepts a scalar or an array of floats.
size : float or (N,) np.ndarray[float]
size of the points in the scatter plot. Accepts a scalar or an array of floats.
"""
return self._alpha

@alpha.setter
def alpha(self, value: np.ndarray):
"""Sets the alpha values for the scatter plot and updates the display accordingly."""
# Check if value is a scalar
if np.isscalar(value):
value = np.full(len(self._data), value)
self._alpha = value
return self._size

@size.setter
def size(self, value: Union[float, np.ndarray]):
"""Sets the size of the points in the scatter plot."""
self._size = value
if self._scatter is not None:
self._scatter.set_alpha(value)
self._scatter.set_sizes(np.full(len(self._data), value) if np.isscalar(value) else value)
self.draw()

def draw(self):
Expand Down

0 comments on commit 0069da5

Please sign in to comment.