-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SurvivalProbability - Readability, performance, and algorithm changes #1995
Changes from all commits
69be6dc
3e65d43
4099187
b5e1ab2
967a569
ad4b44b
9cf3928
afb48ac
d71ed58
b711cb2
8eb1854
42704d7
9370e8f
87b6d34
3655647
c89a690
ea1c744
683f1cb
ebd1b33
53abc62
d471bd0
49882ff
5bd2e0b
d194cee
87cce21
cc4db35
c166ab2
505f549
affc0ab
a064686
b827bc0
cdcd942
351fc5d
6d35ed9
99ada4d
522e717
da0f594
cd2abb9
054228b
3928107
3cc426d
01e6fec
7cf8318
d8f4a60
b5d07e9
d348cc7
0ef1057
8470c4c
40e926d
c66f3db
ab21e39
a53a673
9a20846
9f8c677
c01c994
036ccc6
4b1ee14
2c36f37
5988f96
46a7f69
8536d54
2a1df89
441d5ba
8106a6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -283,23 +283,24 @@ | |
|
||
import MDAnalysis | ||
from MDAnalysis.analysis.waterdynamics import SurvivalProbability as SP | ||
import matplotlib.pyplot as plt | ||
|
||
u = MDAnalysis.Universe(pdb, trajectory) | ||
universe = MDAnalysis.Universe(pdb, trajectory) | ||
selection = "byres name OH2 and sphzone 12.3 (resid 42 or resid 26 or resid 34 or resid 80) " | ||
SP_analysis = SP(universe, selection, 0, 100, 20) | ||
SP_analysis.run() | ||
#now we print data ready to graph. The graph | ||
#represents SP vs t | ||
time = 0 | ||
for sp in SP_analysis.timeseries: | ||
print("{time} {sp}".format(time=time, sp=sp)) | ||
time += 1 | ||
sp = SP(universe, selection, verbose=True) | ||
sp.run(start=0, stop=100, tau_max=20) | ||
tau_timeseries = sp.tau_timeseries | ||
sp_timeseries = sp.sp_timeseries | ||
|
||
#Plot | ||
plt.xlabel('time') | ||
# print in console | ||
for tau, sp in zip(tau_timeseries, sp_timeseries): | ||
print("{time} {sp}".format(time=tau, sp=sp)) | ||
|
||
# plot | ||
plt.xlabel('Time') | ||
plt.ylabel('SP') | ||
plt.title('Survival Probability') | ||
plt.plot(range(0,time),MSD_analysis.timeseries) | ||
plt.plot(taus, sp_timeseries) | ||
plt.show() | ||
|
||
|
||
|
@@ -376,16 +377,13 @@ | |
SurvivalProbability | ||
~~~~~~~~~~~~~~~~~~~ | ||
|
||
Survival Probability (SP) data is returned in a list, which each element | ||
represents a SP value in its respective window timestep. Data is stored in | ||
:attr:`SurvivalProbability.timeseries`:: | ||
|
||
results = [ | ||
# SP values order by window timestep | ||
<SP_t0>, <SP_t1>, ... | ||
] | ||
Survival Probability (SP) computes two lists: a list of taus (:attr:`SurvivalProbability.tau_timeseries`) and a list of their corresponding mean survival | ||
probabilities (:attr:`SurvivalProbability.sp_timeseries`). Additionally, a list :attr:`SurvivalProbability.sp_timeseries_data` is provided which contains | ||
a list of SPs for each tau, which can be used to compute their distribution, etc. | ||
|
||
results = [ tau1, tau2, ..., tau_n ], [ sp_tau1, sp_tau2, ..., sp_tau_n] | ||
|
||
Additionally, for each | ||
|
||
Classes | ||
-------- | ||
|
@@ -412,6 +410,9 @@ | |
|
||
""" | ||
from __future__ import print_function, division, absolute_import | ||
|
||
import warnings | ||
|
||
from six.moves import range, zip_longest | ||
|
||
import numpy as np | ||
|
@@ -1185,7 +1186,8 @@ class SurvivalProbability(object): | |
P(\tau) = \frac1T \sum_{t=1}^T \frac{N(t,t+\tau)}{N(t)} | ||
|
||
where :math:`T` is the maximum time of simulation, :math:`\tau` is the | ||
timestep and :math:`N` the number of particles in certain time. | ||
timestep, :math:`N(t)` the number of particles at time t, and | ||
:math:`N(t, t+\tau)` is the number of particles at every frame from t to `\tau`. | ||
|
||
|
||
Parameters | ||
|
@@ -1194,99 +1196,111 @@ class SurvivalProbability(object): | |
Universe object | ||
selection : str | ||
Selection string; any selection is allowed. With this selection you | ||
define the region/zone where to analyze, e.g.: "selection_a" and "zone" | ||
(see `SP-examples`_ ) | ||
t0 : int | ||
frame where analysis begins | ||
tf : int | ||
frame where analysis ends | ||
dtmax : int | ||
Maximum dt size, `dtmax` < `tf` or it will crash. | ||
define the region/zone where to analyze, e.g.: "resname SOL and around 5 (resname LIPID)" | ||
and "resname ION and around 10 (resid 20)" (see `SP-examples`_ ) | ||
verbose : Boolean | ||
If True, prints progress and comments to the console. | ||
|
||
|
||
.. versionadded:: 0.11.0 | ||
|
||
""" | ||
|
||
def __init__(self, universe, selection, t0, tf, dtmax): | ||
def __init__(self, universe, selection, t0=None, tf=None, dtmax=None, verbose=False): | ||
self.universe = universe | ||
self.selection = selection | ||
self.t0 = t0 | ||
self.tf = tf | ||
self.dtmax = dtmax | ||
self.timeseries = [] | ||
self.verbose = verbose | ||
|
||
# backward compatibility | ||
self.start = self.stop = self.tau_max = None | ||
if t0 is not None: | ||
self.start = t0 | ||
warnings.warn("t0 is deprecated, use run(start=t0) instead", category=DeprecationWarning) | ||
|
||
def run(self): | ||
"""Analyze trajectory and produce timeseries""" | ||
if tf is not None: | ||
self.stop = tf | ||
warnings.warn("tf is deprecated, use run(stop=tf) instead", category=DeprecationWarning) | ||
|
||
# select all frames to an array | ||
selected = self._selection_serial(self.universe, self.selection) | ||
if dtmax is not None: | ||
self.tau_max = dtmax | ||
warnings.warn("dtmax is deprecated, use run(tau_max=dtmax) instead", category=DeprecationWarning) | ||
|
||
if len(selected) < self.dtmax: | ||
print ("ERROR: Cannot select fewer frames than dtmax") | ||
return | ||
def print(self, verbose, *args): | ||
if self.verbose: | ||
print(args) | ||
elif verbose: | ||
print(args) | ||
|
||
for window_size in list(range(1, self.dtmax + 1)): | ||
output = self._getMeanOnePoint(selected, window_size) | ||
self.timeseries.append(output) | ||
def run(self, tau_max=20, start=0, stop=None, step=1, verbose=False): | ||
""" | ||
Computes and returns the survival probability timeseries | ||
|
||
Parameters | ||
---------- | ||
start : int | ||
Zero-based index of the first frame to be analysed | ||
stop : int | ||
Zero-based index of the last frame to be analysed (inclusive) | ||
step : int | ||
Jump every `step`'th frame | ||
tau_max : int | ||
Survival probability is calculated for the range :math:`1 <= \tau <= tau_max` | ||
verbose : Boolean | ||
Overwrite the constructor's verbosity | ||
|
||
Returns | ||
------- | ||
tau_timeseries : list | ||
tau from 1 to tau_max. Saved in the field tau_timeseries. | ||
sp_timeseries : list | ||
survival probability for each value of `tau`. Saved in the field sp_timeseries. | ||
""" | ||
|
||
# backward compatibility (and priority) | ||
start = self.start if self.start is not None else start | ||
stop = self.stop if self.stop is not None else stop | ||
tau_max = self.tau_max if self.tau_max is not None else tau_max | ||
|
||
def _selection_serial(self, universe, selection_str): | ||
selected = [] | ||
pm = ProgressMeter(self.tf-self.t0, interval=10, | ||
verbose=True, offset=-self.t0) | ||
for ts in universe.trajectory[self.t0:self.tf]: | ||
selected.append(universe.select_atoms(selection_str)) | ||
pm.echo(ts.frame) | ||
return selected | ||
# sanity checks | ||
if stop is not None and stop >= len(self.universe.trajectory): | ||
raise ValueError("\"stop\" must be smaller than the number of frames in the trajectory.") | ||
|
||
if stop is None: | ||
stop = len(self.universe.trajectory) | ||
else: | ||
stop = stop + 1 | ||
|
||
def _getMeanOnePoint(self, selected, window_size): | ||
""" | ||
This function gets one point of the plot P(t) vs t. It uses the | ||
_getOneDeltaPoint() function to calculate the average. | ||
""" | ||
n = 0 | ||
sumDeltaP = 0.0 | ||
for frame_no in range(len(selected) - window_size): | ||
delta = self._getOneDeltaPoint(selected, frame_no, window_size) | ||
sumDeltaP += delta | ||
n += 1 | ||
if tau_max > (stop - start): | ||
raise ValueError("Too few frames selected for given tau_max.") | ||
|
||
return sumDeltaP/n | ||
# load all frames to an array of sets | ||
selected_ids = [] | ||
for ts in self.universe.trajectory[start:stop]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I see you're just using step below. So every frame is a start point, but you only do deltas of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, not sure if I get the question. So this step here is tricky. I first load all frames (loading the right ones and considering the step makes it very tricky). And then during the analysis, I start the gathering the tau every step. So assume step=5 and max tau is 3. We'll have: t = 0: The main speed gain would be from the ensuring that frame 4, 9, ... etc would not be loaded. I will think about this how to get this right and not overcomplicate the algorithm. |
||
self.print(verbose, "Loading frame:", ts) | ||
selected_ids.append(set(self.universe.select_atoms(self.selection).ids)) | ||
|
||
tau_timeseries = np.arange(1, tau_max + 1) | ||
sp_timeseries_data = [[] for _ in range(tau_max)] | ||
|
||
def _getOneDeltaPoint(self, selected, t, tau): | ||
""" | ||
Gives one point to calculate the mean and | ||
gets one point of the plot C_vect vs t. | ||
- Ex: t=1 and tau=1 calculates | ||
how many selected water molecules survive from the frame 1 to 2 | ||
- Ex: t=5 and tau=3 calculates | ||
how many selected water molecules survive from the frame 5 to 8 | ||
""" | ||
for t in range(0, len(selected_ids), step): | ||
Nt = len(selected_ids[t]) | ||
|
||
Nt = len(selected[t]) | ||
if Nt == 0: | ||
return 0 | ||
if Nt == 0: | ||
self.print(verbose, | ||
"At frame {} the selection did not find any molecule. Moving on to the next frame".format(t)) | ||
continue | ||
|
||
# fraction of water molecules that survived | ||
Ntau = self._NumPart_tau(selected, t, tau) | ||
return Ntau/Nt | ||
for tau in tau_timeseries: | ||
if t + tau >= len(selected_ids): | ||
break | ||
|
||
# ids that survive from t to t + tau and at every frame in between | ||
Ntau = len(set.intersection(*selected_ids[t:t + tau + 1])) | ||
sp_timeseries_data[tau - 1].append(Ntau / float(Nt)) | ||
|
||
def _NumPart_tau(self, selected, t, tau): | ||
""" | ||
Compares the molecules in t selection and t+tau selection and | ||
select only the particles that remain from t to t+tau and | ||
at each point in between. | ||
It returns the number of remaining particles. | ||
""" | ||
survivors = set(selected[t]) | ||
i = 0 | ||
while (t + i) < t + tau and (t + i) < len(selected): | ||
next = set(selected[t + i]) | ||
survivors = survivors.intersection(next) | ||
i += 1 | ||
return len(survivors) | ||
# user can investigate the distribution and sample size | ||
self.sp_timeseries_data = sp_timeseries_data | ||
|
||
self.tau_timeseries = tau_timeseries | ||
self.sp_timeseries = [np.mean(sp) for sp in sp_timeseries_data] | ||
return self |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,29 +22,25 @@ | |
from __future__ import print_function, absolute_import | ||
import MDAnalysis | ||
from MDAnalysis.analysis import waterdynamics | ||
import pytest | ||
|
||
from MDAnalysisTests.datafiles import waterPSF, waterDCD | ||
from MDAnalysisTests.datafiles import PDB, XTC | ||
|
||
import pytest | ||
import numpy as np | ||
from mock import patch | ||
from mock import Mock | ||
from numpy.testing import assert_almost_equal | ||
|
||
SELECTION1 = "byres name OH2" | ||
SELECTION2 = "byres name P1" | ||
SELECTION3 = "around 4 (resid 151 and name OE1)" | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def universe(): | ||
return MDAnalysis.Universe(waterPSF, waterDCD) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def universe_prot(): | ||
return MDAnalysis.Universe(PDB, XTC) | ||
|
||
|
||
def test_HydrogenBondLifetimes(universe): | ||
hbl = waterdynamics.HydrogenBondLifetimes( | ||
universe, SELECTION1, SELECTION1, 0, 5, 3) | ||
|
@@ -89,20 +85,37 @@ def test_MeanSquareDisplacement_zeroMolecules(universe): | |
assert_almost_equal(msd_zero.timeseries[1], 0.0) | ||
|
||
|
||
def test_SurvivalProbability(universe_prot): | ||
sp = waterdynamics.SurvivalProbability(universe_prot, SELECTION3, 0, 10, 4) | ||
sp.run() | ||
assert_almost_equal(sp.timeseries, [1.0, 0.354, 0.267, 0.242], decimal=3) | ||
|
||
def test_SurvivalProbability_t0tf(universe): | ||
with patch.object(universe, 'select_atoms') as select_atoms_mock: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I love these mock tests! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks |
||
ids = [(0, ), (0, ), (7, 6, 5), (6, 5, 4), (5, 4, 3), (4, 3, 2), (3, 2, 1), (0, )] | ||
select_atoms_mock.side_effect = lambda selection: Mock(ids=ids.pop(2)) # atom IDs fed set by set | ||
sp = waterdynamics.SurvivalProbability(universe, "") | ||
sp.run(tau_max=3, start=2, stop=6) | ||
assert_almost_equal(sp.sp_timeseries, [2 / 3.0, 1 / 3.0, 0]) | ||
|
||
def test_SurvivalProbability_t0Ignored(universe_prot): | ||
sp = waterdynamics.SurvivalProbability(universe_prot, SELECTION3, 3, 10, 4) | ||
sp.run() | ||
assert_almost_equal(sp.timeseries, [1.0, 0.391, 0.292, 0.261], decimal=3) | ||
|
||
def test_SurvivalProbability_definedTaus(universe): | ||
with patch.object(universe, 'select_atoms') as select_atoms_mock: | ||
ids = [(9, 8, 7), (8, 7, 6), (7, 6, 5), (6, 5, 4), (5, 4, 3), (4, 3, 2), (3, 2, 1)] | ||
select_atoms_mock.side_effect = lambda selection: Mock(ids=ids.pop()) # atom IDs fed set by set | ||
sp = waterdynamics.SurvivalProbability(universe, "") | ||
sp.run(tau_max=3, start=0, stop=6) | ||
assert_almost_equal(sp.sp_timeseries, [2 / 3.0, 1 / 3.0, 0]) | ||
|
||
|
||
def test_SurvivalProbability_zeroMolecules(universe): | ||
sp_zero = waterdynamics.SurvivalProbability(universe, SELECTION2, 0, 6, 3) | ||
sp_zero.run() | ||
assert_almost_equal(sp_zero.timeseries[1], 0.0) | ||
with patch.object(universe, 'select_atoms') as select_atoms_mock: | ||
# no atom IDs found | ||
select_atoms_mock.return_value = Mock(ids=[]) | ||
sp = waterdynamics.SurvivalProbability(universe, "") | ||
sp.run(tau_max=3, start=3, stop=6) | ||
assert all(np.isnan(sp.sp_timeseries)) | ||
|
||
|
||
def test_SurvivalProbability_alwaysPresent(universe): | ||
with patch.object(universe, 'select_atoms') as select_atoms_mock: | ||
# always the same atom IDs found | ||
select_atoms_mock.return_value = Mock(ids=[7, 8]) | ||
sp = waterdynamics.SurvivalProbability(universe, "") | ||
sp.run(tau_max=3, start=0, stop=6) | ||
assert all(np.equal(sp.sp_timeseries, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add both of your names to the author list contributing to this release at the top. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assme you mean just above the 0.18.1 in CHANGELOG. Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes just above the 0.18.1. I'm not sure if you are done yet. Github doesn't show the new names. We tend to use github names for that list.