diff --git a/CHANGELOG.md b/CHANGELOG.md index d12ade3d1d..9ac9c477bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ we hit release version 1.0.0. as well as `sisl.geom.cgnr`) - Creation of [n]-triangulenes (`sisl.geom.triangulene`) - added `offset` argument in `Geometry.add_vacuum` to enable shifting atomic coordinates +- A new `AtomicMatrixPlot` to plot sparse matrices, #668 ### Fixed - `SparseCSR` ufunc handling, in some corner cases could the dtype casting do things @@ -40,7 +41,7 @@ we hit release version 1.0.0. - `Lattice` objects now issues a warning when created with 0-length vectors - HSX file reads should respect input geometry arguments - enabled slicing in matrix assignments, #650 -- changed `Shape.volume()` to `Shape.volume` +- changed `Shape.volume()` to `Shape.volume` - growth direction for zigzag heteroribbons - `BandStructure` points can now automatically add the `nsc == 1` axis as would be done for assigning matrix elements (it fills with 0's). diff --git a/docs/visualization/viz_module/index.rst b/docs/visualization/viz_module/index.rst index d33b7d0091..48b0df28b5 100644 --- a/docs/visualization/viz_module/index.rst +++ b/docs/visualization/viz_module/index.rst @@ -9,7 +9,7 @@ with your results can be as fast as possible. The plots that you can generate with it are **not bound to a specific visualization framework**. Instead, the users can choose the one that they want based on their taste or on what is available in their environment. Currently, there is support for visualizing the plots with `plotly`_, `matplotlib`_, `blender `_. The flexibility of the framework -allows for the user to **extend the visualizing options** quite simply without modifying ``sisl``'s internal code. +allows for the user to **extend the visualizing options** quite simply without modifying ``sisl``'s internal code. The framework started as a GUI, but then evolved to make it usable by ``sisl`` users directly. Therefore, it can serve as a very robust (highly tested) and featureful **backend to integrate visualizations into graphical interfaces**. @@ -37,6 +37,7 @@ The following notebooks will help you develop a deeper understanding of what eac showcase/GeometryPlot.ipynb showcase/SitesPlot.ipynb showcase/GridPlot.ipynb + showcase/AtomicMatrixPlot.ipynb showcase/BandsPlot.ipynb showcase/FatbandsPlot.ipynb showcase/PdosPlot.ipynb diff --git a/docs/visualization/viz_module/showcase/AtomicMatrixPlot.ipynb b/docs/visualization/viz_module/showcase/AtomicMatrixPlot.ipynb new file mode 100644 index 0000000000..be8bc26389 --- /dev/null +++ b/docs/visualization/viz_module/showcase/AtomicMatrixPlot.ipynb @@ -0,0 +1,513 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-header" + ] + }, + "source": [ + "[![GitHub issues by-label](https://img.shields.io/github/issues-raw/pfebrer/sisl/AtomicmatrixPlot?style=for-the-badge)](https://github.com/pfebrer/sisl/labels/AtomicMatrixPlot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "AtomicMatrixPlot\n", + "=========\n", + "\n", + "`AtomicMatrixPlot` allows you to visualize sparse matrices. This can help you:\n", + "\n", + "- **Understand sparse matrices better**, if you are new to them.\n", + "- Easily **introspect matrices** to debug or understand how to implement new functionality." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sisl\n", + "import sisl.viz\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a toy **Hamiltonian** that we will play with. We will use a chain with two atoms: \n", + "\n", + " - C: With one s orbital and a set of p orbitals.\n", + " - H: With one s orbital." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a C atom with s and p orbitals\n", + "C = sisl.Atom(\n", + " \"C\",\n", + " orbitals=[\n", + " sisl.AtomicOrbital(\"2s\", R=0.8),\n", + " sisl.AtomicOrbital(\"2py\", R=0.8),\n", + " sisl.AtomicOrbital(\"2pz\", R=0.8),\n", + " sisl.AtomicOrbital(\"2px\", R=0.8),\n", + " ],\n", + ")\n", + "\n", + "# Create a H atom with one s orbital\n", + "H = sisl.Atom(\"H\", orbitals=[sisl.AtomicOrbital(\"1s\", R=0.4)])\n", + "\n", + "# Create a chain along X\n", + "geom = sisl.Geometry(\n", + " [[0, 0, 0], [1, 0, 0]],\n", + " atoms=[C, H],\n", + " lattice=sisl.Lattice([2, 10, 10], nsc=[3, 1, 1]),\n", + ")\n", + "\n", + "# Random Hamiltonian with non-zero elements only for orbitals that overlap\n", + "H = sisl.Hamiltonian(geom)\n", + "for i in range(geom.no):\n", + " for j in range(geom.no * 3):\n", + " dist = geom.rij(*geom.o2a([i, j]))\n", + " if dist < 1.2:\n", + " H[i, j] = (np.random.random() - 0.5) * (1 if j < geom.no else 0.2)\n", + "# Symmetrize it to make it more realistic\n", + "H = H + H.transpose()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Matrix as an image\n", + "\n", + "The default mode of `AtomicMatrixPlot` is simply to plot an image where the values of the matrix are encoded as colors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot = H.plot()\n", + "plot.get()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing the colorscale\n", + "\n", + "The most obvious thing to tweak here is the colorscale, you can do so by using the `colorscale` input, which, as usual, accepts any colorscale that the plotting backend can understand." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(colorscale=\"temps\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note how the `temps` colorscale makes it more clear that there are elements of the matrix that are not set (because the matrix is sparse). Those elements are not displayed.\n", + "\n", + "The range of colors **by default is set from min to max**. **Unless there are negative and positive values**. In that case, the colorscale is just **centered at 0** by default.\n", + "\n", + "However, you can set the `crange` and `cmid` to customize the colorscale as you wish. For example, to center the scale at `0.5`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(cmid=0.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And to set the range of the scale from `-4` to `1`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(crange=(-4, 1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how **`crange` takes precedence over `cmid`**. Now, to go back to the default range, just set both to `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(crange=None, cmid=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Show values as text\n", + "\n", + "Colors are nice to give you a quick impression of the relative magnitude of the matrix elements. However, you might want to know the exact value. Although `plotly` shows them when you pass the mouse over the matrix elements, sometimes it might be more convenient to directly display them on top.\n", + "\n", + "To do this, you need to pass a formatting string to the `text` input. For example, to show two decimals:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(text=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can tweak the style of text with the `textfont` input, which is a dictionary with three (optional) keys:\n", + "\n", + "- `color`: Text color.\n", + "- `family`: Font family for the text. Note that different backends might support different fonts.\n", + "- `size`: The size of the font.\n", + "\n", + "The default value will be used for any key that you don't include in the dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(textfont={\"color\": \"blue\", \"family\": \"times\", \"size\": 15})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want the text to go away, set the `text` input to `None`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(text=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot = plot.update_inputs(textfont={}, text=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Separators\n", + "\n", + "Sparse atomic matrices may be hard to interpret because it's hard to know by eye to which atoms/orbitals an element belongs.\n", + "\n", + "Separators come to the rescue by providing a guide for your eye to quickly pinpoint what each element is. There are three types of separators:\n", + "\n", + "- `sc_lines`: Draw lines separating the **different cells of the auxiliary supercell**.\n", + "- `atom_lines`: Draw lines separating the blocks corresponding to **each atom-atom interaction**.\n", + "- `orbital_lines`: Within each atom, draw lines that **isolate the interactions between two sets of orbitals**. E.g. a set of 3 `p` orbitals from one atom and an `s` orbital from another atom.\n", + "\n", + "They all can be activated (deactivated) by setting the corresponding input to `True` (`False`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(\n", + " sc_lines=True,\n", + " atom_lines=True,\n", + " orbital_lines=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sometimes, the **default styles for the lines might not suit your visualization**. For example, they might not play well with your chosen colorscale. In that case, you can **pass a dictionary of line styles** to the inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(\n", + " orbital_lines={\"color\": \"pink\", \"width\": 5, \"dash\": \"dash\", \"opacity\": 0.8},\n", + " sc_lines={\"width\": 4},\n", + " atom_lines={\"color\": \"gray\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Labels\n", + "\n", + "You might want to have a clearer idea of the orbitals that correspond to a given matrix element. You can turn on labels with `set_labels`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(set_labels=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Labels have the format: `Atom index: (l, m)`. where l and m are the quantum numbers of the orbital." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot = plot.update_inputs(set_labels=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Showing only one cell\n", + "\n", + "If you only want to visualize a given cell in the supercell, you can pass the index to the `isc` input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(isc=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To go back to visualizing the whole supercell, just set `isc` to `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(isc=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Arrows\n", + "\n", + "One can ask for an arrow to be drawn on each matrix element. You are free to represent whatever you like as arrows.\n", + "\n", + "The arrow specification works the same as for atom arrows in `GeometryPlot`. It is a dictionary with the key `data` containing the arrow data and the styling keys (`color`, `width`, `opacity`...) to tweak the style.\n", + "\n", + "However, there is one main difference. If `data` is skipped, vertical arrows are drawn, with the value of the corresponding matrix element defining the length of the arrow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(arrows={\"color\": \"blue\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Arrows are normalized so that they fit the box** of their matrix element.\n", + "\n", + "It may be that pixel colors and numbers make it difficult to visualize the arrows. In that case, you can disable them both. We have already seen how to disable text. For pixel colors there's the `color_pixels` input, which is a switch to turn them on or off:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(color_pixels=False, text=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to plot custom data for the arrows, you have to pass a sparse matrix where the last dimension is the cartesian coordinate `(X, Y)`. Let's create some random data to display it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a sparse matrix with the same sparsity pattern as our Hamiltonian,\n", + "# but with an extra dimension that will host the X and Y coordinates.\n", + "arrow_data = sisl.SparseCSR.fromsp(H, H)\n", + "\n", + "# The X coordinate will be the Hamiltonian's value,\n", + "# while the Y coordinate will be just random.\n", + "for i in range(arrow_data.shape[0]):\n", + " for j in range(arrow_data.shape[1]):\n", + " if arrow_data[i, j, 1] != 0:\n", + " arrow_data[i, j, 1] *= np.random.random()\n", + "\n", + "# Let's display the data\n", + "plot.update_inputs(arrows={\"data\": arrow_data, \"color\": \"red\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we showcase how **you can have multiple specifications of arrows**.\n", + "\n", + "We also show how you can use the `center` key. You can set `center` to `\"start\"`, `\"middle\"` or `\"end\"`. It determines which part of the arrow is pinned to the center of the matrix element (the default is `\"middle\"`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot.update_inputs(\n", + " arrows=[\n", + " {\"color\": \"blue\", \"center\": \"start\", \"name\": \"Hamiltonian value\"},\n", + " {\"data\": arrow_data, \"color\": \"red\", \"center\": \"end\", \"name\": \"Some data\"},\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We hope you enjoyed what you learned!\n", + "\n", + "-----\n", + "This next cell is just to create the thumbnail for the notebook in the docs " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nbsphinx-thumbnail" + ] + }, + "outputs": [], + "source": [ + "thumbnail_plot = plot.update_inputs(color_pixels=True, text=\".2f\").update_layout(\n", + " legend_orientation=\"h\"\n", + ")\n", + "\n", + "if thumbnail_plot:\n", + " thumbnail_plot.show(\"png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "notebook-footer" + ] + }, + "source": [ + "-------------" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/sisl/nodes/syntax_nodes.py b/src/sisl/nodes/syntax_nodes.py index 7dc2e341a3..92d2e67434 100644 --- a/src/sisl/nodes/syntax_nodes.py +++ b/src/sisl/nodes/syntax_nodes.py @@ -123,6 +123,9 @@ class CompareSyntaxNode(SyntaxNode): "lt": "<", "ge": ">=", "le": "<=", + "is_": "is", + "is_not": "is not", + "contains": "in", None: "compare", } diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 37dedc6d8e..24e9fc1bf1 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -1036,6 +1036,9 @@ class NodeConverter(ast.NodeTransformer): ast.LtE: "le", ast.Gt: "gt", ast.GtE: "ge", + ast.Is: "is_", + ast.IsNot: "is_not", + ast.In: "contains", } def __init__( diff --git a/src/sisl/viz/_plotables_register.py b/src/sisl/viz/_plotables_register.py index 12d3a49e1e..77c20729fd 100644 --- a/src/sisl/viz/_plotables_register.py +++ b/src/sisl/viz/_plotables_register.py @@ -22,6 +22,25 @@ __all__ = [] +register = register_plotable + +# # ----------------------------------------------------- +# # Register plotable sisl objects +# # ----------------------------------------------------- + +# Matrices +register(sisl.SparseCSR, AtomicMatrixPlot, "matrix", default=True) +register(sisl.SparseOrbital, AtomicMatrixPlot, "matrix", default=True) +register(sisl.SparseAtom, AtomicMatrixPlot, "matrix", default=True) + +# # Geometry +register(sisl.Geometry, GeometryPlot, "geometry", default=True) + +# # Grid +register(sisl.Grid, GridPlot, "grid", default=True) + +# Brilloiun zone +register(sisl.BrillouinZone, SitesPlot, "sites_obj") # ----------------------------------------------------- # Register data sources @@ -48,25 +67,12 @@ # Register plotable siles # ----------------------------------------------------- -register = register_plotable - for GeomSile in get_siles(attrs=["read_geometry"]): register_sile_method(GeomSile, "read_geometry", GeometryPlot, "geometry") for GridSile in get_siles(attrs=["read_grid"]): register_sile_method(GridSile, "read_grid", GridPlot, "grid", default=True) -# # ----------------------------------------------------- -# # Register plotable sisl objects -# # ----------------------------------------------------- - -# # Geometry -register(sisl.Geometry, GeometryPlot, "geometry", default=True) - -# # Grid -register(sisl.Grid, GridPlot, "grid", default=True) - -# Brilloiun zone -register(sisl.BrillouinZone, SitesPlot, "sites_obj") sisl.BandStructure.plot.set_default("bands") +sisl.Hamiltonian.plot.set_default("atomicmatrix") diff --git a/src/sisl/viz/figure/figure.py b/src/sisl/viz/figure/figure.py index e5e7f2cc17..734c1870ec 100644 --- a/src/sisl/viz/figure/figure.py +++ b/src/sisl/viz/figure/figure.py @@ -258,7 +258,14 @@ def init_3D(self): return def init_coloraxis( - self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs + self, + name, + cmin=None, + cmax=None, + cmid=None, + colorscale=None, + showscale=True, + **kwargs, ): """Initializes a color axis to be used by the drawing functions""" self._coloraxes[name] = { @@ -266,6 +273,7 @@ def init_coloraxis( "cmax": cmax, "cmid": cmid, "colorscale": colorscale, + "showscale": showscale, **kwargs, } @@ -880,6 +888,9 @@ def draw_heatmap( name=None, zsmooth=False, coloraxis=None, + opacity=None, + textformat=None, + textfont={}, row=None, col=None, **kwargs, diff --git a/src/sisl/viz/figure/matplotlib.py b/src/sisl/viz/figure/matplotlib.py index c05a1e635f..22eae27c42 100644 --- a/src/sisl/viz/figure/matplotlib.py +++ b/src/sisl/viz/figure/matplotlib.py @@ -1,5 +1,7 @@ import itertools +import math +import matplotlib import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import LineCollection @@ -417,6 +419,9 @@ def draw_heatmap( name=None, zsmooth=False, coloraxis=None, + opacity=None, + textformat=None, + textfont={}, row=None, col=None, _axes=None, @@ -433,7 +438,7 @@ def draw_heatmap( vmin = coloraxis.get("cmin") vmax = coloraxis.get("cmax") - axes.imshow( + im = axes.imshow( values, cmap=colorscale, vmin=vmin, @@ -441,8 +446,95 @@ def draw_heatmap( label=name, extent=extent, origin="lower", + alpha=opacity, ) + if textformat is not None: + self._annotate_heatmap( + im, + data=values, + valfmt="{x:" + textformat + "}", + cmap=matplotlib.colormaps.get_cmap(colorscale), + **textfont, + ) + + def _annotate_heatmap( + self, + im, + cmap, + data=None, + valfmt="{x:.2f}", + textcolors=("black", "white"), + **textkw, + ): + """A function to annotate a heatmap. + + Parameters + ---------- + im + The AxesImage to be labeled. + data + Data used to annotate. If None, the image's data is used. Optional. + valfmt + The format of the annotations inside the heatmap. This should either + use the string format method, e.g. "$ {x:.2f}", or be a + `matplotlib.ticker.Formatter`. Optional. + textcolors + A pair of colors. The first is used for values below a threshold, + the second for those above. Optional. + threshold + Value in data units according to which the colors from textcolors are + applied. If None (the default) uses the middle of the colormap as + separation. Optional. + **kwargs + All other arguments are forwarded to each call to `text` used to create + the text labels. + """ + + if not isinstance(data, (list, np.ndarray)): + data = im.get_array() + + # Set default alignment to center, but allow it to be + # overwritten by textkw. + kw = dict( + horizontalalignment="center", + verticalalignment="center", + ) + kw.update(textkw) + + # Get the formatter in case a string is supplied + if isinstance(valfmt, str): + valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) + + def color_to_textcolor(rgb): + r, g, b = rgb + r *= 255 + g *= 255 + b *= 255 + + hsp = math.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b)) + if hsp > 127.5: + return textcolors[0] + else: + return textcolors[1] + + # Loop over the data and create a `Text` for each "pixel". + # Change the text's color depending on the data. + texts = [] + for i in range(data.shape[0]): + for j in range(data.shape[1]): + if np.isnan(data[i, j]): + continue + + if "color" not in textkw: + rgb = cmap(im.norm(data[i, j]))[:-1] + kw.update(color=color_to_textcolor(rgb)) + + text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) + texts.append(text) + + return texts + def set_axis( self, axis, diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index f0f293d454..2a7ce96085 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -428,14 +428,21 @@ def clear(self, frames=True, layout=False): # METHODS TO STANDARIZE BACKENDS # -------------------------------- def init_coloraxis( - self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs + self, + name, + cmin=None, + cmax=None, + cmid=None, + colorscale=None, + showscale=True, + **kwargs, ): if len(self._coloraxes) == 0: kwargs["ax_name"] = "coloraxis" else: kwargs["ax_name"] = f"coloraxis{len(self._coloraxes) + 1}" - super().init_coloraxis(name, cmin, cmax, cmid, colorscale, **kwargs) + super().init_coloraxis(name, cmin, cmax, cmid, colorscale, showscale, **kwargs) ax_name = kwargs["ax_name"] self.update_layout( @@ -445,6 +452,7 @@ def init_coloraxis( "cmin": cmin, "cmax": cmax, "cmid": cmid, + "showscale": showscale, } } ) @@ -758,10 +766,31 @@ def draw_heatmap( name=None, zsmooth=False, coloraxis=None, + textformat=None, row=None, col=None, **kwargs, ): + if textformat is not None: + # If the user wants a custom color, we must define the text strings to be empty + # for NaN values. If there is not custom color, plotly handles this for us by setting + # the text color to the same as the background for those values so that they are not + # visible. + if "color" in kwargs.get("textfont", {}) and np.any(np.isnan(values)): + to_string = np.vectorize( + lambda x: "" if np.isnan(x) else f"{x:{textformat}}" + ) + kwargs = { + "text": to_string(values), + "texttemplate": "%{text}", + **kwargs, + } + else: + kwargs = { + "texttemplate": "%{z:" + textformat + "}", + **kwargs, + } + self.add_trace( { "type": "heatmap", @@ -772,6 +801,7 @@ def draw_heatmap( "zsmooth": zsmooth, "coloraxis": self._get_coloraxis_name(coloraxis), "meta": kwargs.pop("meta", {}), + **kwargs, }, row=row, col=col, @@ -819,7 +849,9 @@ def set_axis(self, axis, _active_axes={}, **kwargs): updates = {} if ax_name.endswith("axis"): - updates = {f"scene_{ax_name}": kwargs} + scene_updates = {**kwargs} + scene_updates.pop("constrain", None) + updates = {f"scene_{ax_name}": scene_updates} if axis != "z": updates.update({ax_name: kwargs}) diff --git a/src/sisl/viz/plots/__init__.py b/src/sisl/viz/plots/__init__.py index 6c74428662..a846a08e5a 100644 --- a/src/sisl/viz/plots/__init__.py +++ b/src/sisl/viz/plots/__init__.py @@ -5,5 +5,6 @@ from .bands import BandsPlot, FatbandsPlot, bands_plot, fatbands_plot from .geometry import GeometryPlot, SitesPlot, geometry_plot, sites_plot from .grid import GridPlot, WavefunctionPlot, grid_plot, wavefunction_plot +from .matrix import AtomicMatrixPlot, atomic_matrix_plot from .merged import merge_plots from .pdos import PdosPlot, pdos_plot diff --git a/src/sisl/viz/plots/matrix.py b/src/sisl/viz/plots/matrix.py new file mode 100644 index 0000000000..7476cf38a8 --- /dev/null +++ b/src/sisl/viz/plots/matrix.py @@ -0,0 +1,169 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +from scipy.sparse import spmatrix + +import sisl + +from ..figure import Figure, get_figure +from ..plot import Plot +from ..plotters.grid import draw_grid, draw_grid_arrows +from ..plotters.matrix import draw_matrix_separators, set_matrix_axes +from ..plotters.plot_actions import combined +from ..processors.matrix import ( + determine_color_midpoint, + get_geometry_from_matrix, + get_matrix_mode, + matrix_as_array, + sanitize_matrix_arrows, +) + + +def atomic_matrix_plot( + matrix: Union[ + np.ndarray, sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, spmatrix + ], + dim: int = 0, + isc: Optional[int] = None, + fill_value: Optional[float] = None, + geometry: Union[sisl.Geometry, None] = None, + atom_lines: Union[bool, Dict] = False, + orbital_lines: Union[bool, Dict] = False, + sc_lines: Union[bool, Dict] = False, + color_pixels: bool = True, + colorscale: Optional[str] = "RdBu", + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + text: Optional[str] = None, + textfont: Optional[dict] = {}, + set_labels: bool = False, + constrain_axes: bool = True, + arrows: List[dict] = [], + backend: str = "plotly", +) -> Figure: + """Plots a (possibly sparse) matrix where rows and columns are either orbitals or atoms. + + Parameters + ----------- + matrix: + the matrix, either as a numpy array or as a sisl sparse matrix. + dim: + If the matrix has a third dimension (e.g. spin), which index to + plot in that third dimension. + isc: + If the matrix contains data for an auxiliary supercell, the index of the + cell to plot. If None, the whole matrix is plotted. + fill_value: + If the matrix is sparse, the value to use for the missing entries. + geometry: + Only needed if the matrix does not contain a geometry (e.g. it is a numpy array) + and separator lines or labels are requested. + atom_lines: + If a boolean, whether to draw lines separating atom blocks, using default styles. + If a dict, draws the lines with the specified plotly line styles. + orbital_lines: + If a boolean, whether to draw lines separating blocks of orbital sets, using default styles. + If a dict, draws the lines with the specified plotly line styles. + sc_lines: + If a boolean, whether to draw lines separating the supercells, using default styles. + If a dict, draws the lines with the specified plotly line styles. + color_pixels: + Whether to color the pixels of the matrix according to the colorscale. + colorscale: + The colorscale to use to color the pixels. + crange: + The minimum and maximum values of the colorscale. + cmid: + The midpoint of the colorscale. If ``crange`` is provided, this is ignored. + + If None and crange is also None, the midpoint + is set to 0 if the data contains both positive and negative values. + text: + If provided, show text of pixel value with the specified format. + E.g. text=".3f" shows the value with three decimal places. + textfont: + The font to use for the text. + This is a dictionary that may contain the keys "family", "size", "color". + set_labels: + Whether to set the axes labels to the atom/orbital that each row and column corresponds to. + For orbitals the labels will be of the form "Atom: (l, m)", where `Atom` is the index of + the atom and l and m are the quantum numbers of the orbital. + constrain_axes: + Whether to set the ranges of the axes to exactly fit the matrix. + backend: + The backend to use for plotting. + """ + + geometry = get_geometry_from_matrix(matrix, geometry) + mode = get_matrix_mode(matrix) + + matrix_array = matrix_as_array(matrix, dim=dim, isc=isc, fill_value=fill_value) + + color_midpoint = determine_color_midpoint(matrix, cmid=cmid, crange=crange) + + matrix_actions = draw_grid( + matrix_array, + crange=crange, + cmid=color_midpoint, + color_pixels_2d=color_pixels, + colorscale=colorscale, + coloraxis_name="matrix_vals", + textformat=text, + textfont=textfont, + ) + + arrows = sanitize_matrix_arrows(arrows) + + arrow_actions = draw_grid_arrows(matrix_array, arrows) + + draw_supercells = isc is None + + axes_actions = set_matrix_axes( + matrix_array, + geometry, + matrix_mode=mode, + constrain_axes=constrain_axes, + set_labels=set_labels, + ) + + sc_lines_actions = draw_matrix_separators( + sc_lines, + geometry, + matrix_mode=mode, + separator_mode="supercells", + draw_supercells=draw_supercells, + showlegend=False, + ) + + atom_lines_actions = draw_matrix_separators( + atom_lines, + geometry, + matrix_mode=mode, + separator_mode="atoms", + draw_supercells=draw_supercells, + showlegend=False, + ) + + orbital_lines_actions = draw_matrix_separators( + orbital_lines, + geometry, + matrix_mode=mode, + separator_mode="orbitals", + draw_supercells=draw_supercells, + showlegend=False, + ) + + all_actions = combined( + matrix_actions, + arrow_actions, + orbital_lines_actions, + atom_lines_actions, + sc_lines_actions, + axes_actions, + ) + + return get_figure(backend, all_actions) + + +class AtomicMatrixPlot(Plot): + function = staticmethod(atomic_matrix_plot) diff --git a/src/sisl/viz/plots/tests/test_matrix.py b/src/sisl/viz/plots/tests/test_matrix.py new file mode 100644 index 0000000000..e506b9b9d7 --- /dev/null +++ b/src/sisl/viz/plots/tests/test_matrix.py @@ -0,0 +1,10 @@ +import sisl +from sisl.viz.plots import atomic_matrix_plot + + +def test_atomic_matrix_plot(): + + graphene = sisl.geom.graphene() + H = sisl.Hamiltonian(graphene) + + atomic_matrix_plot(H) diff --git a/src/sisl/viz/plotters/grid.py b/src/sisl/viz/plotters/grid.py index 5960c835d5..ee278882b8 100644 --- a/src/sisl/viz/plotters/grid.py +++ b/src/sisl/viz/plotters/grid.py @@ -1,44 +1,77 @@ +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr + import sisl.viz.plotters.plot_actions as plot_actions from sisl.viz.processors.grid import get_isos -def draw_grid(data, isos=[], colorscale=None, crange=None, cmid=None, smooth=False): +def draw_grid( + data, + isos: List[dict] = [], + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, + cmid: Optional[float] = None, + smooth: bool = False, + color_pixels_2d: bool = True, + textformat: Optional[str] = None, + textfont: dict = {}, + name: Optional[str] = None, + coloraxis_name: Optional[str] = "grid_color", + set_equal_axes: bool = True, +): to_plot = [] + # If it is a numpy array, convert it to a DataArray + if not isinstance(data, xr.DataArray): + # If it's 2D, here we assume that the array is a matrix. + # Therefore, rows are y and columns are x. Otherwise, we + # assume that dimensions are cartesian coordinates. + if data.ndim == 2: + dims = ["y", "x"] + else: + dims = ["x", "y", "z"][: data.ndim] + data = xr.DataArray(data, dims=dims) + ndim = data.ndim if ndim == 1: - to_plot.append(plot_actions.draw_line(x=data.x, y=data.values)) + to_plot.append(plot_actions.draw_line(x=data.x, y=data.values, name=name)) elif ndim == 2: - transposed = data.transpose("y", "x") + data = data.transpose("y", "x") cmin, cmax = crange if crange is not None else (None, None) to_plot.append( plot_actions.init_coloraxis( - name="grid_color", + name=coloraxis_name, cmin=cmin, cmax=cmax, cmid=cmid, colorscale=colorscale, + showscale=color_pixels_2d, ) ) + if not color_pixels_2d: + textfont = {"color": "black", **textfont} + to_plot.append( plot_actions.draw_heatmap( - values=transposed.values, - x=data.x, - y=data.y, - name="HEAT", + values=data.values, + x=data.x if "x" in data.coords else None, + y=data.y if "y" in data.coords else None, + name=name, + opacity=1 if color_pixels_2d else 0, zsmooth="best" if smooth else False, - coloraxis="grid_color", + coloraxis=coloraxis_name, + textformat=textformat, + textfont=textfont, ) ) - dx = data.x[1] - data.x[0] - dy = data.y[1] - data.y[0] - - iso_lines = get_isos(transposed, isos) + iso_lines = get_isos(data, isos) for iso_line in iso_lines: iso_line["line"] = { "color": iso_line.pop("color", None), @@ -53,7 +86,84 @@ def draw_grid(data, isos=[], colorscale=None, crange=None, cmid=None, smooth=Fal for isosurface in isosurfaces: to_plot.append(plot_actions.draw_mesh_3D(**isosurface)) - if ndim > 1: + if set_equal_axes and ndim > 1: to_plot.append(plot_actions.set_axes_equal()) return to_plot + + +def draw_grid_arrows(data, arrows: List[dict]): + to_plot = [] + + # If it is a numpy array, convert it to a DataArray + if not isinstance(data, xr.DataArray): + # If it's 2D, here we assume that the array is a matrix. + # Therefore, rows are y and columns are x. Otherwise, we + # assume that dimensions are cartesian coordinates. + if data.ndim == 2: + dims = ["y", "x"] + else: + dims = ["x", "y", "z"][: data.ndim] + data = xr.DataArray(data, dims=dims) + + ndim = data.ndim + + if ndim == 1: + return [] + elif ndim == 2: + coords = np.array(np.meshgrid(data.x, data.y)) + coords = coords.transpose(1, 2, 0) + flat_coords = coords.reshape(-1, coords.shape[-1]) + + for arrow_data in arrows: + center = arrow_data.get("center", "middle") + + values = ( + arrow_data["data"] + if "data" in arrow_data + else np.stack([np.zeros_like(data.values), -data.values], axis=-1) + ) + arrows_array = xr.DataArray(values, dims=["y", "x", "arrow_coords"]) + + arrow_norms = arrows_array.reduce(np.linalg.norm, "arrow_coords") + arrow_max = np.nanmax(arrow_norms) + normed_arrows = ( + arrows_array / arrow_max * (1 if center == "middle" else 0.5) + ) + + flat_normed_arrows = normed_arrows.values.reshape(-1, coords.shape[-1]) + + x = flat_coords[:, 0] + y = flat_coords[:, 1] + if center == "middle": + x = x - flat_normed_arrows[:, 0] / 2 + y = y - flat_normed_arrows[:, 1] / 2 + elif center == "end": + x = x - flat_normed_arrows[:, 0] + y = y - flat_normed_arrows[:, 1] + elif center != "start": + raise ValueError( + f"Invalid value for 'center' in arrow data: {center}. Must be 'start', 'middle' or 'end'." + ) + + to_plot.append( + plot_actions.draw_arrows( + x=x, + y=y, + dxy=flat_normed_arrows, + name=arrow_data.get("name", None), + line=dict( + width=arrow_data.get("width", None), + color=arrow_data.get("color", None), + opacity=arrow_data.get("opacity", None), + dash=arrow_data.get("dash", None), + ), + arrowhead_scale=arrow_data.get("arrowhead_scale", 0.2), + arrowhead_angle=arrow_data.get("arrowhead_angle", 20), + ) + ) + + elif ndim == 3: + return [] + + return to_plot diff --git a/src/sisl/viz/plotters/matrix.py b/src/sisl/viz/plotters/matrix.py new file mode 100644 index 0000000000..fbcd7e57b1 --- /dev/null +++ b/src/sisl/viz/plotters/matrix.py @@ -0,0 +1,188 @@ +from typing import List, Literal, Union + +import numpy as np + +import sisl + +from ..processors.matrix import get_orbital_sets_positions +from . import plot_actions + + +def draw_matrix_separators( + line: Union[bool, dict], + geometry: sisl.Geometry, + matrix_mode: Literal["orbitals", "atoms"], + separator_mode: Literal["orbitals", "atoms", "supercells"], + draw_supercells: bool = True, + showlegend: bool = True, +) -> List[dict]: + """Returns the actions to draw separators in a matrix. + + Parameters + ---------- + line: + If False, no lines are drawn. + If True, the default line style is used, which depends on `separator_mode`. + If a dictionary, it must contain the line style. + geometry: + The geometry associated to the matrix. + matrix_mode: + Whether the elements of the matrix belong to orbitals or atoms. + separator_mode: + What the separators should separate. + draw_supercells: + Whether to draw separators for the whole matrix (not just the unit cell). + showlegend: + Show the separator lines in the legend. + """ + # Orbital separators don't make sense if it is an atom matrix. + if separator_mode == "orbitals" and matrix_mode == "atoms": + return [] + + # Sanitize the line argument + if line is False: + return [] + elif line is True: + line = {} + + # Determine line styles from the defaults and the provided styles. + default_line = { + "orbitals": {"color": "black", "dash": "dot"}, + "atoms": {"color": "orange"}, + "supercells": {"color": "black"}, + } + + line = {**default_line[separator_mode], **line} + + # Initialize list that will hold the positions of all lines + line_positions = [] + + # Determine the shape of the matrix (how many rows) + sc_len = geometry.no if matrix_mode == "orbitals" else geometry.na + + # If the user just wants to draw a given cell, this is effectively as if + # the supercell was (1,1,1) + n_supercells = geometry.n_s if draw_supercells else 1 + + # Find out the line positions depending on what the separators must separate + if separator_mode == "orbitals": + species_lines = get_orbital_sets_positions(geometry.atoms) + + for atom_specie, atom_first_o in zip(geometry.atoms.specie, geometry.firsto): + lines = species_lines[atom_specie][1:] + for line_pos in lines: + line_positions.append(line_pos + atom_first_o - 0.5) + elif separator_mode == "atoms": + for atom_last_o in geometry.lasto[:-1]: + line_pos = atom_last_o + 0.5 + line_positions.append(line_pos) + elif separator_mode == "supercells": + if n_supercells > 1: + line_positions.append(float(sc_len) - 0.5) + else: + raise ValueError( + "separator_mode must be one of 'orbitals', 'atoms', 'supercells'." + ) + + # If there are no lines to draw, exit + if len(line_positions) == 0: + return [] + + # Horizontal lines: determine X and Y coordinates + if separator_mode == "supercells": + hor_x = hor_y = [] + else: + hor_y = np.repeat(line_positions, 3) + hor_y[2::3] = np.nan + hor_x = np.tile((0, sc_len * n_supercells, np.nan), len(line_positions)) - 0.5 + + # Vertical lines: determine X and Y coordinates (for all supercells) + if n_supercells == 1: + vert_line_positions = line_positions + else: + n_repeats = n_supercells - 1 if separator_mode == "supercells" else n_supercells + + vert_line_positions = np.tile(line_positions, n_repeats).reshape(n_repeats, -1) + vert_line_positions += (np.arange(n_repeats) * sc_len).reshape(-1, 1) + vert_line_positions = vert_line_positions.ravel() + + vert_x = np.repeat(vert_line_positions, 3) + vert_x[2::3] = np.nan + + vert_y = np.tile((0, sc_len, np.nan), len(vert_line_positions)) - 0.5 + + return [ + plot_actions.draw_line( + x=np.concatenate([hor_x, vert_x]), + y=np.concatenate([hor_y, vert_y]), + line=line, + name=f"{separator_mode} separators", + showlegend=showlegend, + ) + ] + + +def set_matrix_axes( + matrix, + geometry: sisl.Geometry, + matrix_mode: Literal["orbitals", "atoms"], + constrain_axes: bool = True, + set_labels: bool = False, +) -> List[dict]: + """Configure the axes of a matrix plot + + Parameters + ---------- + matrix: + The matrix that is plotted. + geometry: + The geometry associated to the matrix + matrix_mode: + Whether the elements of the matrix belong to orbitals or atoms. + constrain_axes: + Whether to try to constrain the axes to the domain of the matrix. + set_labels: + Whether to set the axis labels for each element of the matrix. + """ + actions = [] + + actions.append(plot_actions.set_axes_equal()) + + x_kwargs = {} + y_kwargs = {} + + if constrain_axes: + x_kwargs["range"] = [-0.5, matrix.shape[1] - 0.5] + x_kwargs["constrain"] = "domain" + + y_kwargs["range"] = [matrix.shape[0] - 0.5, -0.5] + y_kwargs["constrain"] = "domain" + + if set_labels: + if matrix_mode == "orbitals": + atoms_ticks = [] + atoms = geometry.atoms.atom + for i, atom in enumerate(atoms): + atom_ticks = [] + atoms_ticks.append(atom_ticks) + for orb in atom.orbitals: + atom_ticks.append(f"({orb.l}, {orb.m})") + + ticks = [] + for i, specie in enumerate(geometry.atoms.specie): + ticks.extend([f"{i}: {orb}" for orb in atoms_ticks[specie]]) + else: + ticks = np.arange(matrix.shape[0]).astype(str) + + x_kwargs["ticktext"] = np.tile(ticks, geometry.n_s) + x_kwargs["tickvals"] = np.arange(matrix.shape[1]) + + y_kwargs["ticktext"] = ticks + y_kwargs["tickvals"] = np.arange(matrix.shape[0]) + + if len(x_kwargs) > 0: + actions.append(plot_actions.set_axis(axis="x", **x_kwargs)) + if len(y_kwargs) > 0: + actions.append(plot_actions.set_axis(axis="y", **y_kwargs)) + + return actions diff --git a/src/sisl/viz/plotters/tests/test_matrix.py b/src/sisl/viz/plotters/tests/test_matrix.py new file mode 100644 index 0000000000..92b2ac1425 --- /dev/null +++ b/src/sisl/viz/plotters/tests/test_matrix.py @@ -0,0 +1,129 @@ +import itertools + +import numpy as np +import pytest + +import sisl +from sisl.viz.plotters.matrix import draw_matrix_separators, set_matrix_axes + + +def test_draw_matrix_separators_empty(): + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + # Check combinations that should give no lines + assert draw_matrix_separators(False, geom, "orbitals", "orbitals") == [] + assert draw_matrix_separators(True, geom, "atoms", "orbitals") == [] + + +@pytest.mark.parametrize( + "draw_supercells,separator_mode", + itertools.product([True, False], ["atoms", "orbitals", "supercells"]), +) +def test_draw_matrix_separators(draw_supercells, separator_mode): + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + lines = draw_matrix_separators( + {"color": "red"}, + geom, + "orbitals", + separator_mode=separator_mode, + draw_supercells=draw_supercells, + ) + + if not draw_supercells and separator_mode == "supercells": + assert len(lines) == 0 + return + + assert len(lines) == 1 + assert isinstance(lines[0], dict) + action = lines[0] + assert action["method"] == "draw_line" + # Check that the number of points in the line is fine + n_expected_points = { + ("atoms", False): 6, + ("atoms", True): 30, + ("orbitals", False): 12, + ("orbitals", True): 60, + ("supercells", True): 24, + }[separator_mode, draw_supercells] + + assert action["kwargs"]["x"].shape == (n_expected_points,) + assert action["kwargs"]["y"].shape == (n_expected_points,) + + assert action["kwargs"]["line"]["color"] == "red" + + +def test_set_matrix_axes(): + + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + matrix = np.zeros((geom.no, geom.no * geom.n_s)) + + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=False, set_labels=False + ) + assert len(actions) == 1 + assert actions[0]["method"] == "set_axes_equal" + + # Test without labels + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=True, set_labels=False + ) + assert len(actions) == 3 + assert actions[0]["method"] == "set_axes_equal" + assert actions[1]["method"] == "set_axis" + assert actions[1]["kwargs"]["axis"] == "x" + assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5] + assert "tickvals" not in actions[1]["kwargs"] + assert "ticktext" not in actions[1]["kwargs"] + + assert actions[2]["method"] == "set_axis" + assert actions[2]["kwargs"]["axis"] == "y" + assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5] + assert "tickvals" not in actions[2]["kwargs"] + assert "ticktext" not in actions[2]["kwargs"] + + # Test with labels + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=True, set_labels=True + ) + assert len(actions) == 3 + assert actions[0]["method"] == "set_axes_equal" + assert actions[1]["method"] == "set_axis" + assert actions[1]["kwargs"]["axis"] == "x" + assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5] + assert np.all(actions[1]["kwargs"]["tickvals"] == np.arange(geom.no * geom.n_s)) + assert len(actions[1]["kwargs"]["ticktext"]) == geom.no * geom.n_s + + assert actions[2]["method"] == "set_axis" + assert actions[2]["kwargs"]["axis"] == "y" + assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5] + assert np.all(actions[2]["kwargs"]["tickvals"] == np.arange(geom.no)) + assert len(actions[2]["kwargs"]["ticktext"]) == geom.no diff --git a/src/sisl/viz/processors/matrix.py b/src/sisl/viz/processors/matrix.py new file mode 100644 index 0000000000..b7beb5f784 --- /dev/null +++ b/src/sisl/viz/processors/matrix.py @@ -0,0 +1,173 @@ +from typing import List, Literal, Optional, Tuple, Union + +import numpy as np +from scipy.sparse import issparse + +import sisl + + +def get_orbital_sets_positions(atoms: List[sisl.Atom]) -> List[List[int]]: + """Gets the orbital indices where an orbital set starts for each atom. + + An "orbital set" is a group of 2l + 1 orbitals with an angular momentum l + and different m. + + Parameters + ---------- + atoms : + List of atoms for which the orbital sets positions are desired. + """ + specie_orb_sets = [] + for at in atoms: + orbitals = at.orbitals + + i_orb = 0 + positions = [] + while i_orb < len(orbitals): + positions.append(i_orb) + + i_orb = i_orb + 1 + 2 * orbitals[i_orb].l + + specie_orb_sets.append(positions) + + return specie_orb_sets + + +def get_geometry_from_matrix( + matrix: Union[sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, np.ndarray], + geometry: Optional[sisl.Geometry] = None, +): + """Returns the geometry associated to a matrix. + + Parameters + ---------- + matrix : + The matrix for which the geometry is desired, which may have + an associated geometry. + geometry : + The geometry to be returned. This is to be used when we already + have a geometry and we don't want to extract it from the matrix. + """ + if geometry is not None: + pass + elif hasattr(matrix, "geometry"): + geometry = matrix.geometry + + return geometry + + +def matrix_as_array( + matrix: Union[sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, np.ndarray], + dim: Optional[int] = 0, + isc: Optional[int] = None, + fill_value: Optional[float] = None, +) -> np.ndarray: + """Converts any type of matrix to a numpy array. + + Parameters + ---------- + matrix : + The matrix to be converted. + dim : + If the matrix is a sisl sparse matrix and it has a third dimension, the + index to get in that third dimension. + isc : + If the matrix is a sisl SparseAtom or SparseOrbital, the index of the + cell within the auxiliary supercell. + + If None, the whole matrix is returned. + fill_value : + If the matrix is a sparse matrix, the value to fill the unset elements. + """ + if isinstance(matrix, (sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital)): + if dim is None: + if isinstance(matrix, (sisl.SparseAtom, sisl.SparseOrbital)): + matrix = matrix._csr + + matrix = matrix.todense() + else: + matrix = matrix.tocsr(dim=dim) + + if issparse(matrix): + matrix = matrix.toarray() + matrix[matrix == 0] = fill_value + + if isc is not None: + matrix = matrix[:, matrix.shape[0] * isc : matrix.shape[0] * (isc + 1)] + + matrix = np.array(matrix) + + return matrix + + +def determine_color_midpoint( + matrix: np.ndarray, + cmid: Optional[float] = None, + crange: Optional[Tuple[float, float]] = None, +) -> Optional[float]: + """Determines the midpoint of a colorscale given a matrix of values. + + If ``cmid`` or ``crange`` are provided, this function just returns ``cmid``. + However, if none of them are provided, it returns 0 if the matrix has both + positive and negative values, and None otherwise. + + Parameters + ---------- + matrix : + The matrix of values for which the colorscale is to be determined. + cmid : + Possible already determined midpoint. + crange : + Possible already determined range. + """ + if cmid is not None: + return cmid + elif crange is not None: + return cmid + elif np.sum(matrix < 0) > 0 and np.sum(matrix > 0) > 0: + return 0 + else: + return None + + +def get_matrix_mode(matrix) -> Literal["atoms", "orbitals"]: + """Returns what the elements of the matrix represent. + + If the matrix is a sisl SparseAtom, the elements are atoms. + Otherwise, they are assumed to be orbitals. + + Parameters + ---------- + matrix : + The matrix for which the mode is desired. + """ + return "atoms" if isinstance(matrix, sisl.SparseAtom) else "orbitals" + + +def sanitize_matrix_arrows(arrows: Union[dict, List[dict]]) -> List[dict]: + """Sanitizes an ``arrows`` argument to a list of sanitized specifications. + + Parameters + ---------- + arrows : + The arrows argument to be sanitized. If it is a dictionary, it is converted to a list + with a single element. + """ + if isinstance(arrows, dict): + arrows = [arrows] + + san_arrows = [] + for arrow in arrows: + arrow = arrow.copy() + san_arrows.append(arrow) + + if "data" in arrow: + arrow["data"] = matrix_as_array(arrow["data"], dim=None) + + # Matrices have the y axis reverted. + arrow["data"][..., 1] *= -1 + + if "center" not in arrow: + arrow["center"] = "middle" + + return san_arrows diff --git a/src/sisl/viz/processors/tests/test_matrix.py b/src/sisl/viz/processors/tests/test_matrix.py new file mode 100644 index 0000000000..9052fc2fcd --- /dev/null +++ b/src/sisl/viz/processors/tests/test_matrix.py @@ -0,0 +1,151 @@ +import numpy as np +import pytest + +import sisl +from sisl.viz.processors.matrix import ( + determine_color_midpoint, + get_geometry_from_matrix, + get_matrix_mode, + get_orbital_sets_positions, + matrix_as_array, + sanitize_matrix_arrows, +) + +pytestmark = [pytest.mark.viz, pytest.mark.processors] + + +def test_orbital_positions(): + + C = sisl.Atom( + 6, + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + + H = sisl.Atom(1, orbitals=[sisl.AtomicOrbital("1s")]) + + positions = get_orbital_sets_positions([C, H]) + + assert len(positions) == 2 + + assert positions[0] == [0, 1, 4] + assert positions[1] == [0] + + +def test_get_geometry_from_matrix(): + + geom = sisl.geom.graphene() + + matrix = sisl.Hamiltonian(geom) + + assert get_geometry_from_matrix(matrix) is geom + + geom_copy = geom.copy() + + assert get_geometry_from_matrix(matrix, geom_copy) is geom_copy + + # Check that if we pass something without an associated geometry + # but we provide a geometry it will work + assert get_geometry_from_matrix(np.array([1, 2]), geom) is geom + + +def test_matrix_as_array(): + + matrix = sisl.SparseCSR((2, 2, 2)) + + matrix[0, 0, 0] = 1 + matrix[0, 0, 1] = 2 + + array = matrix_as_array(matrix, fill_value=0) + assert np.allclose(array, np.array([[1, 0], [0, 0]])) + + array = matrix_as_array(matrix, dim=1, fill_value=0) + assert np.allclose(array, np.array([[2, 0], [0, 0]])) + + array = matrix_as_array(matrix) + assert array[0, 0] == 1 + assert np.isnan(array).sum() == 3 + + # Check that it can work with auxiliary supercells + geom = sisl.geom.graphene( + atoms=sisl.Atom("C", orbitals=[sisl.AtomicOrbital("2pz")]) + ) + matrix = sisl.Hamiltonian(geom) + + array = matrix_as_array(matrix) + assert array.shape == matrix.shape[:-1] + + array = matrix_as_array(matrix, isc=1) + assert array.shape == (geom.no, geom.no) + + # Check that a numpy array is kept untouched + matrix = np.array([[1, 2], [3, 4]]) + assert np.allclose(matrix_as_array(matrix), matrix) + + +def test_determine_color_midpoint(): + + # With the matrix containing only positive values + matrix = np.array([1, 2]) + + assert determine_color_midpoint(matrix) is None + assert determine_color_midpoint(matrix, cmid=1, crange=(0, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(0, 1)) is None + + # With the matrix containing only negative values + matrix = np.array([-1, -2]) + + assert determine_color_midpoint(matrix) is None + assert determine_color_midpoint(matrix, cmid=1, crange=(0, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(0, 1)) is None + + # With the matrix containing both positive and negative values + matrix = np.array([-1, 1]) + + assert determine_color_midpoint(matrix) == 0 + assert determine_color_midpoint(matrix, cmid=1, crange=(-1, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(-1, 1)) is None + + +def test_get_matrix_mode(): + + geom = sisl.geom.graphene() + + matrix = sisl.SparseAtom(geom) + assert get_matrix_mode(matrix) == "atoms" + + matrix = sisl.Hamiltonian(geom) + assert get_matrix_mode(matrix) == "orbitals" + + matrix = sisl.SparseCSR((2, 2)) + assert get_matrix_mode(matrix) == "orbitals" + + matrix = np.array([[1, 2], [3, 4]]) + assert get_matrix_mode(matrix) == "orbitals" + + +def test_sanitize_matrix_arrows(): + + arrows = {} + assert sanitize_matrix_arrows(arrows) == [{"center": "middle"}] + + geom = sisl.geom.graphene() + data = sisl.Hamiltonian(geom, dim=2) + data[0, 0, 0] = 1 + data[0, 0, 1] = 2 + + arrows = [{"data": data}] + sanitized = sanitize_matrix_arrows(arrows) + + assert len(sanitized) == 1 + assert sanitized[0]["data"].shape == data.shape + san_data = sanitized[0]["data"] + assert san_data[0, 0, 0] == 1 + assert san_data[0, 0, 1] == -2