Skip to content

Commit

Permalink
Ensure equilibration functions return integer indices
Browse files Browse the repository at this point in the history
  • Loading branch information
fjclark committed Oct 6, 2024
1 parent 10736da commit d01b9fa
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ PACKAGE_DIR := red
# For the CI github actions workflow, we skip "make env" and set up the environment manually. In this case,
# it's helpful to to set CONDA_ENV_RUN to be empty. However, for the documentation workflow, we want to override
# this and keep the normal behavior. We override this by setting KEEP_CONDA_ENV_RUN to true in the documentation workflow.
CONDA_ENV_RUN = $(if $(GITHUB_ACTIONS),$(if $(KEEP_CONDA_ENV_RUN),conda run --no-capture-output --name $(PACKAGE_NAME),),conda run --no-capture-output --name $(PACKAGE_NAME))
SKIP_CONDA_ENV = $(and $(GITHUB_ACTIONS),$(not $(KEEP_CONDA_ENV_RUN)))
CONDA_ENV_RUN = $(if $(SKIP_CONDA_ENV),,conda run --no-capture-output --name $(PACKAGE_NAME))

TEST_ARGS := -v --cov=$(PACKAGE_NAME) --cov-report=term --cov-report=xml --junitxml=unit.xml --color=yes

Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.1.1] - 2024-09-29

### Fixed

- Ensure that equilibration detection functions return integer indices when times are not provided.

## [0.1.0] - 2024-09-23

Initial release.
6 changes: 3 additions & 3 deletions red/equilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def detect_equilibration_init_seq(
if times is None:
time_units = "index"
# Convert times to indices.
times_valid: _npt.NDArray[_np.float64] = _np.arange(n_samples, dtype=_np.float64)
times_valid: _npt.NDArray[_np.int64 | _np.float64] = _np.arange(n_samples, dtype=_np.int64)
else:
# To satisfy type checking.
times_valid = times
Expand Down Expand Up @@ -276,7 +276,7 @@ def detect_equilibration_window(
if times is None:
time_units = "index"
# Convert times to indices.
times_valid: _npt.NDArray[_np.float64] = _np.arange(n_samples, dtype=_np.float64)
times_valid: _npt.NDArray[_np.int64 | _np.float64] = _np.arange(n_samples, dtype=_np.int64)
else:
# To satisfy type checking.
times_valid = times
Expand Down Expand Up @@ -395,7 +395,7 @@ def get_paired_t_p_timeseries(

# Convert times to indices if necessary.
if times is None:
times = _np.arange(n_samples, dtype=_np.float64)
times = _np.arange(n_samples, dtype=_np.int64)

# Check that times is match the number of samples.
if n_samples != len(times):
Expand Down
8 changes: 4 additions & 4 deletions red/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def plot_timeseries(
ax: _Axes,
data: _npt.NDArray[_np.float64],
times: _npt.NDArray[_np.float64],
times: _npt.NDArray[_np.float64 | _np.int64],
n_blocks: int = 100,
time_units: str = "ns",
y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
Expand Down Expand Up @@ -203,7 +203,7 @@ def plot_sse(
sse: _npt.NDArray[_np.float64],
max_lags: _Optional[_npt.NDArray[_np.float64]],
window_sizes: _Optional[_npt.NDArray[_np.float64]],
times: _npt.NDArray[_np.float64],
times: _npt.NDArray[_np.float64 | _np.int64],
time_units: str = "ns",
variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$",
reciprocal: bool = True,
Expand Down Expand Up @@ -409,8 +409,8 @@ def plot_equilibration_min_sse(
subplot_spec: _gridspec.SubplotSpec,
data: _npt.NDArray[_np.float64],
sse_series: _npt.NDArray[_np.float64],
data_times: _npt.NDArray[_np.float64],
sse_times: _npt.NDArray[_np.float64],
data_times: _npt.NDArray[_np.float64 | _np.int64],
sse_times: _npt.NDArray[_np.float64 | _np.int64],
max_lag_series: _Optional[_npt.NDArray[_np.float64]] = None,
window_size_series: _Optional[_npt.NDArray[_np.float64]] = None,
time_units: str = "ns",
Expand Down
3 changes: 3 additions & 0 deletions red/tests/test_equilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_detect_equilibration_init_seq(example_timeseries, example_times, tmpdir
# Compute the equilibration index.
equil_idx, equil_g, equil_ess = detect_equilibration_init_seq(data=example_timeseries)

# Make sure that the index is a numpy int64
assert isinstance(equil_idx, np.int64)

# Check that the equilibration index is correct.
assert equil_idx == 398
assert equil_g == pytest.approx(4.292145845594654, abs=1e-4)
Expand Down

0 comments on commit d01b9fa

Please sign in to comment.