From 7da0684b0d7010ce8d8ade657e757dca82235bb6 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Thu, 18 Apr 2024 13:43:37 +0200 Subject: [PATCH] removed sisl.plot from the code base Signed-off-by: Nick Papior --- CHANGELOG.md | 1 + src/sisl/__init__.py | 3 -- src/sisl/_core/geometry.py | 76 ---------------------------- src/sisl/_core/lattice.py | 67 ------------------------- src/sisl/_core/orbital.py | 47 ------------------ src/sisl/_plot.py | 56 --------------------- src/sisl/tests/test_plot.py | 98 ------------------------------------- 7 files changed, 1 insertion(+), 347 deletions(-) delete mode 100644 src/sisl/_plot.py delete mode 100644 src/sisl/tests/test_plot.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f3cda9338..6a2919699 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,7 @@ we hit release version 1.0.0. be done for assigning matrix elements (it fills with 0's). ### Removed +- `sisl.plot` is removed (`sisl.viz` is replacing it!) - `cell` argument for `Geometry.translate/move` (it never worked) - removed `Selector` and `TimeSelector`, they were never used internally diff --git a/src/sisl/__init__.py b/src/sisl/__init__.py index d90c8f501..4e975d70e 100644 --- a/src/sisl/__init__.py +++ b/src/sisl/__init__.py @@ -87,9 +87,6 @@ # import the common options used from ._common import * -# Import plot routine -from ._plot import plot as plot - # Import warning classes # We currently do not import warn and info # as they are too generic names in case one does from sisl import * diff --git a/src/sisl/_core/geometry.py b/src/sisl/_core/geometry.py index 2830f5a4d..465c90fdf 100644 --- a/src/sisl/_core/geometry.py +++ b/src/sisl/_core/geometry.py @@ -34,7 +34,6 @@ ) import sisl._array as _a -import sisl._plot as plt from sisl._category import Category, GenericCategory from sisl._dispatch_class import _Dispatchs from sisl._dispatcher import AbstractDispatch, ClassDispatcher, TypeDispatcher @@ -3112,81 +3111,6 @@ def o2sc(self, orbitals: OrbitalsIndex) -> ndarray: """ return self.lattice.offset(self.o2isc(orbitals)) - def __plot__( - self, - axis=None, - lattice: bool = True, - axes=False, - atom_indices: bool = False, - *args, - **kwargs, - ): - """Plot the geometry in a specified ``matplotlib.Axes`` object. - - Parameters - ---------- - axis : array_like, optional - only plot a subset of the axis, defaults to all axis - lattice : bool, optional - If `True` also plot the lattice structure - atom_indices : bool, optional - if true, also add atomic numbering in the plot (0-based) - axes : bool or matplotlib.Axes, optional - the figure axes to plot in (if ``matplotlib.Axes`` object). - If `True` it will create a new figure to plot in. - If `False` it will try and grap the current figure and the current axes. - """ - # Default dictionary for passing to newly created figures - d = dict() - - colors = np.linspace(0, 1, num=self.atoms.nspecies, endpoint=False) - colors = colors[self.atoms.species] - if "s" in kwargs: - area = kwargs.pop("s") - else: - area = _a.arrayd(self.atoms.Z) - area[:] *= 20 * np.pi / area.min() - - if axis is None: - axis = [0, 1, 2] - - # Ensure we have a new 3D Axes3D - if len(axis) == 3: - d["projection"] = "3d" - - # The Geometry determines the axes, then we pass it to supercell. - axes = plt.get_axes(axes, **d) - - # Start by plotting the supercell - if lattice: - axes = self.lattice.__plot__(axis, axes=axes, *args, **kwargs) - - # Create short-hand - xyz = self.xyz - - if axes.__class__.__name__.startswith("Axes3D"): - # We should plot in 3D plots - axes.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], s=area, c=colors, alpha=0.8) - axes.set_zlabel("Ang") - if atom_indices: - for i, loc in enumerate(xyz): - axes.text( - loc[0], loc[1], loc[2], str(i), verticalalignment="bottom" - ) - - else: - axes.scatter(xyz[:, axis[0]], xyz[:, axis[1]], s=area, c=colors, alpha=0.8) - if atom_indices: - for i, loc in enumerate(xyz): - axes.text( - loc[axis[0]], loc[axis[1]], str(i), verticalalignment="bottom" - ) - - axes.set_xlabel("Ang") - axes.set_ylabel("Ang") - - return axes - def equal(self, other: GeometryLike, R: bool = True, tol: float = 1e-4) -> bool: """Whether two geometries are the same (optional not check of the orbital radius) diff --git a/src/sisl/_core/lattice.py b/src/sisl/_core/lattice.py index 10b4aff46..82cb5de93 100644 --- a/src/sisl/_core/lattice.py +++ b/src/sisl/_core/lattice.py @@ -18,7 +18,6 @@ from numpy import dot, ndarray import sisl._array as _a -import sisl._plot as plt from sisl._dispatch_class import _Dispatchs from sisl._dispatcher import AbstractDispatch, ClassDispatcher, TypeDispatcher from sisl._internal import set_module @@ -1062,72 +1061,6 @@ def __setstate__(self, d): self.__init__(d["cell"], d["nsc"], d["origin"]) self.sc_off = d["sc_off"] - def __plot__(self, axis=None, axes=False, *args, **kwargs): - """Plot the supercell in a specified ``matplotlib.Axes`` object. - - Parameters - ---------- - axis : array_like, optional - only plot a subset of the axis, defaults to all axis - axes : bool or matplotlib.Axes, optional - the figure axes to plot in (if ``matplotlib.Axes`` object). - If ``True`` it will create a new figure to plot in. - If ``False`` it will try and grap the current figure and the current axes. - """ - # Default dictionary for passing to newly created figures - d = dict() - - # Try and default the color and alpha - if "color" not in kwargs and len(args) == 0: - kwargs["color"] = "k" - if "alpha" not in kwargs: - kwargs["alpha"] = 0.5 - - if axis is None: - axis = [0, 1, 2] - - # Ensure we have a new 3D Axes3D - if len(axis) == 3: - d["projection"] = "3d" - - axes = plt.get_axes(axes, **d) - - # Create vector objects - o = self.origin - v = [] - for a in axis: - v.append(np.vstack((o[axis], o[axis] + self.cell[a, axis]))) - v = np.array(v) - - if axes.__class__.__name__.startswith("Axes3D"): - # We should plot in 3D plots - for vv in v: - axes.plot(vv[:, 0], vv[:, 1], vv[:, 2], *args, **kwargs) - - v0, v1 = v[0], v[1] - o - axes.plot( - v0[1, 0] + v1[:, 0], - v0[1, 1] + v1[:, 1], - v0[1, 2] + v1[:, 2], - *args, - **kwargs, - ) - - axes.set_zlabel("Ang") - - else: - for vv in v: - axes.plot(vv[:, 0], vv[:, 1], *args, **kwargs) - - v0, v1 = v[0], v[1] - o[axis] - axes.plot(v0[1, 0] + v1[:, 0], v0[1, 1] + v1[:, 1], *args, **kwargs) - axes.plot(v1[1, 0] + v0[:, 0], v1[1, 1] + v0[:, 1], *args, **kwargs) - - axes.set_xlabel("Ang") - axes.set_ylabel("Ang") - - return axes - new_dispatch = Lattice.new to_dispatch = Lattice.to diff --git a/src/sisl/_core/orbital.py b/src/sisl/_core/orbital.py index 49b3df04b..4c7d66ddd 100644 --- a/src/sisl/_core/orbital.py +++ b/src/sisl/_core/orbital.py @@ -25,7 +25,6 @@ from scipy.interpolate import UnivariateSpline import sisl._array as _a -import sisl._plot as plt from sisl._internal import set_module from sisl.constant import a0 from sisl.messages import warn @@ -279,52 +278,6 @@ def equal(self, other, psi: bool = False, radial: bool = False): def __eq__(self, other): return self.equal(other) - def __plot__(self, harmonics: bool = False, axes=False, *args, **kwargs): - """Plot the orbital radial/spherical harmonics - - Parameters - ---------- - harmonics : bool, optional - if `True` the spherical harmonics will be plotted in a 3D only plot a subset of the axis, defaults to all axis - axes : bool or matplotlib.Axes, optional - the figure axes to plot in (if ``matplotlib.Axes`` object). - If ``True`` it will create a new figure to plot in. - If ``False`` it will try and grap the current figure and the current axes. - """ - d = dict() - - if harmonics: - # We are plotting the harmonic part - d["projection"] = "polar" - - axes = plt.get_axes(axes, **d) - - # Add plots - if harmonics: - # Calculate the spherical harmonics - theta, phi = np.meshgrid(np.arange(360), np.arange(180) - 90) - s = self.spher(np.radians(theta), np.radians(phi)) - - # Plot data - cax = axes.contourf(theta, phi, s, *args, **kwargs) - cax.set_clim(s.min(), s.max()) - axes.get_figure().colorbar(cax) - axes.set_title(r"${}$".format(self.name(True))) - # I don't know how exactly to handle this... - # axes.set_xlabel(r"Azimuthal angle $\theta$") - # axes.set_ylabel(r"Polar angle $\phi$") - - else: - # Plot the radial function and 5% above 0 value - r = np.linspace(0, self.R * 1.05, 1000) - f = self.radial(r) - axes.plot(r, f, *args, **kwargs) - axes.set_xlim(left=0) - axes.set_xlabel("Radius [Ang]") - axes.set_ylabel(r"$f(r)$ [1/Ang$^{3/2}$]") - - return axes - def toGrid( self, precision: float = 0.05, c: float = 1.0, R=None, dtype=np.float64, atom=1 ): diff --git a/src/sisl/_plot.py b/src/sisl/_plot.py deleted file mode 100644 index 61c1c2b3c..000000000 --- a/src/sisl/_plot.py +++ /dev/null @@ -1,56 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from __future__ import annotations - -""" An interface routine for plotting different classes in sisl - -It merely calls the `<>.__plot__(**)` routine and returns immediately -""" - -try: - import matplotlib as mlib - import matplotlib.pyplot as mlibplt - import mpl_toolkits.mplot3d as mlib3d - - has_matplotlib = True -except Exception as _matplotlib_import_exception: - mlib = NotImplementedError - mlibplt = NotImplementedError - mlib3d = NotImplementedError - has_matplotlib = False - -__all__ = ["plot", "mlib", "mlibplt", "mlib3d", "get_axes"] - - -def get_axes(axes=False, **kwargs): - if axes is False: - try: - axes = mlibplt.gca() - except Exception: - axes = mlibplt.figure().add_subplot(111, **kwargs) - elif axes is True: - axes = mlibplt.figure().add_subplot(111, **kwargs) - return axes - - -def _plot(obj, *args, **kwargs): - try: - a = getattr(obj, "__plot__") - except AttributeError: - raise NotImplementedError( - f"{obj.__class__.__name__} does not implement the __plot__ method." - ) - return a(*args, **kwargs) - - -if has_matplotlib: - plot = _plot -else: - - def plot(obj, *args, **kwargs): - raise _matplotlib_import_exception # pylint: disable=E0601 - - -# Clean up -del has_matplotlib diff --git a/src/sisl/tests/test_plot.py b/src/sisl/tests/test_plot.py deleted file mode 100644 index 81d73b91e..000000000 --- a/src/sisl/tests/test_plot.py +++ /dev/null @@ -1,98 +0,0 @@ -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. -from __future__ import annotations - -import numpy as np -import pytest - -import sisl - -pytestmark = pytest.mark.plot - -mlib = pytest.importorskip("matplotlib") -plt = pytest.importorskip("matplotlib.pyplot") -mlib3d = pytest.importorskip("mpl_toolkits.mplot3d") - - -def test_supercell_2d(): - g = sisl.geom.graphene() - sisl.plot(g.lattice, axis=[0, 1]) - sisl.plot(g.lattice, axis=[0, 2]) - sisl.plot(g.lattice, axis=[1, 2]) - plt.close("all") - - ax = plt.subplot(111) - sisl.plot(g.lattice, axis=[1, 2], axes=ax) - plt.close("all") - - -def test_supercell_3d(): - g = sisl.geom.graphene() - sisl.plot(g.lattice) - plt.close("all") - - -def test_geometry_2d(): - g = sisl.geom.graphene() - sisl.plot(g, axis=[0, 1]) - sisl.plot(g, axis=[0, 2]) - sisl.plot(g, axis=[1, 2]) - plt.close("all") - - ax = plt.subplot(111) - sisl.plot(g, axis=[1, 2], axes=ax) - plt.close("all") - - -def test_geometry_2d_atom_indices(): - g = sisl.geom.graphene() - sisl.plot(g, axis=[0, 1]) - sisl.plot(g, axis=[0, 2]) - sisl.plot(g, axis=[1, 2]) - plt.close("all") - - ax = plt.subplot(111) - sisl.plot(g, axis=[1, 2], axes=ax, atom_indices=True) - plt.close("all") - - -def test_geometry_3d(): - g = sisl.geom.graphene() - sisl.plot(g) - plt.close("all") - - -def test_geometry_3d_atom_indices(): - g = sisl.geom.graphene() - sisl.plot(g, atom_indices=True) - plt.close("all") - - -def test_orbital_radial(): - r = np.linspace(0, 10, 1000) - f = np.exp(-r) - o = sisl.SphericalOrbital(2, (r, f)) - sisl.plot(o) - plt.close("all") - - fig = plt.figure() - sisl.plot(o, axes=fig.gca()) - plt.close("all") - - -def test_orbital_harmonics(): - r = np.linspace(0, 10, 1000) - f = np.exp(-r) - o = sisl.SphericalOrbital(2, (r, f)) - sisl.plot(o, harmonics=True) - plt.close("all") - - -def test_not_implemented(): - class Test: - pass - - t = Test() - with pytest.raises(NotImplementedError): - sisl.plot(t)