From dbfedf02bb262dd605a31cf970f288d545c0889d Mon Sep 17 00:00:00 2001 From: Jimmy Shen <14003693+jmmshn@users.noreply.github.com> Date: Mon, 31 Jul 2023 13:14:29 -0700 Subject: [PATCH] Optical Transition Matrix Element Plotting (#137) * deps * deps * python vers * teset * tests * cleaner strict * cleaner strict * wf * plotting plotting plotting plotting plotting plotting * plotting * plotting * notebook notebook notebook notebook/test * notebook/test --- docs/source/content/photo-conduct.ipynb | 56 +++- .../analysis/defects/corrections/__init__.py | 1 + .../analysis/defects/plotting/__init__.py | 1 + pymatgen/analysis/defects/plotting/optical.py | 287 ++++++++++++++++++ tests/test_ccd.py | 16 + 5 files changed, 357 insertions(+), 4 deletions(-) create mode 100644 pymatgen/analysis/defects/corrections/__init__.py create mode 100644 pymatgen/analysis/defects/plotting/__init__.py create mode 100644 pymatgen/analysis/defects/plotting/optical.py diff --git a/docs/source/content/photo-conduct.ipynb b/docs/source/content/photo-conduct.ipynb index c88a4070..ba3743bb 100644 --- a/docs/source/content/photo-conduct.ipynb +++ b/docs/source/content/photo-conduct.ipynb @@ -69,7 +69,9 @@ "outputs": [], "source": [ "dir0 = TEST_FILES / \"ccd_0_-1\" / \"optics\"\n", - "hd0 = HarmonicDefect.from_directories(directories=[dir0])\n", + "hd0 = HarmonicDefect.from_directories(directories=[dir0], store_bandstructure=True)\n", + "# Note the `store_bandstructure=True` argument is required for the matrix element plotting later in the notebook.\n", + "# but not required for the dielectric function calculation.\n", "print(f\"The defect band is {hd0.defect_band}\")\n", "print(f\"The vibrational frequency is omega={hd0.omega} in this case is gibberish.\")" ] @@ -107,8 +109,8 @@ " if spin == Spin.up:\n", " return 0\n", " return 1\n", - " occ = vr.eigenvalues[Spin.up][0, :, 1]\n", - " fermi_idx = bisect.bisect_left(occ, -0.5, key=lambda x: -x) \n", + " occ = vr.eigenvalues[Spin.up][0, :, 1] * -1\n", + " fermi_idx = bisect.bisect_left(occ, -0.5) \n", " output = collections.defaultdict(list)\n", " for k, spin_eigs in vr.eigenvalues.items():\n", " spin_idx = _get_spin_idx(k)\n", @@ -195,6 +197,52 @@ "\n", "Of course for a complete picture of photoconductivity, the Frank-Condon type ofr vibrational state transition should also be considered, but we are already pushing the limits of what is acceptable in the independent-particle picture so we will leave that for another time.\n" ] + }, + { + "cell_type": "markdown", + "id": "1fd5f0b0", + "metadata": {}, + "source": [ + "## Dipole Matrix Elements\n", + "\n", + "We can also check the dipole matrix elements for the (VBM)→(defect) and (defect)→(CBM) transitions explicitly by calling the `plot_optical_transitions` method as shown below.\n", + "The function returns a summary `pandas.DataFrame` object with the dipole matrix elements as well as the `ListedColormap` and `Normalize` objects for plotting the colorbar. These objects can then be passed to other instances of the plotting function to ensure that the colorbar is consistent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e00ac9f", + "metadata": {}, + "outputs": [], + "source": [ + "from pymatgen.analysis.defects.plotting.optical import plot_optical_transitions\n", + "import matplotlib as mpl\n", + "fig, ax = plt.subplots()\n", + "cm_ax = fig.add_axes([0.8,0.1,0.02,0.8])\n", + "df_k0, cmap, norm = plot_optical_transitions(hd0, kpt_index=1, band_window=5, x0=3, ax=ax)\n", + "df_k1, _, _ = plot_optical_transitions(hd0, kpt_index=0, band_window=5, x0=0, ax=ax, cmap=cmap, norm=norm)\n", + "mpl.colorbar.ColorbarBase(cm_ax,cmap=cmap,norm=norm,orientation='vertical')\n", + "ax.set_ylabel(\"Energy (eV)\")" + ] + }, + { + "cell_type": "markdown", + "id": "43f9e8d0", + "metadata": {}, + "source": [ + "The `DataFrame` object containing the dipole matrix elements can also be examined directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d44ad8ef", + "metadata": {}, + "outputs": [], + "source": [ + "df_k0" + ] } ], "metadata": { @@ -208,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/pymatgen/analysis/defects/corrections/__init__.py b/pymatgen/analysis/defects/corrections/__init__.py new file mode 100644 index 00000000..baa5ed7d --- /dev/null +++ b/pymatgen/analysis/defects/corrections/__init__.py @@ -0,0 +1 @@ +"""Finite size corrections for defects.""" diff --git a/pymatgen/analysis/defects/plotting/__init__.py b/pymatgen/analysis/defects/plotting/__init__.py new file mode 100644 index 00000000..44232b5d --- /dev/null +++ b/pymatgen/analysis/defects/plotting/__init__.py @@ -0,0 +1 @@ +"""Plotting functions.""" diff --git a/pymatgen/analysis/defects/plotting/optical.py b/pymatgen/analysis/defects/plotting/optical.py new file mode 100644 index 00000000..b5c05a86 --- /dev/null +++ b/pymatgen/analysis/defects/plotting/optical.py @@ -0,0 +1,287 @@ +"""Plotting functions.""" +from __future__ import annotations + +import collections +import logging + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from matplotlib.colors import Normalize +from pymatgen.electronic_structure.core import Spin + +from pymatgen.analysis.defects.ccd import HarmonicDefect + +__author__ = "Jimmy Shen" +__copyright__ = "Copyright 2022, The Materials Project" +__maintainer__ = "Jimmy Shen @jmmshn" +__date__ = "July 2023" + +logger = logging.getLogger(__name__) + + +def plot_optical_transitions( + defect: HarmonicDefect, + kpt_index: int = 0, + band_window: int = 5, + user_defect_band: tuple = tuple(), + ijdirs=((0, 0), (1, 1), (2, 2)), + shift_eig: dict[tuple, float] = None, + x0: float = 0, + x_width: float = 2, + ax=None, + cmap=None, + norm=None, +): + """Plot the optical transitions from the defect state to all other states. + + Only plot the transitions for a specific kpoint index. The arrows present the transitions + between the defect state of interest and all other states. The color of the arrows + indicate the magnitude of the matrix element (derivative of the wavefunction) for the + transition. + + Args: + defect: + The HarmonicDefect object, the `relaxed_bandstructure` attribute + must be set since this contains the eigenvalues. + Please see the `store_bandstructure` option in the constructor. + kpt_index: + The kpoint index to read the eigenvalues from. + band_window: + The number of bands above and below the defect state to include in the output. + user_defect_band: + (band, kpt, spin) tuple to specify the defect state. If not provided, + the defect state will be determined automatically using the inverse + participation ratio and the `kpt_index` argument. + ijdirs: + The cartesian direction of the WAVDER tensor to sum over for the plot. + If not provided, all the absolute values of the matrix for all + three diagonal entries will be summed. + shift_eig: + A dictionary of the format `(band, kpt, spin) -> float` to apply to the + eigenvalues. This is useful for aligning the defect state with the + valence or conduction band for plotting and schematic purposes. + x0: + The x coordinate of the center of the set of arrows and the eigenvalue plot. + x_width: + The width of the set of arrows and the eigenvalue plot. + ax: + The matplotlib axis object to plot on. + cmap: + The matplotlib color map to use for the color of the arrorws. + norm: + The matplotlib normalization to use for the color map of the arrows. + + """ + d_eigs = get_bs_eigenvalues( + defect=defect, + kpt_index=kpt_index, + band_window=band_window, + user_defect_band=user_defect_band, + shift_eig=shift_eig, + ) + if user_defect_band: + defect_band_index = user_defect_band[0] + else: + defect_band_index = next( + filter(lambda x: x[1] == kpt_index, defect.defect_band) + )[0] + + if ax is None: + ax_ = plt.gca() + else: # pragma: no cover + ax_ = ax + _plot_eigs( + d_eigs, defect.relaxed_bandstructure.efermi, ax=ax_, x0=x0, x_width=x_width + ) + me_plot_data, cmap, norm = _plot_matrix_elements( + defect.waveder.cder, + d_eigs, + defect_band_index=defect_band_index, + ijdirs=ijdirs, + ax=ax_, + x0=x0, + x_width=x_width, + cmap=cmap, + norm=norm, + ) + return _get_dataframe(d_eigs=d_eigs, me_plot_data=me_plot_data), cmap, norm + + +def get_bs_eigenvalues( + defect: HarmonicDefect, + kpt_index: int = 0, + band_window: int = 5, + user_defect_band: tuple = tuple(), + shift_eig: dict[tuple, float] = None, +) -> dict[tuple, float]: + """Read the eigenvalues from `HarmonicDefect.relaxed_bandstructure`. + + Args: + defect: + The HarmonicDefect object, the `relaxed_bandstructure` attribute + must be set since this contains the eigenvalues. + Please see the `store_bandstructure` option in the constructor. + kpt_index: + The kpoint index to read the eigenvalues from. + band_window: + The number of bands above and below the Fermi level to include. + user_defect_band: + (band, kpt, spin) tuple to specify the defect state. If not provided, + the defect state will be determined automatically using the inverse + participation ratio. + The user provided kpoint index here will overwrite the kpt_index argument. + + Returns: + Dictionary of the format: (iband, ikpt, ispin) -> eigenvalue + """ + + if defect.relaxed_bandstructure is None: # pragma: no cover + raise ValueError("The defect object does not have a band structure.") + + if user_defect_band: + def_indices = user_defect_band + else: + def_indices = next(filter(lambda x: x[1] == kpt_index, defect.defect_band)) + + band_index, kpt_index, spin_index = def_indices + spin_key = Spin.up if spin_index == 0 else Spin.down + output: dict[tuple, float] = dict() + shift_dict: dict = collections.defaultdict(lambda: 0.0) + if shift_eig is not None: + shift_dict.update(shift_eig) + for ib in range(band_index - band_window, band_index + band_window + 1): + output[(ib, kpt_index, spin_index)] = ( + defect.relaxed_bandstructure.bands[spin_key][ib, kpt_index] + + shift_dict[(ib, kpt_index, spin_index)] + ) + return output + + +def _plot_eigs( + d_eigs: dict[tuple, float], + e_fermi=None, + ax=None, + x0: float = 0.0, + x_width: float = 0.3, + **kwargs, +) -> None: + """Plot the eigenvalues.""" + if ax is None: # pragma: no cover + ax = plt.gca() + + # Use current color scheme + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + collections.defaultdict(list) + eigenvalues = np.array(list(d_eigs.values())) + if e_fermi is None: # pragma: no cover + e_fermi = -np.inf + + eigs_ = eigenvalues[eigenvalues <= e_fermi] + ax.hlines( + eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[0], **kwargs + ) + eigs_ = eigenvalues[eigenvalues > e_fermi] + ax.hlines( + eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[1], **kwargs + ) + + # turn off x-aixs + ax.get_xaxis().set_visible(False) + + +def _plot_matrix_elements( + cder, + d_eig, + defect_band_index, + ijdirs=((0, 0), (1, 1), (2, 2)), + ax=None, + x0=0, + x_width=0.6, + arrow_width=0.1, + cmap=None, + norm=None, +): + """Plot arrow for the transition from the defect state to all other states. + + Args: + cder: + The matrix element (derivative of the wavefunction) for the defect state. + d_eig: + The dictionary of eigenvalues for the defect state. In the format of + (iband, ikpt, ispin) -> eigenvalue + defect_band_index: + The band index of the defect state. + ax: + The matplotlib axis object to plot on. + x0: + The x coordinate of the center of the set of arrows. + x_width: + The width of the set of arrows. + arrow_width: + The width of the arrow. + cmap: + The matplotlib color map to use. + norm: + The matplotlib normalization to use for the color map. + ijdirs: + The cartesian direction of the WAVDER tensor to sum over for the plot. + If not provided, all the absolute values of the matrix for all + three diagonal entries will be summed. + """ + if ax is None: # pragma: no cover + ax = plt.gca() + ax.set_aspect("equal") + jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys())) + y0 = d_eig[jb, jkpt, jspin] + plot_data = [] + for (ib, ik, ispin), eig in d_eig.items(): + A = 0 + for idir, jdir in ijdirs: + A += np.abs( + cder[ib, jb, ik, ispin, idir] + * np.conjugate(cder[ib, jb, ik, ispin, jdir]) + ) + plot_data.append((jb, ib, eig, A)) + + if cmap is None: + cmap = plt.get_cmap("viridis") + + # get the range of A values + if norm is None: + A_min, A_max = ( + min(plot_data, key=lambda x: x[3])[3], + max(plot_data, key=lambda x: x[3])[3], + ) + norm = Normalize(vmin=A_min, vmax=A_max) + + n_arrows = len(plot_data) + x_step = x_width / n_arrows + x = x0 - x_width / 2 + x_step / 2 + for ib, jb, eig, A in plot_data: + ax.arrow( + x=x, + y=y0, + dx=0, + dy=eig - y0, + width=arrow_width, + length_includes_head=True, + head_width=arrow_width * 2, + head_length=arrow_width * 2, + color=cmap(norm(A)), + zorder=20, + ) + x += x_step + return plot_data, cmap, norm + + +def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame: + """Convert the eigenvalue and matrix element data into a pandas dataframe.""" + _, ikpt, ispin = next(iter(d_eigs.keys())) + df = pd.DataFrame( + me_plot_data, + columns=["ib", "jb", "eig", "M.E."], + ) + df["kpt"] = ikpt + df["spin"] = ispin + return df diff --git a/tests/test_ccd.py b/tests/test_ccd.py index 39049176..f2b11941 100644 --- a/tests/test_ccd.py +++ b/tests/test_ccd.py @@ -1,6 +1,7 @@ from collections import namedtuple import numpy as np +import pandas as pd import pytest from pymatgen.analysis.defects.ccd import ( @@ -9,6 +10,7 @@ _get_wswq_slope, plot_pes, ) +from pymatgen.analysis.defects.plotting.optical import plot_optical_transitions @pytest.fixture(scope="session") @@ -151,6 +153,20 @@ def test_dielectric_func(test_dir): assert pytest.approx(inter_vbm, abs=0.01) == 6.31 assert pytest.approx(inter_cbm, abs=0.01) == 0.27 + df, cmap, norm = plot_optical_transitions(hd0, kpt_index=0, band_window=5) + assert isinstance(df, pd.DataFrame) + assert len(df) == 11 + + df, cmap, norm = plot_optical_transitions( + hd0, + kpt_index=-100, + band_window=5, + user_defect_band=(100, 0, 0), + shift_eig={100: 0}, + ) + assert df.iloc[5]["ib"] == 100 + assert df.iloc[5]["jb"] == 100 + def test_plot_pes(hd0): plot_pes(hd0)