Skip to content

Commit

Permalink
Merge branch 'master' into mp-dos
Browse files Browse the repository at this point in the history
  • Loading branch information
ajjackson authored May 24, 2021
2 parents 6245975 + 294a4a2 commit 06d8765
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ Versioning <http://semver.org/>`__. The changelog format is inspired by
dependencies and different Python versions.
- Fix some incorrect values in Al k-alpha XPS cross-sections
- BUGFIX: Pymatgen CompleteDOS was not correctly accepted by galore.process_pdos()
- Implement previously ineffective "offset" option in
galore.plot.plot_pdos(), add a matching option to
galore.plot.plot_tdos()

`[0.6.1] <https://github.com/smtg-ucl/galore/compare/0.6.0...0.6.1>`__ - 2018-11-19
-----------------------------------------------------------------------------------
Expand Down
13 changes: 8 additions & 5 deletions galore/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def add_overlay(plt, overlay, overlay_scale=None, overlay_offset=0.,


def plot_pdos(pdos_data, ax=None, total=True, show_orbitals=True,
offset=0, flipx=False, **kwargs):
offset=0., flipx=False, **kwargs):
"""Plot a projected density of states (PDOS)
Args:
Expand Down Expand Up @@ -141,9 +141,9 @@ def plot_pdos(pdos_data, ax=None, total=True, show_orbitals=True,
# Field 'energy' must be present, other fields are orbitals
assert 'energy' in el_data.keys()
if flipx:
x_data = -el_data['energy']
x_data = -el_data['energy'] + offset
else:
x_data = el_data['energy']
x_data = el_data['energy'] + offset

orbitals = list(el_data.keys())
orbitals.remove('energy')
Expand Down Expand Up @@ -183,14 +183,15 @@ def plot_pdos(pdos_data, ax=None, total=True, show_orbitals=True,
return plt


def plot_tdos(xdata, ydata, ax=None, **kwargs):
def plot_tdos(xdata, ydata, ax=None, offset=0., **kwargs):
"""Plot a total DOS (i.e. 1D dataset)
Args:
xdata (iterable): x-values (energy, frequency etc.)
ydata (iterable): Corresponding y-values (DOS or measurement intensity)
show (bool): Display plot
offset (float): Energy shift to x-axis
ax (matplotlib.Axes): If provided, plot onto existing Axes object. If
None, a new Figure will be created and the pyplot instance will be
returned.
Expand All @@ -212,7 +213,9 @@ def plot_tdos(xdata, ydata, ax=None, **kwargs):
ax = fig.add_subplot(1, 1, 1)

if kwargs['flipx']:
xdata = -xdata
xdata = -xdata + offset
else:
xdata = xdata + offset

ax.plot(xdata, ydata, 'C0-')
ax.set_xlim([min(xdata), max(xdata)])
Expand Down
19 changes: 12 additions & 7 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,29 @@ def test_plot_pdos(self):
('energy', np.array([1, 2, 3, 4, 5])),
('s', np.array([1, 1, 0, 0, 0])),
('p', np.array([0, 1, 2, 2, 1]))]))])
offset = 0.5

galore.plot.plot_pdos(pdos_data, ax=ax)

galore.plot.plot_pdos(pdos_data, ax=ax, offset=offset)

line1 = ax.lines[0]
xy1 = line1.get_xydata()
self.assertEqual(xy1[1, 0], 2)
self.assertEqual(xy1[1, 0], 2 + offset)
self.assertEqual(xy1[1, 1], 0)

line2 = ax.lines[1]
xy2 = line2.get_xydata()
self.assertEqual(xy2[2, 0], 3)
self.assertEqual(xy2[2, 0], 3 + offset)
self.assertEqual(xy2[2, 1], 4)

line3 = ax.lines[2]
xy3 = line3.get_xydata()
self.assertEqual(xy3[0, 0], 1)
self.assertEqual(xy3[0, 0], 1 + offset)
self.assertEqual(xy2[0, 1], 1)

tdos = ax.lines[4]
xyt = tdos.get_xydata()
self.assertEqual(xyt[2, 0], 3 + offset)
self.assertEqual(xyt[3, 1], 1 + 1 + 0 + 2)


Expand All @@ -99,11 +102,13 @@ def test_plot_tdos(self):
ax = fig.add_subplot(1, 1, 1)

xvals = np.linspace(-5, 5, 21)
galore.plot.plot_tdos(xvals, xvals**2, ax=ax)
offset = 0.8

galore.plot.plot_tdos(xvals, xvals**2, ax=ax, offset=offset)

self.assertEqual(len(ax.lines), 1)
self.assertEqual(ax.lines[0].get_xydata()[11, 0], 0.5)
self.assertAlmostEqual(ax.lines[0].get_xydata()[18, 0], 4.0)
self.assertEqual(ax.lines[0].get_xydata()[11, 0], 0.5 + offset)
self.assertAlmostEqual(ax.lines[0].get_xydata()[18, 1], 4.0**2)

self.assertAlmostEqual(ax.get_ylim()[0], 0)
self.assertAlmostEqual(ax.get_ylim()[1], 1.1 * 5**2)

0 comments on commit 06d8765

Please sign in to comment.