Skip to content

Commit

Permalink
Optical Transition Matrix Element Plotting (#137)
Browse files Browse the repository at this point in the history
* deps

* deps

* python vers

* teset

* tests

* cleaner strict

* cleaner strict

* wf

* plotting

plotting

plotting

plotting

plotting

plotting

* plotting

* plotting

* notebook

notebook

notebook

notebook/test

* notebook/test
  • Loading branch information
jmmshn authored Jul 31, 2023
1 parent fd1cda9 commit dbfedf0
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 4 deletions.
56 changes: 52 additions & 4 deletions docs/source/content/photo-conduct.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
"outputs": [],
"source": [
"dir0 = TEST_FILES / \"ccd_0_-1\" / \"optics\"\n",
"hd0 = HarmonicDefect.from_directories(directories=[dir0])\n",
"hd0 = HarmonicDefect.from_directories(directories=[dir0], store_bandstructure=True)\n",
"# Note the `store_bandstructure=True` argument is required for the matrix element plotting later in the notebook.\n",
"# but not required for the dielectric function calculation.\n",
"print(f\"The defect band is {hd0.defect_band}\")\n",
"print(f\"The vibrational frequency is omega={hd0.omega} in this case is gibberish.\")"
]
Expand Down Expand Up @@ -107,8 +109,8 @@
" if spin == Spin.up:\n",
" return 0\n",
" return 1\n",
" occ = vr.eigenvalues[Spin.up][0, :, 1]\n",
" fermi_idx = bisect.bisect_left(occ, -0.5, key=lambda x: -x) \n",
" occ = vr.eigenvalues[Spin.up][0, :, 1] * -1\n",
" fermi_idx = bisect.bisect_left(occ, -0.5) \n",
" output = collections.defaultdict(list)\n",
" for k, spin_eigs in vr.eigenvalues.items():\n",
" spin_idx = _get_spin_idx(k)\n",
Expand Down Expand Up @@ -195,6 +197,52 @@
"\n",
"Of course for a complete picture of photoconductivity, the Frank-Condon type ofr vibrational state transition should also be considered, but we are already pushing the limits of what is acceptable in the independent-particle picture so we will leave that for another time.\n"
]
},
{
"cell_type": "markdown",
"id": "1fd5f0b0",
"metadata": {},
"source": [
"## Dipole Matrix Elements\n",
"\n",
"We can also check the dipole matrix elements for the (VBM)→(defect) and (defect)→(CBM) transitions explicitly by calling the `plot_optical_transitions` method as shown below.\n",
"The function returns a summary `pandas.DataFrame` object with the dipole matrix elements as well as the `ListedColormap` and `Normalize` objects for plotting the colorbar. These objects can then be passed to other instances of the plotting function to ensure that the colorbar is consistent."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e00ac9f",
"metadata": {},
"outputs": [],
"source": [
"from pymatgen.analysis.defects.plotting.optical import plot_optical_transitions\n",
"import matplotlib as mpl\n",
"fig, ax = plt.subplots()\n",
"cm_ax = fig.add_axes([0.8,0.1,0.02,0.8])\n",
"df_k0, cmap, norm = plot_optical_transitions(hd0, kpt_index=1, band_window=5, x0=3, ax=ax)\n",
"df_k1, _, _ = plot_optical_transitions(hd0, kpt_index=0, band_window=5, x0=0, ax=ax, cmap=cmap, norm=norm)\n",
"mpl.colorbar.ColorbarBase(cm_ax,cmap=cmap,norm=norm,orientation='vertical')\n",
"ax.set_ylabel(\"Energy (eV)\")"
]
},
{
"cell_type": "markdown",
"id": "43f9e8d0",
"metadata": {},
"source": [
"The `DataFrame` object containing the dipole matrix elements can also be examined directly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d44ad8ef",
"metadata": {},
"outputs": [],
"source": [
"df_k0"
]
}
],
"metadata": {
Expand All @@ -208,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions pymatgen/analysis/defects/corrections/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Finite size corrections for defects."""
1 change: 1 addition & 0 deletions pymatgen/analysis/defects/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Plotting functions."""
287 changes: 287 additions & 0 deletions pymatgen/analysis/defects/plotting/optical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Plotting functions."""
from __future__ import annotations

import collections
import logging

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from pymatgen.electronic_structure.core import Spin

from pymatgen.analysis.defects.ccd import HarmonicDefect

__author__ = "Jimmy Shen"
__copyright__ = "Copyright 2022, The Materials Project"
__maintainer__ = "Jimmy Shen @jmmshn"
__date__ = "July 2023"

logger = logging.getLogger(__name__)


def plot_optical_transitions(
defect: HarmonicDefect,
kpt_index: int = 0,
band_window: int = 5,
user_defect_band: tuple = tuple(),
ijdirs=((0, 0), (1, 1), (2, 2)),
shift_eig: dict[tuple, float] = None,
x0: float = 0,
x_width: float = 2,
ax=None,
cmap=None,
norm=None,
):
"""Plot the optical transitions from the defect state to all other states.
Only plot the transitions for a specific kpoint index. The arrows present the transitions
between the defect state of interest and all other states. The color of the arrows
indicate the magnitude of the matrix element (derivative of the wavefunction) for the
transition.
Args:
defect:
The HarmonicDefect object, the `relaxed_bandstructure` attribute
must be set since this contains the eigenvalues.
Please see the `store_bandstructure` option in the constructor.
kpt_index:
The kpoint index to read the eigenvalues from.
band_window:
The number of bands above and below the defect state to include in the output.
user_defect_band:
(band, kpt, spin) tuple to specify the defect state. If not provided,
the defect state will be determined automatically using the inverse
participation ratio and the `kpt_index` argument.
ijdirs:
The cartesian direction of the WAVDER tensor to sum over for the plot.
If not provided, all the absolute values of the matrix for all
three diagonal entries will be summed.
shift_eig:
A dictionary of the format `(band, kpt, spin) -> float` to apply to the
eigenvalues. This is useful for aligning the defect state with the
valence or conduction band for plotting and schematic purposes.
x0:
The x coordinate of the center of the set of arrows and the eigenvalue plot.
x_width:
The width of the set of arrows and the eigenvalue plot.
ax:
The matplotlib axis object to plot on.
cmap:
The matplotlib color map to use for the color of the arrorws.
norm:
The matplotlib normalization to use for the color map of the arrows.
"""
d_eigs = get_bs_eigenvalues(
defect=defect,
kpt_index=kpt_index,
band_window=band_window,
user_defect_band=user_defect_band,
shift_eig=shift_eig,
)
if user_defect_band:
defect_band_index = user_defect_band[0]
else:
defect_band_index = next(
filter(lambda x: x[1] == kpt_index, defect.defect_band)
)[0]

if ax is None:
ax_ = plt.gca()
else: # pragma: no cover
ax_ = ax
_plot_eigs(
d_eigs, defect.relaxed_bandstructure.efermi, ax=ax_, x0=x0, x_width=x_width
)
me_plot_data, cmap, norm = _plot_matrix_elements(
defect.waveder.cder,
d_eigs,
defect_band_index=defect_band_index,
ijdirs=ijdirs,
ax=ax_,
x0=x0,
x_width=x_width,
cmap=cmap,
norm=norm,
)
return _get_dataframe(d_eigs=d_eigs, me_plot_data=me_plot_data), cmap, norm


def get_bs_eigenvalues(
defect: HarmonicDefect,
kpt_index: int = 0,
band_window: int = 5,
user_defect_band: tuple = tuple(),
shift_eig: dict[tuple, float] = None,
) -> dict[tuple, float]:
"""Read the eigenvalues from `HarmonicDefect.relaxed_bandstructure`.
Args:
defect:
The HarmonicDefect object, the `relaxed_bandstructure` attribute
must be set since this contains the eigenvalues.
Please see the `store_bandstructure` option in the constructor.
kpt_index:
The kpoint index to read the eigenvalues from.
band_window:
The number of bands above and below the Fermi level to include.
user_defect_band:
(band, kpt, spin) tuple to specify the defect state. If not provided,
the defect state will be determined automatically using the inverse
participation ratio.
The user provided kpoint index here will overwrite the kpt_index argument.
Returns:
Dictionary of the format: (iband, ikpt, ispin) -> eigenvalue
"""

if defect.relaxed_bandstructure is None: # pragma: no cover
raise ValueError("The defect object does not have a band structure.")

if user_defect_band:
def_indices = user_defect_band
else:
def_indices = next(filter(lambda x: x[1] == kpt_index, defect.defect_band))

band_index, kpt_index, spin_index = def_indices
spin_key = Spin.up if spin_index == 0 else Spin.down
output: dict[tuple, float] = dict()
shift_dict: dict = collections.defaultdict(lambda: 0.0)
if shift_eig is not None:
shift_dict.update(shift_eig)
for ib in range(band_index - band_window, band_index + band_window + 1):
output[(ib, kpt_index, spin_index)] = (
defect.relaxed_bandstructure.bands[spin_key][ib, kpt_index]
+ shift_dict[(ib, kpt_index, spin_index)]
)
return output


def _plot_eigs(
d_eigs: dict[tuple, float],
e_fermi=None,
ax=None,
x0: float = 0.0,
x_width: float = 0.3,
**kwargs,
) -> None:
"""Plot the eigenvalues."""
if ax is None: # pragma: no cover
ax = plt.gca()

# Use current color scheme
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
collections.defaultdict(list)
eigenvalues = np.array(list(d_eigs.values()))
if e_fermi is None: # pragma: no cover
e_fermi = -np.inf

eigs_ = eigenvalues[eigenvalues <= e_fermi]
ax.hlines(
eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[0], **kwargs
)
eigs_ = eigenvalues[eigenvalues > e_fermi]
ax.hlines(
eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[1], **kwargs
)

# turn off x-aixs
ax.get_xaxis().set_visible(False)


def _plot_matrix_elements(
cder,
d_eig,
defect_band_index,
ijdirs=((0, 0), (1, 1), (2, 2)),
ax=None,
x0=0,
x_width=0.6,
arrow_width=0.1,
cmap=None,
norm=None,
):
"""Plot arrow for the transition from the defect state to all other states.
Args:
cder:
The matrix element (derivative of the wavefunction) for the defect state.
d_eig:
The dictionary of eigenvalues for the defect state. In the format of
(iband, ikpt, ispin) -> eigenvalue
defect_band_index:
The band index of the defect state.
ax:
The matplotlib axis object to plot on.
x0:
The x coordinate of the center of the set of arrows.
x_width:
The width of the set of arrows.
arrow_width:
The width of the arrow.
cmap:
The matplotlib color map to use.
norm:
The matplotlib normalization to use for the color map.
ijdirs:
The cartesian direction of the WAVDER tensor to sum over for the plot.
If not provided, all the absolute values of the matrix for all
three diagonal entries will be summed.
"""
if ax is None: # pragma: no cover
ax = plt.gca()
ax.set_aspect("equal")
jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys()))
y0 = d_eig[jb, jkpt, jspin]
plot_data = []
for (ib, ik, ispin), eig in d_eig.items():
A = 0
for idir, jdir in ijdirs:
A += np.abs(
cder[ib, jb, ik, ispin, idir]
* np.conjugate(cder[ib, jb, ik, ispin, jdir])
)
plot_data.append((jb, ib, eig, A))

if cmap is None:
cmap = plt.get_cmap("viridis")

# get the range of A values
if norm is None:
A_min, A_max = (
min(plot_data, key=lambda x: x[3])[3],
max(plot_data, key=lambda x: x[3])[3],
)
norm = Normalize(vmin=A_min, vmax=A_max)

n_arrows = len(plot_data)
x_step = x_width / n_arrows
x = x0 - x_width / 2 + x_step / 2
for ib, jb, eig, A in plot_data:
ax.arrow(
x=x,
y=y0,
dx=0,
dy=eig - y0,
width=arrow_width,
length_includes_head=True,
head_width=arrow_width * 2,
head_length=arrow_width * 2,
color=cmap(norm(A)),
zorder=20,
)
x += x_step
return plot_data, cmap, norm


def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame:
"""Convert the eigenvalue and matrix element data into a pandas dataframe."""
_, ikpt, ispin = next(iter(d_eigs.keys()))
df = pd.DataFrame(
me_plot_data,
columns=["ib", "jb", "eig", "M.E."],
)
df["kpt"] = ikpt
df["spin"] = ispin
return df
Loading

0 comments on commit dbfedf0

Please sign in to comment.