Skip to content

Commit

Permalink
Merge pull request #67 from nwinner/plotting
Browse files Browse the repository at this point in the history
Plotting
  • Loading branch information
jmmshn authored Jan 5, 2023
2 parents 72ec88a + 6c5ec70 commit 3d5a99e
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 1 deletion.
209 changes: 208 additions & 1 deletion pymatgen/analysis/defects/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from dataclasses import dataclass
from itertools import groupby
from pathlib import Path
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

import numpy as np
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from monty.json import MSONable
from numpy.typing import ArrayLike, NDArray
from pymatgen.analysis.chempot_diagram import ChemicalPotentialDiagram
Expand Down Expand Up @@ -788,3 +791,207 @@ def fermi_dirac(energy: float, temperature: int | float) -> float:
assuming dilue limit thermodynamics (non-interacting defects) using FD statistics.
"""
return 1.0 / (1.0 + np.exp((energy) / (boltzman_eV_K * temperature)))


def plot_formation_energy_diagrams(
formation_energy_diagrams: FormationEnergyDiagram
| List[FormationEnergyDiagram]
| MultiFormationEnergyDiagram,
chempots: Dict,
alignment: float = 0.0,
xlim: list | None = None,
ylim: list | None = None,
only_lower_envelope: bool = True,
show: bool = True,
save: bool | str = False,
colors: list | None = None,
legend_prefix: str | None = None,
transition_marker: str = "*",
transition_markersize: int = 16,
linestyle: str = "-",
linewidth: int = 4,
envelope_alpha: float = 0.8,
line_alpha: float = 0.5,
band_edge_color="k",
filterfunction: Callable | None = None,
axis=None,
):
"""Plot the formation energy diagram.
Args:
formation_energy_diagrams: Which formation energy lines to plot.
chempots: Chemical potentials at which to plot the formation energy lines
Should be bounded by the chempot_limits property
alignment: shift the energy axis by this amount. For example, giving bandgap/2
will visually shift the 0 reference from the VBM to the middle of the band gap.
xlim: Limits (low, high) to use for the x-axis. Default is to use 0eV for the
VBM up to the band gap, plus a buffer of 0.2eV on each side
ylim: Limits (low, high) to use for y-axis. Default is to use the minimum and
maximum formation energy value of all defects, plus a buffer of 0.1eV
only_lower_envelope: Whether to only plot the lower envolope (concave hull). If
False, then the lower envolope will be highlighted, but all lines will be
plotted.
show: Whether to show the plot.
save: Whether to save the plot. A string can be provided to save to a specific
file. If True, will be saved to current working directory under the name,
formation_energy_diagram.png
colors: Manually select the colors to use. Must have length >= to number of
FormationEnergyDiagrams to plot.
legend_prefix: Prefix for all legend labels
transition_marker: Marker style for the charge transitions
transition_markersize: Size for charge transition markers
linestyle: Matplotlib line style
linewidth: Linewidth for the envelope and lines (if shown)
envelope_alpha: Alpha for the envelope
line_alpha: Alpha for the lines (if the are shown)
band_edge_color: Color for VBM/CBM vertical lines
filterfunction: A callable that filters formation energy diagram objects to clean up the plot
axis: Previous axis to amend
Returns:
Axis subplot
"""
if isinstance(formation_energy_diagrams, MultiFormationEnergyDiagram):
formation_energy_diagrams = formation_energy_diagrams.formation_energy_diagrams
elif isinstance(formation_energy_diagrams, FormationEnergyDiagram):
formation_energy_diagrams = [formation_energy_diagrams]

filterfunction = filterfunction if filterfunction else lambda x: True
formation_energy_diagrams = list(filter(filterfunction, formation_energy_diagrams))

band_gap = formation_energy_diagrams[0].band_gap
if not xlim and not band_gap:
raise ValueError("Must specify xlim or set band_gap attribute")

if not axis:
_, axis = plt.subplots()
if not xlim and band_gap:
xmin, xmax = np.subtract(-0.2, alignment), np.subtract(
band_gap + 0.2, alignment
)
else:
xmin, xmax = xlim
ymin, ymax = 0.0, 1.0
legends_txt = []
artists = []
fontwidth = 12
ax_fontsize = 1.3
lg_fontsize = 10

colors = (
colors
if colors
else cm.Dark2(np.linspace(0, 1, len(formation_energy_diagrams)))
if len(formation_energy_diagrams) <= 8
else cm.gist_rainbow(np.linspace(0, 1, len(formation_energy_diagrams)))
)

for fid, single_fed in enumerate(formation_energy_diagrams):
lines = single_fed._get_lines(chempots=chempots)
lowerlines = get_lower_envelope(lines)
trans = get_transitions(
lowerlines, np.add(xmin, alignment), np.add(xmax, alignment)
)

# plot lines
if not only_lower_envelope:
for line in lines:
x = np.linspace(xmin, xmax)
y = line[0] * x + line[1]
axis.plot(
np.subtract(x, alignment), y, color=colors[fid], alpha=line_alpha
)

# plot connecting envelop lines
for i, (_x, _y) in enumerate(trans[:-1]):
x = np.linspace(_x, trans[i + 1][0])
y = ((trans[i + 1][1] - _y) / (trans[i + 1][0] - _x)) * (x - _x) + _y
axis.plot(
np.subtract(x, alignment),
y,
color=colors[fid],
ls=linestyle,
lw=linewidth,
alpha=envelope_alpha,
)

# Plot transitions
for _x, _y in trans:
ymax = max((ymax, _y))
ymin = min((ymin, _y))
axis.plot(
np.subtract(_x, alignment),
_y,
marker=transition_marker,
color=colors[fid],
markersize=transition_markersize,
)

# get latex-like legend titles
dfct = single_fed.defect_entries[0].defect
flds = dfct.name.split("_")
latexname = f"${flds[0]}_{{{flds[1]}}}$"
if legend_prefix:
latexname = f"{legend_prefix} {latexname}"
legends_txt.append(latexname)
artists.append(Line2D([0], [0], color=colors[fid], lw=4))

axis.set_xlim(xmin, xmax)
axis.set_ylim(ylim[0] if ylim else ymin - 0.1, ylim[1] if ylim else ymax + 0.1)
axis.set_xlabel("Fermi energy (eV)", size=ax_fontsize * fontwidth)
axis.set_ylabel("Defect Formation\nEnergy (eV)", size=ax_fontsize * fontwidth)
axis.minorticks_on()
axis.tick_params(
which="major",
length=8,
width=2,
direction="in",
top=True,
right=True,
labelsize=fontwidth * ax_fontsize,
)
axis.tick_params(
which="minor",
length=2,
width=2,
direction="in",
top=True,
right=True,
labelsize=fontwidth * ax_fontsize,
)
for _ax in axis.spines.values():
_ax.set_linewidth(1.5)

axis.axvline(0, ls="--", color="k", lw=2, alpha=0.2)
axis.axvline(
np.subtract(0, alignment), ls="--", color=band_edge_color, lw=2, alpha=0.8
)
if band_gap:
axis.axvline(
np.subtract(band_gap, alignment),
ls="--",
color=band_edge_color,
lw=2,
alpha=0.8,
)

lg = axis.get_legend()
if lg:
handle, leg = lg.legendHandles, [txt._text for txt in lg.texts]
else:
handle, leg = [], []
axis.legend(
handles=artists + handle,
labels=legends_txt + leg,
fontsize=lg_fontsize * ax_fontsize,
ncol=3,
loc="lower center",
)

if save:
save = save if isinstance(save, str) else "formation_energy_diagram.png"
plt.savefig(save)
if show:
plt.show()

return axis
62 changes: 62 additions & 0 deletions tests/test_thermo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

import numpy as np
import pytest
from matplotlib import pyplot as plt
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.core import PeriodicSite

Expand All @@ -14,6 +17,7 @@
ensure_stable_bulk,
get_lower_envelope,
get_transitions,
plot_formation_energy_diagrams,
)


Expand Down Expand Up @@ -224,3 +228,61 @@ def test_ensure_stable_bulk(stable_entries_Mg_Ga_N):
assert "GaN" not in [e.composition.reduced_formula for e in pd1.stable_entries]
pd2 = ensure_stable_bulk(pd, fake_bulk_ent)
assert "GaN" in [e.composition.reduced_formula for e in pd2.stable_entries]


def test_plotter(data_Mg_Ga, defect_entries_Mg_Ga, stable_entries_Mg_Ga_N, plot_fn):
bulk_vasprun = data_Mg_Ga["bulk_sc"]["vasprun"]
bulk_dos = bulk_vasprun.complete_dos
_, vbm = bulk_dos.get_cbm_vbm()
bulk_entry = bulk_vasprun.get_computed_entry(inc_structure=False)
defect_entries, _ = defect_entries_Mg_Ga
def_ent_list = list(defect_entries.values())

fed = FormationEnergyDiagram(
bulk_entry=bulk_entry,
defect_entries=def_ent_list,
vbm=vbm,
pd_entries=stable_entries_Mg_Ga_N,
inc_inf_values=False,
)
with pytest.raises(
ValueError,
match="Must specify xlim or set band_gap attribute",
):
plot_formation_energy_diagrams(
fed, chempots=fed.chempot_limits[0], show=False, save=False
)
fed.band_gap = 1
axis = plot_formation_energy_diagrams(
fed,
chempots=fed.chempot_limits[0],
show=False,
xlim=[0, 2],
ylim=[0, 4],
save=False,
)
mfed = MultiFormationEnergyDiagram(formation_energy_diagrams=[fed])
plot_formation_energy_diagrams(
mfed,
chempots=fed.chempot_limits[0],
show=False,
save=False,
only_lower_envelope=False,
axis=axis,
legend_prefix="test",
linestyle="--",
line_alpha=1,
linewidth=1,
)
plot_fn(fed, fed.chempot_limits[0])


@pytest.fixture(scope="function")
def plot_fn():
def _plot(*args):
plot_formation_energy_diagrams(*args, save=True, show=True)
yield plt.show()
plt.close("all")
os.remove("formation_energy_diagram.png")

return _plot

0 comments on commit 3d5a99e

Please sign in to comment.