Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: introduction of meta attribute broke heatmaps #623

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sisl/viz/figure/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15,

return self.draw_line_3D(arrows[:, 0], arrows[:, 1], arrows[:, 2], row=row, col=col, **kwargs)

def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None):
def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, **kwargs):
"""Draws a heatmap following the specifications."""
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_heatmap method.")

Expand Down
4 changes: 2 additions & 2 deletions src/sisl/viz/figure/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def draw_area_line(self, x, y, line={}, name=None, dependent_axis=None, row=None
else:
raise ValueError(f"dependent_axis must be one of 'x', 'y', or None, but was {dependent_axis}")

def draw_scatter(self, x, y, name=None, marker={}, text=None, zorder=2, row=None, col=None, _axes=None, **kwargs):
def draw_scatter(self, x, y, name=None, marker={}, text=None, zorder=2, row=None, col=None, _axes=None, meta={}, **kwargs):
axes = _axes or self._get_subplot_axes(row=row, col=col)
try:
return axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), alpha=marker.get("opacity"), label=name, zorder=zorder, **kwargs)
Expand All @@ -302,7 +302,7 @@ def draw_multicolor_scatter(self, *args, **kwargs):
marker["colorscale"] = coloraxis.get("colorscale")
return super().draw_multicolor_scatter(*args, marker=marker, **kwargs)

def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, _axes=None):
def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, _axes=None, **kwargs):

extent = None
if x is not None and y is not None:
Expand Down
21 changes: 15 additions & 6 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@
'name': name,
'line': {k: v for k, v in line.items() if k != "opacity"},
'opacity': opacity,
"meta": kwargs.pop("meta", {}),
**kwargs,
}, row=row, col=col)

Expand Down Expand Up @@ -407,7 +408,8 @@
"name": name,
"legendgroup": name,
"showlegend": kwargs.pop("showlegend", None),
"fill": "toself"
"fill": "toself",
"meta": kwargs.pop("meta", {})
}, row=row, col=col)

def draw_scatter(self, x, y, name=None, marker={}, **kwargs):
Expand Down Expand Up @@ -450,13 +452,15 @@

iterator = enumerate(zip(np.array(x), np.array(y), np.array(z), style["size"], style["color"], style["opacity"]))

meta = kwargs.pop("meta", {})
showlegend = True
for i, (sp_x, sp_y, sp_z, sp_size, sp_color, sp_opacity) in iterator:
self.draw_ball_3D(
xyz=[sp_x, sp_y, sp_z],
size=sp_size, color=sp_color, opacity=sp_opacity,
name=f"{name}_{i}",
legendgroup=name, showlegend=showlegend
legendgroup=name, showlegend=showlegend,
meta=meta
)
showlegend = False

Expand All @@ -470,8 +474,8 @@
'color': color,
'showscale': False,
'name': name,
'meta': ['({:.2f}, {:.2f}, {:.2f})'.format(*xyz)],
'hovertemplate': '%{meta[0]}',
'meta': {"position": '({:.2f}, {:.2f}, {:.2f})'.format(*xyz), "meta": kwargs.pop("meta", {})},
'hovertemplate': '%{meta.position}',
**kwargs
}, row=None, col=None)

Expand Down Expand Up @@ -506,6 +510,8 @@
rows_cols['rows'] = [row, row]
if col is not None:
rows_cols['cols'] = [col, col]

meta = kwargs.pop("meta", {})

Check warning on line 514 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L514

Added line #L514 was not covered by tests


self.figure.add_traces([{
Expand All @@ -519,6 +525,7 @@
"legendgroup": name,
"name": f"{name} lines",
"showlegend": False,
"meta": meta
},
{
"type": "cone",
Expand All @@ -536,17 +543,18 @@
"legendgroup": name,
"name": name,
"showlegend": True,
"meta": meta
}], **rows_cols)

def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None):
def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, **kwargs):

self.add_trace({
'type': 'heatmap', 'z': values,
'x': x, 'y': y,
'name': name,
'zsmooth': zsmooth,
'coloraxis': self._get_coloraxis_name(coloraxis),
# **kwargs
'meta': kwargs.pop("meta", {}),
}, row=row, col=col)

def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs):
Expand All @@ -562,6 +570,7 @@
opacity=opacity,
name=name,
showlegend=True,
meta=kwargs.pop("meta", {}),
**kwargs
), row=row, col=col)

Expand Down
7 changes: 5 additions & 2 deletions src/sisl/viz/plots/tests/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from sisl.viz.data import BandsData
from sisl.viz.plots import bands_plot

@pytest.fixture(scope="module", params=["plotly", "matplotlib"])
def backend(request):
return request.param

@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"])
def spin(request):
Expand All @@ -17,5 +20,5 @@ def gap():
def bands_data(spin, gap):
return BandsData.toy_example(spin=spin, gap=gap)

def test_bands_plot(bands_data):
bands_plot(bands_data)
def test_bands_plot(bands_data, backend):
bands_plot(bands_data, backend=backend)
27 changes: 27 additions & 0 deletions src/sisl/viz/plots/tests/test_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

import numpy as np

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'np' is not used.

import sisl
from sisl.viz.plots import geometry_plot


@pytest.fixture(scope="module", params=["plotly", "matplotlib"])
def backend(request):
return request.param

@pytest.fixture(scope="module", params=["x", "xy", "xyz"])
def axes(request):
return request.param

@pytest.fixture(scope="module")
def geometry():
return sisl.geom.graphene()

def test_geometry_plot(geometry, axes, backend):

if axes == "xyz" and backend == "matplotlib":
with pytest.raises(NotImplementedError):
geometry_plot(geometry, axes=axes, backend=backend)
else:
geometry_plot(geometry, axes=axes, backend=backend)
31 changes: 31 additions & 0 deletions src/sisl/viz/plots/tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

import numpy as np

import sisl

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'sisl' is imported with both 'import' and 'import from'.
from sisl import Grid
from sisl.viz.plots import grid_plot

@pytest.fixture(scope="module", params=["plotly", "matplotlib"])
def backend(request):
return request.param

@pytest.fixture(scope="module", params=["x", "xy", "xyz"])
def axes(request):
return request.param

@pytest.fixture(scope="module")
def grid():
geometry = sisl.geom.graphene()
grid = Grid((10, 10, 10), geometry=geometry)

grid.grid[:] = np.linspace(0, 1000, 1000).reshape(10, 10, 10)
return grid

def test_grid_plot(grid, axes, backend):

if axes == "xyz" and backend == "matplotlib":
with pytest.raises(NotImplementedError):
grid_plot(grid, axes=axes, backend=backend)
else:
grid_plot(grid, axes=axes, backend=backend)
20 changes: 20 additions & 0 deletions src/sisl/viz/plots/tests/test_pdos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from sisl import Spin
from sisl.viz.data import PDOSData
from sisl.viz.plots import pdos_plot

@pytest.fixture(scope="module", params=["plotly", "matplotlib"])
def backend(request):
return request.param

@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"])
def spin(request):
return Spin(request.param)

@pytest.fixture(scope="module")
def pdos_data(spin):
return PDOSData.toy_example(spin=spin)

def test_pdos_plot(pdos_data, backend):
pdos_plot(pdos_data, backend=backend)