Skip to content

Commit

Permalink
Experimental: add show_bonds: bool | NearNeighbors = False to `stru…
Browse files Browse the repository at this point in the history
…cture_(2|3)d_plotly` (#233)

* add crystal symmetry props to Key enum

* rename _add_unit_cell to draw_unit_cell, add tests, then simplify

* rename generate_subplot_title to get_subplot_title

* add draw_bonds helper used by structure_(2|3)d_plotly to show bonds based on any pymatgen NearNeighbors subclass

defaults to CrystalNN

* rename get_image_sites->get_image_atoms

* fix and improve test coverage of get_image_sites

* tweak element_pair_rdfs axes label placement

* add and test element_symbol_map: dict[str, str] | None = None to ptable_heatmap_plotly

* default show_bonds to False (since half-baked and still experimental)

* fix test_structure_2d_plotly_multiple and make more accurate
  • Loading branch information
janosh authored Oct 12, 2024
1 parent 25f90bd commit b7adecd
Show file tree
Hide file tree
Showing 11 changed files with 440 additions and 101 deletions.
4 changes: 2 additions & 2 deletions examples/make_assets/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
for struct in df_phonons.nlargest(2, Key.n_sites)[Key.structure]:
fig = pmv.element_pair_rdfs(struct, n_bins=100, cutoff=10)
formula = struct.formula
fig.layout.title.update(text=f"Pairwise RDFs - {formula}", x=0.5, y=0.98)
fig.layout.margin = dict(l=40, r=0, t=50, b=0)
fig.layout.title.update(text=f"Pairwise RDFs - {formula}", x=0.5, y=0.99)
fig.layout.margin.t = 55

fig.show()
pmv.io.save_and_compress_svg(fig, f"element-pair-rdfs-{formula.replace(' ', '')}")
Expand Down
15 changes: 13 additions & 2 deletions pymatviz/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,23 @@ class Key(LabelEnum):
bond_lens = "bond_lengths", f"Bond Lengths {angstrom}"
bond_angles = "bond_angles", "Bond Angles (°)"
packing_fraction = "packing_fraction", "Packing Fraction"
max_pair_dist = "max_pair_dist", f"Maximum Pair Distance {angstrom}"

# Crystal Symmetry Properties
choice_symbol = "choice_symbol", "Choice symbol"
hall_num = "hall_num", "Hall number"
hall_symbol = "hall_symbol", "Hall symbol"
n_rot_ops = "n_rot_ops", "Number of rotational operations"
n_sym_ops = "n_sym_ops", "Total number of symmetry operations"
n_trans_ops = "n_trans_ops", "Number of translational operations"
wyckoff = "wyckoff", "AFLOW-style Label with Chemical System"
wyckoff_spglib = "wyckoff_spglib", "Wyckoff Label (spglib)"
wyckoff_symbols = "wyckoff_symbols", "Wyckoff symbols"

# Structure Prototyping
# AFLOW-style prototype label with appended chemical system
protostructure = "protostructure", "Protostructure Label"
# Deprecated name for the protostructure
wyckoff = "wyckoff", "AFLOW-style Label with Chemical System"
wyckoff_spglib = "wyckoff_spglib", "Wyckoff Label (spglib)"
prototype = "prototype", "Prototype Label"
aflow_prototype = "aflow_prototype", "AFLOW-style Prototype Label"
# AFLOW-style prototype label that has been canonicalized
Expand Down Expand Up @@ -471,6 +481,7 @@ class Key(LabelEnum):
spearman = "Spearman", "Spearman Correlation"
kendall = "Kendall", "Kendall Correlation"
rmse = "RMSE", "Root Mean Squared Error"
rmsd = "rmsd", "Root Mean Square Deviation"
mape = "MAPE", "Mean Absolute Percentage Error"
variance = "variance", "Variance"
std_dev = "std_dev", "Standard Deviation"
Expand Down
11 changes: 9 additions & 2 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def ptable_heatmap_plotly(
exclude_elements: Sequence[str] = (),
log: bool = False,
fill_value: float | None = None,
element_symbol_map: dict[str, str] | None = None,
label_map: dict[str, str] | Callable[[str], str] | Literal[False] | None = None,
border: dict[str, Any] | None | Literal[False] = None,
scale: float = 1.0,
Expand Down Expand Up @@ -118,6 +119,9 @@ def ptable_heatmap_plotly(
log (bool): Whether to use a logarithmic color scale. Defaults to False.
Piece of advice: colorscale="viridis" and log=True go well together.
fill_value (float | None): Value to fill in for missing elements. Defaults to 0.
element_symbol_map (dict[str, str] | None): A dictionary to map element symbols
to custom strings. If provided, these custom strings will be displayed
instead of the standard element symbols. Defaults to None.
label_map (dict[str, str] | Callable[[str], str] | None): Map heat values (after
string formatting) to target strings. Set to False to disable. Defaults to
dict.fromkeys((np.nan, None, "nan", "nan%"), "-") so as not to display "nan"
Expand Down Expand Up @@ -200,14 +204,17 @@ def ptable_heatmap_plotly(
label = label_map(label)
elif isinstance(label_map, dict):
label = label_map.get(label, label)
# Apply custom element symbol if provided
display_symbol = (element_symbol_map or {}).get(symbol, symbol)

style = f"font-weight: bold; font-size: {1.5 * (font_size or 12) * scale};"
tile_text = f"<span {style=}>{symbol}</span>"
tile_text = f"<span {style=}>{display_symbol}</span>"
if show_values and label:
tile_text += f"<br>{label}"

tile_texts[row][col] = tile_text

hover_text = name
hover_text = f"{name} ({symbol})"

if heat_val := heat_value_element_map.get(symbol):
if all_ints:
Expand Down
8 changes: 4 additions & 4 deletions pymatviz/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def element_pair_rdfs(
"<extra></extra>",
)

# Add x-axis label
fig.update_xaxes(title_text="r (Å)", row=row, col=col)
# only show x-axis label on the bottom row of subplots
fig.update_xaxes(title_text="r (Å)", title_standoff=9, row=n_rows)

# Add reference line if specified
if reference_line is not None:
Expand All @@ -272,11 +272,11 @@ def element_pair_rdfs(

# Set subplot height/width and y-axis labels
fig.update_layout(height=300 * n_rows, width=450 * actual_cols)
fig.update_yaxes(title_text="g(r)", col=1)
fig.update_yaxes(title_text="g(r)", title_standoff=9, col=1)

# show legend centered above subplots only if multiple structures were passed
if len(structures) > 1:
fig.layout.legend = dict(
fig.layout.legend.update(
orientation="h",
xanchor="center",
x=0.5,
Expand Down
146 changes: 100 additions & 46 deletions pymatviz/structure_viz/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import functools
import itertools
import math
import warnings
Expand All @@ -15,6 +16,7 @@
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import Composition, Lattice, PeriodicSite, Species, Structure

from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
Expand All @@ -28,6 +30,7 @@

import plotly.graph_objects as go
from numpy.typing import ArrayLike
from pymatgen.analysis.local_env import NearNeighbors


# fallback value (in nanometers) for covalent radius of an element
Expand Down Expand Up @@ -91,23 +94,28 @@ def _angles_to_rotation_matrix(
return rotation


def get_image_atoms(
def get_image_sites(
site: PeriodicSite, lattice: Lattice, tol: float = 0.02
) -> np.ndarray:
"""Get image atoms for a given site."""
coords_image_atoms: list[np.ndarray] = []
"""Get images for a given site in a lattice.
# If the site is at the lattice origin, return an empty array
if np.allclose(site.frac_coords, (0, 0, 0), atol=tol):
return np.array(coords_image_atoms)
Images are sites that are integer translations of the given site that are within a
tolerance of the unit cell edges.
# Generate all possible combinations of lattice vector offsets
offsets = list(itertools.product([0, 1], repeat=3))
Args:
site (PeriodicSite): The site to get images for.
lattice (Lattice): The lattice to get images for.
tol (float): The tolerance for being on the unit cell edge. Defaults to 0.02.
for offset in offsets:
if offset == (0, 0, 0):
continue # Skip the original atom
Returns:
np.ndarray: Coordinates of all image sites.
"""
coords_image_atoms: list[np.ndarray] = []

# Generate all possible combinations of lattice vector offsets (except zero offset)
offsets = set(itertools.product([-1, 0, 1], repeat=3)) - {(0, 0, 0)}

for offset in offsets:
new_frac = site.frac_coords + offset
new_cart = lattice.get_cartesian_coords(new_frac)

Expand Down Expand Up @@ -203,7 +211,7 @@ def generate_site_label(
)


def generate_subplot_title(
def get_subplot_title(
struct_i: Structure,
struct_key: Any,
idx: int,
Expand Down Expand Up @@ -347,7 +355,7 @@ def get_structures(
raise TypeError(f"Expected pymatgen Structure or Sequence of them, got {struct=}")


def _add_unit_cell(
def draw_unit_cell(
fig: go.Figure,
structure: Structure,
unit_cell_kwargs: dict[str, Any],
Expand All @@ -357,38 +365,22 @@ def _add_unit_cell(
col: int | None = None,
scene: str | None = None,
) -> go.Figure:
"""Draw the unit cell of a structure in a 2D or 3D Plotly figure."""
corners = np.array(list(itertools.product((0, 1), (0, 1), (0, 1))))
cart_corners = structure.lattice.get_cartesian_coords(corners)

alpha, beta, gamma = structure.lattice.angles

def add_trace(
x: float | Sequence[float],
y: float | Sequence[float],
z: float | Sequence[float] | None = None,
mode: str = "lines",
marker: dict[str, Any] | None = None,
line: dict[str, Any] | None = None,
hovertext: str | list[str | None] | None = None,
) -> None:
trace_kwargs = dict(
mode=mode,
hoverinfo="text",
hovertext=hovertext,
showlegend=False,
marker=marker,
line=line,
)

if is_3d:
fig.add_scatter3d(x=x, y=y, z=z, scene=scene, **trace_kwargs)
else:
fig.add_scatter(x=x, y=y, row=row, col=col, **trace_kwargs)
trace_adder = ( # prefill args for add_scatter or add_scatter3d
functools.partial(fig.add_scatter3d, scene=scene)
if is_3d
else functools.partial(fig.add_scatter, row=row, col=col)
)

# Add edges
edge_defaults = dict(color="black", width=1, dash="dash")
edge_kwargs = edge_defaults | unit_cell_kwargs.get("edge", {})
for start, end in UNIT_CELL_EDGES:
for idx, (start, end) in enumerate(UNIT_CELL_EDGES):
start_point = cart_corners[start]
end_point = cart_corners[end]
mid_point = (start_point + end_point) / 2
Expand All @@ -403,25 +395,30 @@ def add_trace(
f"[{', '.join(f'{c:.3g}' for c in corners[end])}]"
)

add_trace(
coords = dict(
x=[start_point[0], mid_point[0], end_point[0]],
y=[start_point[1], mid_point[1], end_point[1]],
z=[start_point[2], mid_point[2], end_point[2]] if is_3d else None,
)
if is_3d:
coords["z"] = [start_point[2], mid_point[2], end_point[2]]
trace_adder(
**coords,
mode="lines",
line=edge_kwargs,
hovertext=[None, hover_text, None],
name=f"edge {idx}",
)

# Add corner spheres
node_defaults = dict(size=3, color="black")
node_kwargs = node_defaults | unit_cell_kwargs.get("node", {})
for i, (frac_coord, cart_coord) in enumerate(
for idx, (frac_coord, cart_coord) in enumerate(
zip(corners, cart_corners, strict=False)
):
adjacent_angles = []
for _ in range(3):
v1 = cart_corners[(i + 1) % 8] - cart_coord
v2 = cart_corners[(i + 2) % 8] - cart_coord
v1 = cart_corners[(idx + 1) % 8] - cart_coord
v2 = cart_corners[(idx + 2) % 8] - cart_coord
angle = np.degrees(
np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))
)
Expand All @@ -432,14 +429,15 @@ def add_trace(
f"[{', '.join(f'{c:.3g}' for c in frac_coord)}]<br>"
f"α = {alpha:.3g}°, β = {beta:.3g}°, γ = {gamma:.3g}°" # noqa: RUF001
)

add_trace(
x=[cart_coord[0]],
y=[cart_coord[1]],
z=[cart_coord[2]] if is_3d else None,
coords = dict(x=[cart_coord[0]], y=[cart_coord[1]])
if is_3d:
coords["z"] = [cart_coord[2]]
trace_adder(
**coords,
mode="markers",
marker=node_kwargs,
hovertext=hover_text,
name=f"node {idx}",
)

return fig
Expand Down Expand Up @@ -576,3 +574,59 @@ def get_first_matching_site_prop(
warnings.warn(warn_msg, UserWarning, stacklevel=2)

return None


def draw_bonds(
fig: go.Figure,
structure: Structure,
nn: NearNeighbors,
*,
is_3d: bool = True,
bond_kwargs: dict[str, Any] | None = None,
row: int | None = None,
col: int | None = None,
scene: str | None = None,
visible_image_atoms: set[tuple[float, float, float]] | None = None,
) -> None:
"""Draw bonds between atoms in the structure."""
default_bond_kwargs = dict(color="white", width=4)
bond_kwargs = default_bond_kwargs | (bond_kwargs or {})

for i, site in enumerate(structure):
neighbors = nn.get_nn_info(structure, i)
for neighbor in neighbors:
end_site = neighbor["site"]
end_coords = tuple(end_site.coords)

# Check if the end site is within the unit cell or a visible image atom
is_in_unit_cell = all(0 <= c < 1 for c in end_site.frac_coords)
is_visible_image = visible_image_atoms and end_coords in visible_image_atoms

if is_in_unit_cell or is_visible_image:
start = site.coords
end = end_site.coords

trace_kwargs = dict(
mode="lines",
line=bond_kwargs,
showlegend=False,
hoverinfo="skip",
name=f"bond {i}-{neighbor['site_index']}",
)

if is_3d:
fig.add_scatter3d(
x=[start[0], end[0]],
y=[start[1], end[1]],
z=[start[2], end[2]],
scene=scene,
**trace_kwargs,
)
else:
fig.add_scatter(
x=[start[0], end[0]],
y=[start[1], end[1]],
row=row,
col=col,
**trace_kwargs,
)
Loading

0 comments on commit b7adecd

Please sign in to comment.