From b65d9488babe3a9303c086b7f2ae442fcea22885 Mon Sep 17 00:00:00 2001 From: Xu Hong Chen <110699064+xhgchen@users.noreply.github.com> Date: Tue, 15 Aug 2023 14:30:14 -0700 Subject: [PATCH] Create plotting function for viscosity function * Plots result of class `ViscosityHelfand` * Write tests for plotting function --- transport_analysis/tests/test_viscosity.py | 47 ++++++++++++++++++++++ transport_analysis/viscosity.py | 40 ++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/transport_analysis/tests/test_viscosity.py b/transport_analysis/tests/test_viscosity.py index dace793..c72d4fd 100644 --- a/transport_analysis/tests/test_viscosity.py +++ b/transport_analysis/tests/test_viscosity.py @@ -147,6 +147,53 @@ def test_dimtype_error(self, ag, dimtype): with pytest.raises(ValueError, match=errmsg): VH(ag, dim_type=dimtype) + def test_plot_viscosity_function(self, visc_helfand): + # Expected data to be plotted + x_exp = visc_helfand.times + y_exp = visc_helfand.results.timeseries + + # Actual data returned from plot + (line,) = visc_helfand.plot_viscosity_function() + x_act, y_act = line.get_xydata().T + + assert_allclose(x_act, x_exp) + assert_allclose(y_act, y_exp) + + def test_plot_viscosity_function_labels(self, visc_helfand): + # Expected labels + x_exp = "Time (ps)" + y_exp = "Viscosity Function" # TODO: Specify units + + # Actual labels returned from plot + (line,) = visc_helfand.plot_viscosity_function() + x_act = line.axes.get_xlabel() + y_act = line.axes.get_ylabel() + + assert x_act == x_exp + assert y_act == y_exp + + def test_plot_viscosity_function_start_stop_step( + self, visc_helfand, start=1, stop=9, step=2 + ): + # Expected data to be plotted + x_exp = visc_helfand.times[start:stop:step] + y_exp = visc_helfand.results.timeseries[start:stop:step] + + # Actual data returned from plot + (line,) = visc_helfand.plot_viscosity_function( + start=start, stop=stop, step=step + ) + x_act, y_act = line.get_xydata().T + + assert_allclose(x_act, x_exp) + assert_allclose(y_act, y_exp) + + def test_plot_viscosity_function_exception(self, step_vtraj_full): + vis_h = VH(step_vtraj_full.atoms) + errmsg = "Analysis must be run" + with pytest.raises(RuntimeError, match=errmsg): + vis_h.plot_viscosity_function() + @pytest.mark.parametrize( "tdim, tdim_factor", diff --git a/transport_analysis/viscosity.py b/transport_analysis/viscosity.py index b1f104c..84a2cbe 100644 --- a/transport_analysis/viscosity.py +++ b/transport_analysis/viscosity.py @@ -13,6 +13,7 @@ from MDAnalysis.exceptions import NoDataError from MDAnalysis.units import constants import numpy as np +import matplotlib.pyplot as plt if TYPE_CHECKING: from MDAnalysis.core.universe import AtomGroup @@ -87,6 +88,7 @@ def __init__( # local self.atomgroup = atomgroup self.n_particles = len(self.atomgroup) + self._run_called = False def _prepare(self): """ @@ -211,3 +213,41 @@ def _conclude(self): ) # average over # particles and update results array self.results.timeseries = self.results.visc_by_particle.mean(axis=1) + self._run_called = True + + def plot_viscosity_function(self, start=0, stop=0, step=1): + """ + Returns a viscosity function plot via ``Matplotlib``. Usage + of this plot is recommended to help determine where to take the + slope of the viscosity function to obtain the viscosity. + Analysis must be run prior to plotting. + + Parameters + ---------- + start : Optional[int] + The first frame of ``self.results.timeseries`` + used for the plot. + stop : Optional[int] + The frame of ``self.results.timeseries`` to stop at + for the plot, non-inclusive. + step : Optional[int] + Number of frames to skip between each plotted frame. + + Returns + ------- + :class:`matplotlib.lines.Line2D` + A :class:`matplotlib.lines.Line2D` instance with + the desired viscosity function plotting information. + """ + if not self._run_called: + raise RuntimeError("Analysis must be run prior to plotting") + + stop = self.n_frames if stop == 0 else stop + + fig, ax_vacf = plt.subplots() + ax_vacf.set_xlabel("Time (ps)") + ax_vacf.set_ylabel("Viscosity Function") # TODO: Specify units + return ax_vacf.plot( + self.times[start:stop:step], + self.results.timeseries[start:stop:step], + )