Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom color scales across sisl.viz #786

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions docs/visualization/viz_module/showcase/GeometryPlot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,20 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice however that, for now, you can not mix values with strings and there is only one colorscale for all atoms."
"Notice however that, for now, you can not mix values with strings and there is only one colorscale for all atoms.\n",
"\n",
"You can also pass a custom colorscale specified as a list of colors as in `plotly` colorscales:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot.update_inputs(atoms_colorscale=[\"rgb(255, 0, 0)\", \"rgb(0, 0, 255)\"])\n",
"# or\n",
"plot.update_inputs(atoms_colorscale=[[0, \"rgb(255, 0, 0)\"], [1, \"rgb(0, 0, 255)\"]])"
]
},
{
Expand All @@ -615,7 +628,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot.update_inputs(axes=\"xyz\")"
"plot.update_inputs(axes=\"xyz\", atoms_colorscale=\"viridis\")"
]
},
{
Expand All @@ -635,7 +648,7 @@
"source": [
"plot.update_inputs(\n",
" axes=\"yx\", bonds_style={\"color\": \"orange\", \"width\": 5, \"opacity\": 0.5}\n",
").get()"
")"
]
},
{
Expand Down Expand Up @@ -908,7 +921,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
29 changes: 26 additions & 3 deletions src/sisl/viz/figure/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,29 @@
"dot": "dotted",
}.get(dash, dash)

def _sanitize_colorscale(self, colorscale):
"""Makes sure that a colorscale is either a string or a colormap."""
if isinstance(colorscale, str):
return colorscale

Check warning on line 246 in src/sisl/viz/figure/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/matplotlib.py#L246

Added line #L246 was not covered by tests
elif isinstance(colorscale, list):

def _sanitize_scale_item(item):

Check warning on line 249 in src/sisl/viz/figure/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/matplotlib.py#L249

Added line #L249 was not covered by tests

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
# Plotly uses rgb colors as a string like "rgb(r,g,b)",
# while matplotlib uses tuples
# Also plotly's range goes from 0 to 255 while matplotlib's goes from 0 to 1
if isinstance(item, (tuple, list)) and len(item) == 2:
return (item[0], _sanitize_scale_item(item[1]))
elif isinstance(item, str) and item.startswith("rgb("):
return tuple(float(x) / 255 for x in item[4:-1].split(","))

Check warning on line 256 in src/sisl/viz/figure/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/matplotlib.py#L253-L256

Added lines #L253 - L256 were not covered by tests

colorscale = [_sanitize_scale_item(item) for item in colorscale]

Check warning on line 258 in src/sisl/viz/figure/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/matplotlib.py#L258

Added line #L258 was not covered by tests

return matplotlib.colors.LinearSegmentedColormap.from_list(

Check warning on line 260 in src/sisl/viz/figure/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/matplotlib.py#L260

Added line #L260 was not covered by tests
"custom", colorscale
)
else:
return colorscale

def draw_line(
self,
x,
Expand Down Expand Up @@ -426,7 +449,7 @@
y,
c=marker.get("color"),
s=marker.get("size", 1),
cmap=marker.get("colorscale"),
cmap=self._sanitize_colorscale(marker.get("colorscale")),
alpha=marker.get("opacity"),
label=name,
zorder=zorder,
Expand All @@ -442,7 +465,7 @@
y,
c=marker.get("color"),
s=marker.get("size", 1),
cmap=marker.get("colorscale"),
cmap=self._sanitize_colorscale(marker.get("colorscale")),
label=name,
zorder=zorder,
**kwargs,
Expand Down Expand Up @@ -481,7 +504,7 @@
axes = _axes or self._get_subplot_axes(row=row, col=col)

coloraxis = self._coloraxes.get(coloraxis, {})
colorscale = coloraxis.get("colorscale")
colorscale = self._sanitize_colorscale(coloraxis.get("colorscale"))
vmin = coloraxis.get("cmin")
vmax = coloraxis.get("cmax")

Expand Down
7 changes: 4 additions & 3 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@
)

def draw_scatter(self, x, y, name=None, marker={}, **kwargs):
marker.pop("dash", None)
marker = {k: v for k, v in marker.items() if k != "dash"}
self.draw_line(x, y, name, marker=marker, mode="markers", **kwargs)

def draw_multicolor_scatter(self, *args, **kwargs):
Expand All @@ -606,8 +606,9 @@

super().draw_multicolor_line_3D(x, y, z, **kwargs)

def draw_scatter_3D(self, *args, **kwargs):
self.draw_line_3D(*args, mode="markers", **kwargs)
def draw_scatter_3D(self, *args, marker={}, **kwargs):
marker = {k: v for k, v in marker.items() if k != "dash"}
self.draw_line_3D(*args, mode="markers", marker=marker, **kwargs)

Check warning on line 611 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L610-L611

Added lines #L610 - L611 were not covered by tests

def draw_multicolor_scatter_3D(self, *args, **kwargs):
kwargs["marker"] = self._handle_multicolor_scatter(kwargs["marker"], kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/sisl/viz/plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from sisl.viz.types import OrbitalQueries, StyleSpec
from sisl.viz.types import Colorscale, OrbitalQueries, StyleSpec

from ..data.bands import BandsData
from ..figure import Figure, get_figure
Expand Down Expand Up @@ -64,7 +64,7 @@ def bands_plot(
"dash": "solid",
},
spindown_style: StyleSpec = {"color": "blue", "width": 1},
colorscale: Optional[str] = None,
colorscale: Optional[Colorscale] = None,
gap: bool = False,
gap_tol: float = 0.01,
gap_color: str = "red",
Expand Down
8 changes: 4 additions & 4 deletions src/sisl/viz/plots/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sisl.typing import AtomsIndex
from sisl.viz.figure import Figure, get_figure
from sisl.viz.plotters import plot_actions as plot_actions
from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, StyleSpec
from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, Colorscale, StyleSpec

from ..plot import Plot
from ..plotters.cell import cell_plot_actions, get_ndim, get_z
Expand Down Expand Up @@ -100,13 +100,13 @@ def geometry_plot(
atoms: AtomsIndex = None,
atoms_style: Sequence[AtomsStyleSpec] = [],
atoms_scale: float = 1.0,
atoms_colorscale: Optional[str] = None,
atoms_colorscale: Optional[Colorscale] = None,
drawing_mode: Literal["scatter", "balls", None] = None,
bind_bonds_to_ats: bool = True,
points_per_bond: int = 20,
bonds_style: StyleSpec = {},
bonds_scale: float = 1.0,
bonds_colorscale: Optional[str] = None,
bonds_colorscale: Optional[Colorscale] = None,
show_atoms: bool = True,
show_bonds: bool = True,
show_cell: Literal["box", "axes", False] = "box",
Expand Down Expand Up @@ -290,7 +290,7 @@ def sites_plot(
sites_style: Sequence[AtomsStyleSpec] = [],
sites_scale: float = 1.0,
sites_name: str = "Sites",
sites_colorscale: Optional[str] = None,
sites_colorscale: Optional[Colorscale] = None,
drawing_mode: Literal["scatter", "balls", "line", None] = None,
show_cell: Literal["box", "axes", False] = False,
cell_style: StyleSpec = {},
Expand Down
6 changes: 3 additions & 3 deletions src/sisl/viz/plots/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
sub_grid,
tile_grid,
)
from ..types import Axes
from ..types import Axes, Colorscale
from .geometry import geometry_plot


Expand Down Expand Up @@ -69,7 +69,7 @@ def grid_plot(
interp: Tuple[int, int, int] = (1, 1, 1),
isos: Sequence[dict] = [],
smooth: bool = False,
colorscale: Optional[str] = None,
colorscale: Optional[Colorscale] = None,
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
show_cell: Literal["box", "axes", False] = "box",
Expand Down Expand Up @@ -219,7 +219,7 @@ def wavefunction_plot(
interp: Tuple[int, int, int] = (1, 1, 1),
isos: Sequence[dict] = [],
smooth: bool = False,
colorscale: Optional[str] = None,
colorscale: Optional[Colorscale] = None,
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
show_cell: Literal["box", "axes", False] = "box",
Expand Down
3 changes: 2 additions & 1 deletion src/sisl/viz/plots/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
matrix_as_array,
sanitize_matrix_arrows,
)
from ..types import Colorscale


def atomic_matrix_plot(
Expand All @@ -36,7 +37,7 @@ def atomic_matrix_plot(
orbital_lines: Union[bool, Dict] = False,
sc_lines: Union[bool, Dict] = False,
color_pixels: bool = True,
colorscale: Optional[str] = "RdBu",
colorscale: Optional[Colorscale] = "RdBu",
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
text: Optional[str] = None,
Expand Down
37 changes: 8 additions & 29 deletions src/sisl/viz/plotutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

import itertools
import os
import sys
from pathlib import Path

import numpy as np
Expand All @@ -23,31 +21,24 @@
except Exception:
tqdm_avail = False

from copy import deepcopy

from sisl._environ import get_environ_variable
from sisl.io.sile import get_sile_rules, get_siles
from sisl.messages import info

from .types import Colorscale

__all__ = [
"running_in_notebook",
"check_widgets",
"get_plot_classes",
"get_plotable_siles",
"get_plotable_variables",
"get_session_classes",
"get_avail_presets",
"get_nested_key",
"modify_nested_dict",
"dictOfLists2listOfDicts",
"get_avail_presets",
"random_color",
"load",
"find_files",
"find_plotable_siles",
"shift_trace",
"normalize_trace",
"swap_trace_axes",
]

# -------------------------------------
Expand Down Expand Up @@ -446,7 +437,7 @@
return "#" + "%06x" % random.randint(0, 0xFFFFFF)


def values_to_colors(values, scale):
def values_to_colors(values, scale: Colorscale):
"""Maps an array of numbers to colors using a colorscale.

Parameters
Expand All @@ -466,23 +457,11 @@
list
the corresponding colors in "rgb(r,g,b)" format.
"""
import matplotlib
import plotly

v_min = np.min(values)
values = (values - v_min) / (np.max(values) - v_min)

scale_colors = plotly.colors.convert_colors_to_same_type(scale, colortype="tuple")[
0
]

if not scale_colors and isinstance(scale, str):
scale_colors = plotly.colors.convert_colors_to_same_type(
scale[0].upper() + scale[1:], colortype="tuple"
)[0]
from plotly.colors import sample_colorscale

Check warning on line 461 in src/sisl/viz/plotutils.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plotutils.py#L461

Added line #L461 was not covered by tests

cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"my color map", scale_colors
)
# Normalize values
min_value = np.min(values)
values = (np.array(values) - min_value) / (np.max(values) - min_value)

Check warning on line 465 in src/sisl/viz/plotutils.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plotutils.py#L464-L465

Added lines #L464 - L465 were not covered by tests

return plotly.colors.convert_colors_to_same_type([cmap(c) for c in values])[0]
return sample_colorscale(scale, values)

Check warning on line 467 in src/sisl/viz/plotutils.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plotutils.py#L467

Added line #L467 was not covered by tests
6 changes: 5 additions & 1 deletion src/sisl/viz/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, NewType, Optional, Sequence, TypedDict, Union
from typing import Any, Literal, NewType, Optional, Sequence, Tuple, TypedDict, Union

import numpy as np

Expand All @@ -19,6 +19,10 @@

Color = NewType("Color", str)

# A colorscale can be a scale name, a sequence of colors or a sequence of
# (value, color) tuples.
Colorscale = Union[str, Sequence[Color], Sequence[Tuple[float, Color]]]

GeometryLike = Union[sisl.Geometry, Any]

Axis = Union[
Expand Down
Loading