Skip to content

Commit

Permalink
Merge pull request #4 from fjclark/min_mcse
Browse files Browse the repository at this point in the history
Min mcse
  • Loading branch information
fjclark authored Jan 22, 2024
2 parents 0e501e1 + cfddc77 commit 1c810b3
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 72 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ deea
==============================
[//]: # (Badges)
[![GitHub Actions Build Status](https://github.com/fjclark/deea/workflows/CI/badge.svg)](https://github.com/fjclark/deea/actions?query=workflow%3ACI)
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/fff40e5573f847399bee98eef495f8c6)](https://app.codacy.com/gh/fjclark/deea/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
[![codecov](https://codecov.io/gh/fjclark/deea/branch/main/graph/badge.svg)](https://codecov.io/gh/fjclark/deea/branch/main)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
Expand Down
6 changes: 3 additions & 3 deletions deea/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from warnings import warn as _warn

import numpy as np
import numpy as _np

from ._exceptions import InvalidInputError


def check_data(data: np.ndarray, one_dim_allowed: bool = False) -> np.ndarray:
def check_data(data: _np.ndarray, one_dim_allowed: bool = False) -> _np.ndarray:
"""
Assert that data passed is a numpy array where
the first dimension is the number of chains and
Expand All @@ -29,7 +29,7 @@ def check_data(data: np.ndarray, one_dim_allowed: bool = False) -> np.ndarray:
Data with shape (n_chains, n_samples).
"""
# Check that data is a numpy array.
if not isinstance(data, np.ndarray):
if not isinstance(data, _np.ndarray):
raise InvalidInputError("Data must be a numpy array.")

# Check that data has right number of dimensions.
Expand Down
34 changes: 17 additions & 17 deletions deea/equilibration.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Functions for selecting the equilibration time of a time series."""

from pathlib import Path
from pathlib import Path as _Path
from typing import Callable as _Callable
from typing import Optional as _Optional
from typing import Tuple as _Tuple
from typing import Union as _Union

import matplotlib.pyplot as plt
import matplotlib.pyplot as _plt
import numpy as _np
from matplotlib import gridspec
from scipy.stats import ttest_rel
from matplotlib import gridspec as _gridspec
from scipy.stats import ttest_rel as _ttest_rel

from ._exceptions import EquilibrationNotDetectedError, InvalidInputError
from ._validation import check_data
Expand All @@ -30,7 +30,7 @@ def detect_equilibration_init_seq(
smooth_lag_times: bool = False,
frac_padding: float = 0.1,
plot: bool = False,
plot_name: _Union[str, Path] = "equilibration_sse_init_seq.png",
plot_name: _Union[str, _Path] = "equilibration_sse_init_seq.png",
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
plot_max_lags: bool = True,
Expand Down Expand Up @@ -112,7 +112,7 @@ def detect_equilibration_init_seq(
"""
# Check that data is valid.
data = check_data(data, one_dim_allowed=True)
n_runs, n_samples = data.shape
_, n_samples = data.shape

# Check that method is valid.
valid_methods = ["min_sse", "max_ess"]
Expand Down Expand Up @@ -163,8 +163,8 @@ def detect_equilibration_init_seq(

if plot:
# Create a figure.
fig = plt.figure(figsize=(6, 4))
gridspec_obj = gridspec.GridSpec(1, 1, figure=fig)
fig = _plt.figure(figsize=(6, 4))
gridspec_obj = _gridspec.GridSpec(1, 1, figure=fig)

# Plot the ESS.
plot_equilibration_min_sse(
Expand Down Expand Up @@ -196,7 +196,7 @@ def detect_equilibration_window(
window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**1 / 2),
window_size: _Optional[int] = None,
plot: bool = False,
plot_name: _Union[str, Path] = "equilibration_sse_window.png",
plot_name: _Union[str, _Path] = "equilibration_sse_window.png",
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
plot_window_size: bool = True,
Expand Down Expand Up @@ -307,8 +307,8 @@ def detect_equilibration_window(

if plot:
# Create a figure.
fig = plt.figure(figsize=(6, 4))
gridspec_obj = gridspec.GridSpec(1, 1, figure=fig)
fig = _plt.figure(figsize=(6, 4))
gridspec_obj = _gridspec.GridSpec(1, 1, figure=fig)

# Plot the ESS.
plot_equilibration_min_sse(
Expand Down Expand Up @@ -456,9 +456,9 @@ def get_paired_t_p_timeseries(
:, -round(n_truncated_samples * final_block_size) :
].mean(axis=1)
# Compute the paired t-test.
p_vals[i] = ttest_rel(initial_block, final_block, alternative=t_test_sidedness)[
1
]
p_vals[i] = _ttest_rel(
initial_block, final_block, alternative=t_test_sidedness
)[1]
p_val_indices[i] = idx
time_vals[i] = times[idx]

Expand All @@ -475,7 +475,7 @@ def detect_equilibration_paired_t_test(
final_block_size: float = 0.5,
t_test_sidedness: str = "two-sided",
plot: bool = False,
plot_name: _Union[str, Path] = "equilibration_paired_t_test.png",
plot_name: _Union[str, _Path] = "equilibration_paired_t_test.png",
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
) -> int:
Expand Down Expand Up @@ -573,8 +573,8 @@ def detect_equilibration_paired_t_test(

# Plot the p values.
if plot:
fig = plt.figure(figsize=(6, 4))
gridspec_obj = gridspec.GridSpec(1, 1, figure=fig)
fig = _plt.figure(figsize=(6, 4))
gridspec_obj = _gridspec.GridSpec(1, 1, figure=fig)
plot_equilibration_paired_t_test(
fig=fig,
subplot_spec=gridspec_obj[0],
Expand Down
10 changes: 5 additions & 5 deletions deea/gelman_rubin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Compute the Gelman-Rubin diagnostic."""

import numpy as np
import numpy as _np

from ._validation import check_data
from .variance import inter_run_variance, intra_run_variance, lugsail_variance


def gelman_rubin(data: np.ndarray) -> float:
def gelman_rubin(data: _np.ndarray) -> float:
"""
Compute the Gelman-Rubin diagnostic according to
equation 4 in Statist. Sci. 36(4): 518-529
Expand Down Expand Up @@ -38,12 +38,12 @@ def gelman_rubin(data: np.ndarray) -> float:
) / n_samples * intra_run_variance_est + inter_run_variance_est / n_samples

# Compute GR diagnostic.
gelman_rubin_diagnostic = np.sqrt(combined_variance_est / intra_run_variance_est)
gelman_rubin_diagnostic = _np.sqrt(combined_variance_est / intra_run_variance_est)

return gelman_rubin_diagnostic


def stable_gelman_rubin(data: np.ndarray, n_pow: float = 1 / 3) -> float:
def stable_gelman_rubin(data: _np.ndarray, n_pow: float = 1 / 3) -> float:
"""
Compute the stable Gelman-Rubin diagnostic according to
equation 7 in Statist. Sci. 36(4): 518-529
Expand All @@ -64,6 +64,6 @@ def stable_gelman_rubin(data: np.ndarray, n_pow: float = 1 / 3) -> float:
) / n_samples * intra_run_variance_est + lugsail_variance_est / n_samples

# Compute GR diagnostic.
gelman_rubin_diagnostic = np.sqrt(combined_variance_est / intra_run_variance_est)
gelman_rubin_diagnostic = _np.sqrt(combined_variance_est / intra_run_variance_est)

return gelman_rubin_diagnostic
95 changes: 51 additions & 44 deletions deea/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@

from typing import List as _List

import matplotlib.pyplot as plt
from matplotlib import figure, gridspec
from matplotlib.axes import Axes
import matplotlib.pyplot as _plt
from matplotlib import figure as _figure
from matplotlib import gridspec as _gridspec
from matplotlib.axes import Axes as _Axes

plt.style.use("ggplot")
_plt.style.use("ggplot")
from typing import Optional as _Optional
from typing import Tuple as _Tuple

import numpy as np
from matplotlib.lines import Line2D
import numpy as _np
from matplotlib.lines import Line2D as _Line2D

from ._exceptions import InvalidInputError
from ._validation import check_data


def plot_timeseries(
ax: Axes,
data: np.ndarray,
times: np.ndarray,
ax: _Axes,
data: _np.ndarray,
times: _np.ndarray,
show_ci: bool = True,
n_blocks: int = 100,
time_units: str = "ns",
Expand Down Expand Up @@ -62,7 +63,7 @@ def plot_timeseries(
n_runs, n_samples = data.shape

# Check that times is valid.
if not isinstance(times, np.ndarray):
if not isinstance(times, _np.ndarray):
raise InvalidInputError("Times must be a numpy array.")

if times.ndim != 1:
Expand Down Expand Up @@ -130,12 +131,12 @@ def plot_timeseries(


def plot_p_values(
ax: Axes,
p_values: np.ndarray,
times: np.ndarray,
ax: _Axes,
p_values: _np.ndarray,
times: _np.ndarray,
p_threshold: float = 0.05,
time_units: str = "ns",
threshold_times: _Optional[np.ndarray] = None,
threshold_times: _Optional[_np.ndarray] = None,
) -> None:
"""
Plot the p-values of the paired t-test.
Expand Down Expand Up @@ -163,7 +164,7 @@ def plot_p_values(
using this plot underneath a time series plot.
"""
# Check that p_values is valid.
if not isinstance(p_values, np.ndarray) or not isinstance(times, np.ndarray):
if not isinstance(p_values, _np.ndarray) or not isinstance(times, _np.ndarray):
raise InvalidInputError("p_values and times must be numpy arrays.")

if p_values.ndim != 1:
Expand All @@ -183,7 +184,7 @@ def plot_p_values(
# Plot the p-value threshold.
ax.plot(
threshold_times,
np.full(threshold_times.shape, p_threshold),
_np.full(threshold_times.shape, p_threshold),
color="black",
linestyle="-",
linewidth=0.5,
Expand Down Expand Up @@ -217,15 +218,15 @@ def plot_p_values(


def plot_sse(
ax: Axes,
sse: np.ndarray,
max_lags: _Optional[np.ndarray],
window_sizes: _Optional[np.ndarray],
times: np.ndarray,
ax: _Axes,
sse: _np.ndarray,
max_lags: _Optional[_np.ndarray],
window_sizes: _Optional[_np.ndarray],
times: _np.ndarray,
time_units: str = "ns",
variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$",
reciprocal: bool = True,
) -> _Tuple[_List[Line2D], _List[str]]:
) -> _Tuple[_List[_Line2D], _List[str]]:
r"""
Plot the squared standard error (SSE) estimate against time.
Expand Down Expand Up @@ -265,7 +266,7 @@ def plot_sse(
The labels for the legend.
"""
# Check that sse is valid.
if not isinstance(sse, np.ndarray) or not isinstance(times, np.ndarray):
if not isinstance(sse, _np.ndarray) or not isinstance(times, _np.ndarray):
raise InvalidInputError("sse and times must be numpy arrays.")

if sse.ndim != 1:
Expand All @@ -291,7 +292,9 @@ def plot_sse(
to_plot = max_lags if window_sizes is None else window_sizes
ax2 = ax.twinx()
# Get the second colour from the colour cycle.
ax2.set_prop_cycle(color=[plt.rcParams["axes.prop_cycle"].by_key()["color"][1]])
ax2.set_prop_cycle(
color=[_plt.rcParams["axes.prop_cycle"].by_key()["color"][1]]
)
ax2.plot(times, to_plot, alpha=0.8, label=label)
ax2.set_ylabel(label)

Expand Down Expand Up @@ -320,16 +323,16 @@ def plot_sse(


def plot_equilibration_paired_t_test(
fig: figure.Figure,
subplot_spec: gridspec.SubplotSpec,
data: np.ndarray,
p_values: np.ndarray,
data_times: np.ndarray,
p_times: np.ndarray,
fig: _figure.Figure,
subplot_spec: _gridspec.SubplotSpec,
data: _np.ndarray,
p_values: _np.ndarray,
data_times: _np.ndarray,
p_times: _np.ndarray,
p_threshold: float = 0.05,
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
) -> _Tuple[Axes, Axes]:
) -> _Tuple[_Axes, _Axes]:
r"""
Plot the p-values of the paired t-test against time, underneath the
time series data.
Expand Down Expand Up @@ -376,7 +379,9 @@ def plot_equilibration_paired_t_test(
"""
# We need to split the gridspec into two subplots, one for the time series data (above)
# and one for the p-values (below). Share x-axis but not y-axis.
gs0 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=subplot_spec, hspace=0.05)
gs0 = _gridspec.GridSpecFromSubplotSpec(
2, 1, subplot_spec=subplot_spec, hspace=0.05
)
ax_top = fig.add_subplot(gs0[0])
ax_bottom = fig.add_subplot(gs0[1], sharex=ax_top)

Expand Down Expand Up @@ -407,7 +412,7 @@ def plot_equilibration_paired_t_test(
ax_bottom.legend(bbox_to_anchor=(1.05, 0.5), loc="center left")

# Hide the x tick labels for the top axis.
plt.setp(ax_top.get_xticklabels(), visible=False)
_plt.setp(ax_top.get_xticklabels(), visible=False)
ax_top.spines["bottom"].set_visible(False)
ax_top.tick_params(axis="x", which="both", length=0)
ax_top.set_xlabel("")
Expand All @@ -416,19 +421,19 @@ def plot_equilibration_paired_t_test(


def plot_equilibration_min_sse(
fig: figure.Figure,
subplot_spec: gridspec.SubplotSpec,
data: np.ndarray,
sse_series: np.ndarray,
data_times: np.ndarray,
sse_times: np.ndarray,
max_lag_series: _Optional[np.ndarray] = None,
window_size_series: _Optional[np.ndarray] = None,
fig: _figure.Figure,
subplot_spec: _gridspec.SubplotSpec,
data: _np.ndarray,
sse_series: _np.ndarray,
data_times: _np.ndarray,
sse_times: _np.ndarray,
max_lag_series: _Optional[_np.ndarray] = None,
window_size_series: _Optional[_np.ndarray] = None,
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
variance_y_label=r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$",
reciprocal: bool = True,
) -> _Tuple[Axes, Axes]:
) -> _Tuple[_Axes, _Axes]:
r"""
Plot the (reciprocal of the) squared standard error (SSE)
estimates against time, underneath the time series data.
Expand Down Expand Up @@ -492,7 +497,9 @@ def plot_equilibration_min_sse(

# We need to split the gridspec into two subplots, one for the time series data (above)
# and one for the p-values (below). Share x-axis but not y-axis.
gs0 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=subplot_spec, hspace=0.05)
gs0 = _gridspec.GridSpecFromSubplotSpec(
2, 1, subplot_spec=subplot_spec, hspace=0.05
)
ax_top = fig.add_subplot(gs0[0])
ax_bottom = fig.add_subplot(gs0[1], sharex=ax_top)

Expand Down Expand Up @@ -537,7 +544,7 @@ def plot_equilibration_min_sse(
)

# Hide the x tick labels for the top axis.
plt.setp(ax_top.get_xticklabels(), visible=False)
_plt.setp(ax_top.get_xticklabels(), visible=False)
ax_top.spines["bottom"].set_visible(False)
ax_top.tick_params(axis="x", which="both", length=0)
ax_top.set_xlabel("")
Expand Down
1 change: 0 additions & 1 deletion deea/tests/test_gelman_rubin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests for the Gelman-Rubin convergence diagnostic."""

import numpy as np
import pytest

from .._exceptions import InvalidInputError
Expand Down
Loading

0 comments on commit 1c810b3

Please sign in to comment.