Skip to content

Commit

Permalink
Add function ptable_hists (#100)
Browse files Browse the repository at this point in the history
* ptable_heatmap() add keywords cbar_coords (tuple[float, float, float, float]) and rare_earth_voffset (float)

* ptable.py add new function ptable_hists()

* ptable_hists() add keywords
- cbar_title_kwds: dict[str, Any] | None = None
- symbol_pos: tuple[float, float] = (0.5, 0.8)
- log: bool = False
- anno_kwds: dict[str, Any] | None = None

rename 1st arg srs to data: pd.DataFrame | pd.Series | dict[str, list[float]]

* add test_ptable_hists()

* fix SyntaxError

row, group = df_ptable.loc[symbol := element.symbol, ["row", "column"]]

* fix docs
  • Loading branch information
janosh authored Nov 29, 2023
1 parent 12556f5 commit 26e5f61
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 7 deletions.
1 change: 1 addition & 0 deletions pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ptable_heatmap,
ptable_heatmap_plotly,
ptable_heatmap_ratio,
ptable_hists,
)
from pymatviz.relevance import precision_recall_curve, roc_curve
from pymatviz.sankey import sankey_from_2_df_cols
Expand Down
1 change: 1 addition & 0 deletions pymatviz/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def save_fig(
if not isinstance(fig, go.Figure):
raise TypeError(
f"Unsupported figure type {type(fig)}, expected plotly or matplotlib Figure"
" or plt.Axes"
)
is_pdf = path.lower().endswith((".pdf", ".pdfa"))
if path.lower().endswith((".svelte", ".html")):
Expand Down
182 changes: 175 additions & 7 deletions pymatviz/ptable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import math
import sys
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, get_args
Expand All @@ -14,7 +15,7 @@
from matplotlib.colors import LogNorm, Normalize
from matplotlib.patches import Rectangle
from pandas.api.types import is_numeric_dtype, is_string_dtype
from pymatgen.core import Composition
from pymatgen.core import Composition, Element

from pymatviz.utils import df_ptable, pick_bw_for_contrast

Expand Down Expand Up @@ -161,6 +162,7 @@ def ptable_heatmap(
label_font_size: int = 16,
value_font_size: int = 12,
tile_size: float | tuple[float, float] = 0.9,
rare_earth_voffset: float = 0.5,
**kwargs: Any,
) -> plt.Axes:
"""Plot a heatmap across the periodic table of elements.
Expand Down Expand Up @@ -215,6 +217,11 @@ def ptable_heatmap(
tile_size (float | tuple[float, float]): Size of each tile in the periodic
table as a fraction of available space before touching neighboring tiles.
1 or (1, 1) means no gaps between tiles. Defaults to 0.9.
cbar_coords (tuple[float, float, float, float]): Color bar position and size:
[x, y, width, height] anchored at lower left corner of the bar. Defaults to
(0.18, 0.8, 0.42, 0.05).
rare_earth_voffset (float): Vertical offset for lanthanides and actinides
(row 6 and 7) from the rest of the periodic table. Defaults to 0.5.
**kwargs: Additional keyword arguments passed to plt.figure().
Returns:
Expand Down Expand Up @@ -297,8 +304,8 @@ def ptable_heatmap(
label = f"{tile_value:{fmt}}"
# replace shortens scientific notation 1e+01 to 1e1 so it fits inside cells
label = label.replace("e+0", "e")
if row < 3: # vertical offset for lanthanide + actinide series
row += 0.5
if row < 3: # vertical offset for lanthanides + actinides
row += rare_earth_voffset
rect = Rectangle(
(column, row), tile_width, tile_height, edgecolor="gray", facecolor=color
)
Expand Down Expand Up @@ -343,9 +350,9 @@ def ptable_heatmap(
if heat_mode is not None:
# color bar position and size: [x, y, width, height]
# anchored at lower left corner
cb_ax = ax.inset_axes(cbar_coords, transform=ax.transAxes)
cbar_ax = ax.inset_axes(cbar_coords, transform=ax.transAxes)
# format major and minor ticks
cb_ax.tick_params(which="both", labelsize=14, width=1)
cbar_ax.tick_params(which="both", labelsize=14, width=1)

mappable = plt.cm.ScalarMappable(norm=norm, cmap=colorscale)

Expand All @@ -360,7 +367,7 @@ def tick_fmt(val: float, _pos: int) -> str:
cbar_kwargs = cbar_kwargs or {}
cbar = fig.colorbar(
mappable,
cax=cbar_kwargs.pop("cax", cb_ax),
cax=cbar_kwargs.pop("cax", cbar_ax),
orientation=cbar_kwargs.pop("orientation", "horizontal"),
format=cbar_kwargs.pop(
"format", cbar_fmt if callable(cbar_fmt) else tick_fmt
Expand All @@ -369,7 +376,7 @@ def tick_fmt(val: float, _pos: int) -> str:
)

cbar.outline.set_linewidth(1)
cb_ax.set_title(cbar_title, pad=10, **text_style)
cbar_ax.set_title(cbar_title, pad=10, **text_style)

plt.ylim(0.3, n_rows + 0.1)
plt.xlim(0.9, n_columns + 1)
Expand Down Expand Up @@ -696,3 +703,164 @@ def ptable_heatmap_plotly(

fig.update_traces(colorbar=dict(lenmode="fraction", thickness=15, **color_bar))
return fig


def ptable_hists(
data: pd.DataFrame | pd.Series | dict[str, list[float]],
bins: int = 20,
colormap: str = "viridis",
cbar_coords: tuple[float, float, float, float] = (0.18, 0.8, 0.42, 0.02),
x_range: tuple[float | None, float | None] | None = None,
symbol_kwargs: Any = None,
symbol_text: str | Callable[[Element], str] = lambda elem: elem.symbol,
cbar_title: str = "Values",
cbar_title_kwds: dict[str, Any] | None = None,
symbol_pos: tuple[float, float] = (0.5, 0.8),
log: bool = False,
anno_kwds: dict[str, Any] | None = None,
**kwargs: Any,
) -> plt.Figure:
"""Plot histograms of values across the periodic table of elements.
Args:
data (pd.DataFrame | pd.Series | dict[str, list[float]]): Map from element
symbols to histogram values. E.g. if dict, {"Fe": [1, 2, 3], "O": [4, 5]}.
If pd.Series, index is element symbols and values lists. If pd.DataFrame,
column names are element symbols histograms are plotted from each column.
bins (int): Number of bins for the histograms. Defaults to 20.
colormap (str): Matplotlib colormap name to use. Defaults to "viridis".
See https://matplotlib.org/stable/users/explain/colors/colormaps
for available options.
cbar_coords (tuple[float, float, float, float]): Color bar position and size:
[x, y, width, height] anchored at lower left corner of the bar. Defaults to
(0.25, 0.77, 0.35, 0.02).
x_range (tuple[float | None, float | None]): x-axis range for all histograms.
Defaults to None.
symbol_text (str | Callable[[Element], str]): Text to display for each element
symbol. Defaults to lambda elem: elem.symbol.
symbol_kwargs (dict): Keyword arguments passed to plt.text() for element
symbols. Defaults to None.
cbar_title (str): Color bar title. Defaults to "Histogram Value".
cbar_title_kwds (dict): Keyword arguments passed to cbar.ax.set_title().
Defaults to dict(fontsize=12, pad=10).
symbol_pos (tuple[float, float]): Position of element symbols relative to the
lower left corner of each tile. Defaults to (0.5, 0.8). (1, 1) is the upper
right corner.
log (bool): Whether to log scale y-axis of each histogram. Defaults to False.
anno_kwds (dict): Keyword arguments passed to plt.annotate() for element
annotations. Defaults to None. Useful for adding e.g. number of data points
in each histogram. For that, use
anno_kwds=dict(text=lambda hist_vals: str(len(hist_vals))).
Recognized keys are text, xy, xycoords, fontsize, and any other
plt.annotate() keywords.
**kwargs: Additional keyword arguments passed to plt.subplots().
figsize is set to (0.75 * n_columns, 0.75 * n_rows) where n_columns and
n_rows are the number of columns and rows in the periodic table.
Returns:
plt.Figure: periodic table with a histogram in each element tile.
"""
symbol_kwargs = symbol_kwargs or {}
n_rows = df_ptable.row.max()
n_columns = df_ptable.column.max()

kwargs.setdefault("figsize", (0.75 * n_columns, 0.75 * n_rows))
fig, axes = plt.subplots(n_rows, n_columns, **kwargs)
plt.subplots_adjust(wspace=0.4, hspace=0.4)

if isinstance(data, pd.Series):
# use series name as color bar title if available and no title was passed
if cbar_title == "Values" and data.name:
cbar_title = data.name
data = data.to_dict()
elif isinstance(data, pd.DataFrame):
data = data.to_dict(orient="list")

# create a normalized color map
flat_list = [
val
for sublist in (data.values() if isinstance(data, dict) else data)
for val in sublist
]
norm = Normalize(vmin=min(flat_list), vmax=max(flat_list))
cmap = plt.get_cmap(colormap)

# turn off axis of subplots on the grid that don't correspond to elements
for ax in axes.flat:
ax.axis("off")

for Z in range(1, 119):
element = Element.from_Z(Z)
symbol = element.symbol
row, group = df_ptable.loc[symbol, ["row", "column"]]

ax = axes[row - 1][group - 1]
symbol_kwargs.setdefault("fontsize", 10)
ax.text(
*symbol_pos,
symbol_text(element)
if callable(symbol_text)
else symbol_text.format(elem=element),
ha="center",
va="center",
transform=ax.transAxes,
**symbol_kwargs,
)
ax.axis("on") # re-enable axes of elements that exist

hist_data = data.get(symbol, [])
if anno_kwds:
anno_kwds.setdefault("xy", (0.8, 0.8))
anno_kwds.setdefault("xycoords", "axes fraction")
anno_kwds.setdefault("fontsize", 8)
anno_kwds.setdefault("horizontalalignment", "center")
anno_kwds.setdefault("verticalalignment", "center")
anno_text = anno_kwds.get("text")
if isinstance(anno_text, dict):
anno_text = anno_text.get(symbol)
elif callable(anno_text):
anno_text = anno_text(hist_data)
ax.annotate(**(anno_kwds | {"text": anno_text}))
if hist_data:
_n, bins_array, patches = ax.hist(
hist_data, bins=bins, color="C0", alpha=1, log=log
)
if x_range:
ax.set_xlim(x_range)
x_min, x_max = math.floor(min(bins_array)), math.ceil(max(bins_array))
x_ticks = list(x_range or [x_min, x_max])
if x_ticks[0] is None:
x_ticks[0] = x_min
if x_ticks[1] is None:
x_ticks[1] = x_max
if x_min < 0 < x_max:
x_ticks = [x_min, 0, x_max]

for patch, x_val in zip(patches, bins_array[:-1]):
plt.setp(patch, "facecolor", cmap(norm(x_val)))
ax.set_xticks(x_ticks)
ax.set_yticks([])
ax.tick_params(labelsize=8, direction="in")
else: # disable ticks for elements without data
ax.set_xticks([])
ax.set_yticks([])
for side in ("left", "right", "top"):
ax.spines[side].set_visible(False)
# also hide tick marks
ax.tick_params(axis="y", which="both", length=0)

# add colorbar
cbar_ax = fig.add_axes(cbar_coords)
_cbar = fig.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=cmap),
cax=cbar_ax,
orientation="horizontal",
)
# set color bar title
cbar_title_kwds = cbar_title_kwds or {}
cbar_title_kwds.setdefault("fontsize", 12)
cbar_title_kwds.setdefault("pad", 10)
cbar_title_kwds["label"] = cbar_title
cbar_ax.set_title(**cbar_title_kwds)

return fig
21 changes: 21 additions & 0 deletions tests/test_ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ptable_heatmap,
ptable_heatmap_plotly,
ptable_heatmap_ratio,
ptable_hists,
)
from pymatviz.utils import df_ptable, si_fmt

Expand Down Expand Up @@ -359,3 +360,23 @@ def test_ptable_heatmap_plotly_label_map(
any(val in anno.text for val in label_map.values())
for anno in fig.layout.annotations
)


@pytest.mark.parametrize(
"data, symbol_pos, anno_kwds",
[
(pd.DataFrame({"H": [1, 2, 3], "He": [4, 5, 6]}), (0, 0), {}),
(dict(H=[1, 2, 3], He=[4, 5, 6]), (1, 1), dict(text=lambda x: f"{len(x):,}")),
(pd.Series([[1, 2, 3], [4, 5, 6]], index=["H", "He"]), (1, 1), dict(xy=(0, 0))),
],
)
def test_ptable_hists(
data: pd.DataFrame | pd.Series | dict[str, list[int]],
symbol_pos: tuple[int, int],
anno_kwds: dict[str, Any],
) -> None:
# Test the function with a valid DataFrame
fig = ptable_hists(data, symbol_pos=symbol_pos, anno_kwds=anno_kwds)
assert isinstance(
fig, plt.Figure
), "The function should return a matplotlib Figure object"

0 comments on commit 26e5f61

Please sign in to comment.