Skip to content

Commit

Permalink
Create plotting function for viscosity function
Browse files Browse the repository at this point in the history
* Plots result of class `ViscosityHelfand`
* Write tests for plotting function
  • Loading branch information
xhgchen committed Aug 16, 2023
1 parent c6e196f commit b65d948
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
47 changes: 47 additions & 0 deletions transport_analysis/tests/test_viscosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,53 @@ def test_dimtype_error(self, ag, dimtype):
with pytest.raises(ValueError, match=errmsg):
VH(ag, dim_type=dimtype)

def test_plot_viscosity_function(self, visc_helfand):
# Expected data to be plotted
x_exp = visc_helfand.times
y_exp = visc_helfand.results.timeseries

# Actual data returned from plot
(line,) = visc_helfand.plot_viscosity_function()
x_act, y_act = line.get_xydata().T

assert_allclose(x_act, x_exp)
assert_allclose(y_act, y_exp)

def test_plot_viscosity_function_labels(self, visc_helfand):
# Expected labels
x_exp = "Time (ps)"
y_exp = "Viscosity Function" # TODO: Specify units

# Actual labels returned from plot
(line,) = visc_helfand.plot_viscosity_function()
x_act = line.axes.get_xlabel()
y_act = line.axes.get_ylabel()

assert x_act == x_exp
assert y_act == y_exp

def test_plot_viscosity_function_start_stop_step(
self, visc_helfand, start=1, stop=9, step=2
):
# Expected data to be plotted
x_exp = visc_helfand.times[start:stop:step]
y_exp = visc_helfand.results.timeseries[start:stop:step]

# Actual data returned from plot
(line,) = visc_helfand.plot_viscosity_function(
start=start, stop=stop, step=step
)
x_act, y_act = line.get_xydata().T

assert_allclose(x_act, x_exp)
assert_allclose(y_act, y_exp)

def test_plot_viscosity_function_exception(self, step_vtraj_full):
vis_h = VH(step_vtraj_full.atoms)
errmsg = "Analysis must be run"
with pytest.raises(RuntimeError, match=errmsg):
vis_h.plot_viscosity_function()


@pytest.mark.parametrize(
"tdim, tdim_factor",
Expand Down
40 changes: 40 additions & 0 deletions transport_analysis/viscosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from MDAnalysis.exceptions import NoDataError
from MDAnalysis.units import constants
import numpy as np
import matplotlib.pyplot as plt

if TYPE_CHECKING:
from MDAnalysis.core.universe import AtomGroup
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
# local
self.atomgroup = atomgroup
self.n_particles = len(self.atomgroup)
self._run_called = False

def _prepare(self):
"""
Expand Down Expand Up @@ -211,3 +213,41 @@ def _conclude(self):
)
# average over # particles and update results array
self.results.timeseries = self.results.visc_by_particle.mean(axis=1)
self._run_called = True

def plot_viscosity_function(self, start=0, stop=0, step=1):
"""
Returns a viscosity function plot via ``Matplotlib``. Usage
of this plot is recommended to help determine where to take the
slope of the viscosity function to obtain the viscosity.
Analysis must be run prior to plotting.
Parameters
----------
start : Optional[int]
The first frame of ``self.results.timeseries``
used for the plot.
stop : Optional[int]
The frame of ``self.results.timeseries`` to stop at
for the plot, non-inclusive.
step : Optional[int]
Number of frames to skip between each plotted frame.
Returns
-------
:class:`matplotlib.lines.Line2D`
A :class:`matplotlib.lines.Line2D` instance with
the desired viscosity function plotting information.
"""
if not self._run_called:
raise RuntimeError("Analysis must be run prior to plotting")

stop = self.n_frames if stop == 0 else stop

fig, ax_vacf = plt.subplots()
ax_vacf.set_xlabel("Time (ps)")
ax_vacf.set_ylabel("Viscosity Function") # TODO: Specify units
return ax_vacf.plot(
self.times[start:stop:step],
self.results.timeseries[start:stop:step],
)

0 comments on commit b65d948

Please sign in to comment.