Skip to content

Commit

Permalink
Last touches and tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Feb 19, 2024
1 parent ac5224f commit 42fc8d1
Show file tree
Hide file tree
Showing 8 changed files with 885 additions and 172 deletions.
513 changes: 513 additions & 0 deletions docs/visualization/viz_module/showcase/MatrixPlot.ipynb

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion src/sisl/viz/figure/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,22 @@ def init_3D(self):
return

def init_coloraxis(
self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs
self,
name,
cmin=None,
cmax=None,
cmid=None,
colorscale=None,
showscale=True,
**kwargs,
):
"""Initializes a color axis to be used by the drawing functions"""
self._coloraxes[name] = {
"cmin": cmin,
"cmax": cmax,
"cmid": cmid,
"colorscale": colorscale,
"showscale": showscale,
**kwargs,
}

Expand Down Expand Up @@ -880,6 +888,9 @@ def draw_heatmap(
name=None,
zsmooth=False,
coloraxis=None,
opacity=None,
textformat=None,
textfont={},
row=None,
col=None,
**kwargs,
Expand Down
94 changes: 93 additions & 1 deletion src/sisl/viz/figure/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import itertools
import math

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
Expand Down Expand Up @@ -417,6 +419,9 @@ def draw_heatmap(
name=None,
zsmooth=False,
coloraxis=None,
opacity=None,
textformat=None,
textfont={},
row=None,
col=None,
_axes=None,
Expand All @@ -433,16 +438,103 @@ def draw_heatmap(
vmin = coloraxis.get("cmin")
vmax = coloraxis.get("cmax")

axes.imshow(
im = axes.imshow(
values,
cmap=colorscale,
vmin=vmin,
vmax=vmax,
label=name,
extent=extent,
origin="lower",
alpha=opacity,
)

if textformat is not None:
self._annotate_heatmap(
im,
data=values,
valfmt="{x:" + textformat + "}",
cmap=matplotlib.colormaps.get_cmap(colorscale),
**textfont,
)

def _annotate_heatmap(
self,
im,
cmap,
data=None,
valfmt="{x:.2f}",
textcolors=("black", "white"),
**textkw,
):
"""A function to annotate a heatmap.
Parameters
----------
im
The AxesImage to be labeled.
data
Data used to annotate. If None, the image's data is used. Optional.
valfmt
The format of the annotations inside the heatmap. This should either
use the string format method, e.g. "$ {x:.2f}", or be a
`matplotlib.ticker.Formatter`. Optional.
textcolors
A pair of colors. The first is used for values below a threshold,
the second for those above. Optional.
threshold
Value in data units according to which the colors from textcolors are
applied. If None (the default) uses the middle of the colormap as
separation. Optional.
**kwargs
All other arguments are forwarded to each call to `text` used to create
the text labels.
"""

if not isinstance(data, (list, np.ndarray)):
data = im.get_array()

# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(
horizontalalignment="center",
verticalalignment="center",
)
kw.update(textkw)

# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

def color_to_textcolor(rgb):
r, g, b = rgb
r *= 255
g *= 255
b *= 255

hsp = math.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b))
if hsp > 127.5:
return textcolors[0]
else:
return textcolors[1]

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
if np.isnan(data[i, j]):
continue

if "color" not in textkw:
rgb = cmap(im.norm(data[i, j]))[:-1]
kw.update(color=color_to_textcolor(rgb))

text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
texts.append(text)

return texts

def set_axis(
self,
axis,
Expand Down
36 changes: 31 additions & 5 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,21 @@ def clear(self, frames=True, layout=False):
# METHODS TO STANDARIZE BACKENDS
# --------------------------------
def init_coloraxis(
self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs
self,
name,
cmin=None,
cmax=None,
cmid=None,
colorscale=None,
showscale=True,
**kwargs,
):
if len(self._coloraxes) == 0:
kwargs["ax_name"] = "coloraxis"
else:
kwargs["ax_name"] = f"coloraxis{len(self._coloraxes) + 1}"

super().init_coloraxis(name, cmin, cmax, cmid, colorscale, **kwargs)
super().init_coloraxis(name, cmin, cmax, cmid, colorscale, showscale, **kwargs)

ax_name = kwargs["ax_name"]
self.update_layout(
Expand All @@ -445,6 +452,7 @@ def init_coloraxis(
"cmin": cmin,
"cmax": cmax,
"cmid": cmid,
"showscale": showscale,
}
}
)
Expand Down Expand Up @@ -763,6 +771,26 @@ def draw_heatmap(
col=None,
**kwargs,
):
if textformat is not None:
# If the user wants a custom color, we must define the text strings to be empty
# for NaN values. If there is not custom color, plotly handles this for us by setting
# the text color to the same as the background for those values so that they are not
# visible.
if "color" in kwargs.get("textfont", {}) and np.any(np.isnan(values)):
to_string = np.vectorize(
lambda x: "" if np.isnan(x) else f"{x:{textformat}}"
)
kwargs = {
"text": to_string(values),
"texttemplate": "%{text}",
**kwargs,
}
else:
kwargs = {
"texttemplate": "%{z:" + textformat + "}",
**kwargs,
}

self.add_trace(
{
"type": "heatmap",
Expand All @@ -773,9 +801,7 @@ def draw_heatmap(
"zsmooth": zsmooth,
"coloraxis": self._get_coloraxis_name(coloraxis),
"meta": kwargs.pop("meta", {}),
"texttemplate": "%{z:" + textformat + "}"
if textformat is not None
else None,
**kwargs,
},
row=row,
col=col,
Expand Down
48 changes: 40 additions & 8 deletions src/sisl/viz/plots/matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from scipy.sparse import spmatrix
Expand All @@ -8,28 +8,32 @@
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.grid import draw_grid, draw_grid_arrows
from ..plotters.matrix import draw_matrix_separators, set_matrix_axes
from ..plotters.plot_actions import combined
from ..processors.matrix import (
determine_color_midpoint,
draw_matrix_separators,
get_geometry_from_matrix,
get_matrix_mode,
matrix_as_array,
sanitize_matrix_arrows,
set_matrix_axes,
)


def atomic_matrix_plot(
matrix: Union[np.ndarray, sisl.SparseCSR, spmatrix],
dim: int = 0,
isc: Optional[int] = None,
fill_value: Optional[float] = None,
geometry: Union[sisl.Geometry, None] = None,
atom_lines: Union[bool, Dict] = False,
orbital_lines: Union[bool, Dict] = False,
sc_lines: Union[bool, Dict] = False,
colorscale: str = "RdBu",
color_pixels: bool = True,
colorscale: Optional[str] = "RdBu",
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
text: Optional[str] = None,
textfont: Optional[dict] = {},
set_labels: bool = False,
constrain_axes: bool = True,
arrows: List[dict] = [],
Expand All @@ -44,6 +48,9 @@ def atomic_matrix_plot(
dim:
If the matrix has a third dimension (e.g. spin), which index to
plot in that third dimension.
isc:
If the matrix contains data for an auxiliary supercell, the index of the
cell to plot. If None, the whole matrix is plotted.
fill_value:
If the matrix is sparse, the value to use for the missing entries.
geometry:
Expand All @@ -58,11 +65,23 @@ def atomic_matrix_plot(
sc_lines:
If a boolean, whether to draw lines separating the supercells, using default styles.
If a dict, draws the lines with the specified plotly line styles.
color_pixels:
Whether to color the pixels of the matrix according to the colorscale.
colorscale:
The colorscale to use to color the pixels.
crange:
The minimum and maximum values of the colorscale.
cmid:
The midpoint of the colorscale. If ``crange`` is provided, this is ignored.
If None and crange is also None, the midpoint
is set to 0 if the data contains both positive and negative values.
text:
If provided, show text of pixel value with the specified format.
E.g. text=".3f" shows the value with three decimal places.
textfont:
The font to use for the text.
This is a dictionary that may contain the keys "family", "size", "color".
set_labels:
Whether to set the axes labels to the atom/orbital that each row and column corresponds to.
For orbitals the labels will be of the form "Atom: (l, m)", where `Atom` is the index of
Expand All @@ -76,27 +95,33 @@ def atomic_matrix_plot(
geometry = get_geometry_from_matrix(matrix, geometry)
mode = get_matrix_mode(matrix)

matrix_array = matrix_as_array(matrix, dim=dim, fill_value=fill_value)
matrix_array = matrix_as_array(matrix, dim=dim, isc=isc, fill_value=fill_value)

color_midpoint = determine_color_midpoint(matrix)
color_midpoint = determine_color_midpoint(matrix, cmid=cmid, crange=crange)

matrix_actions = draw_grid(
matrix_array,
crange=crange,
cmid=color_midpoint,
color_pixels_2d=color_pixels,
colorscale=colorscale,
coloraxis_name="matrix_vals",
textformat=text,
textfont=textfont,
)

arrows = sanitize_matrix_arrows(arrows)

arrow_actions = draw_grid_arrows(matrix_array, arrows)

draw_supercells = isc is None

axes_actions = set_matrix_axes(
matrix,
matrix_array,
geometry,
matrix_mode=mode,
constrain_axes=constrain_axes,
draw_supercells=draw_supercells,
set_labels=set_labels,
)

Expand All @@ -105,18 +130,25 @@ def atomic_matrix_plot(
geometry,
matrix_mode=mode,
separator_mode="supercells",
draw_supercells=draw_supercells,
showlegend=False,
)

atom_lines_actions = draw_matrix_separators(
atom_lines, geometry, matrix_mode=mode, separator_mode="atoms", showlegend=False
atom_lines,
geometry,
matrix_mode=mode,
separator_mode="atoms",
draw_supercells=draw_supercells,
showlegend=False,
)

orbital_lines_actions = draw_matrix_separators(
orbital_lines,
geometry,
matrix_mode=mode,
separator_mode="orbitals",
draw_supercells=draw_supercells,
showlegend=False,
)

Expand Down
Loading

0 comments on commit 42fc8d1

Please sign in to comment.