Skip to content

Commit

Permalink
Modified multistate sampler to accept a list of Pair/Neighbor lists (…
Browse files Browse the repository at this point in the history
…one for each sampler state). Changed test system (HOA) to use nonperiodic space and the pair list; modified neighbor/pairlist classes to not fail if box_vectors = None (errors will be thrown in Space class if box vectors are needed).
  • Loading branch information
chrisiacovella committed Feb 29, 2024
1 parent 124e519 commit d38ed6d
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 61 deletions.
51 changes: 38 additions & 13 deletions chiron/multistate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Union
from chiron.states import SamplerState, ThermodynamicState
from chiron.neighbors import NeighborListNsqrd
from chiron.neighbors import PairsBase
from openmm import unit
import numpy as np
from chiron.mcmc import MCMCMove, MCMCSampler
Expand Down Expand Up @@ -36,7 +36,7 @@ class MultiStateSampler:
Methods
-------
create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd)
create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_lists: List[PairsBase])
Creates a new multistate sampler simulation.
minimize(tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1000)
Minimizes all replicas in the sampler.
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
self._neighborhoods = None
self._n_accepted_matrix = None
self._n_proposed_matrix = None
self._nbr_lists = None

self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead
self._metadata = None
Expand Down Expand Up @@ -162,7 +163,7 @@ def is_periodic(self):
"""
if self._sampler_states is None:
return None
return self._thermodynamic_states[0].is_periodic
return self.is_periodic

@property
def is_completed(self):
Expand All @@ -187,17 +188,18 @@ def _compute_replica_energies(self, replica_id: int) -> np.ndarray:

# Retrieve sampler state associated to this replica.
sampler_state = self._sampler_states[replica_id]
nbr_list = self._sampler_states[replica_id]
# Compute energy for all thermodynamic states.
energies = calculate_reduced_potential_at_states(
sampler_state, self._thermodynamic_states, self.nbr_list
sampler_state, self._thermodynamic_states, nbr_list
)
return energies

def create(
self,
thermodynamic_states: List[ThermodynamicState],
sampler_states: List[SamplerState],
nbr_list: NeighborListNsqrd,
nbr_lists: List[PairsBase],
):
"""
Create a new multistate sampler simulation.
Expand All @@ -208,8 +210,8 @@ def create(
List of ThermodynamicStates to simulate, with one replica per state.
sampler_states : List[SamplerState]
List of initial SamplerStates. The number of states is the number of replicas.
nbr_list : NeighborListNsqrd
Neighbor list object for the simulation.
nbr_lists : List[PairsBase]
A list of objects used to efficiently calculate interacting pairs for each sampler state.
Raises
------
Expand All @@ -227,14 +229,14 @@ def create(
"Number of thermodynamic states and sampler states must be equal."
)

self.nbr_list = nbr_list
self._allocate_variables(thermodynamic_states, sampler_states)
self._allocate_variables(thermodynamic_states, sampler_states, nbr_lists)
self._reporter = MultistateReporter()

def _allocate_variables(
self,
thermodynamic_states: List[ThermodynamicState],
sampler_states: List[SamplerState],
nbr_lists: List[PairsBase],
) -> None:
"""
Allocate and initialize internal variables for the sampler.
Expand All @@ -245,6 +247,8 @@ def _allocate_variables(
A list of ThermodynamicState objects to be used in the sampler.
sampler_states : List[SamplerState]
A list of SamplerState objects for initializing the sampler.
nbr_lists : List[PairsBase]
A list of objects used to efficiently calculate interacting pairs for each sampler state.
Raises
------
Expand All @@ -255,8 +259,16 @@ def _allocate_variables(
import numpy as np

self._thermodynamic_states = copy.deepcopy(thermodynamic_states)
self._sampler_states = sampler_states
self._sampler_states = copy.deepcopy(sampler_states)
self._nbr_lists = copy.deepcopy(nbr_lists)

assert len(self._thermodynamic_states) == len(self._sampler_states)
assert len(self._thermodynamic_states) == len(self._nbr_lists)

# initial build of neighborlists
for nbr_list, state in zip(self._nbr_lists, self._sampler_states):
nbr_list.build(state.positions, state.box_vectors)

self._replica_thermodynamic_states = np.arange(
len(thermodynamic_states), dtype=int
)
Expand Down Expand Up @@ -325,13 +337,23 @@ def _minimize_replica(
minimized_state = minimize_energy(
sampler_state.positions,
thermodynamic_state.potential.compute_energy,
self.nbr_list,
self._nbr_lists[replica_id],
maxiter=max_iterations,
)

# Update the sampler state
self._sampler_states[replica_id].positions = minimized_state.params

# it is not likely that we would need to rebuild after minimization
# but we should make sure check to make sure
if self._nbr_lists[replica_id].check(
self._sampler_states[replica_id].positions
):
self._nbr_lists[replica_id].build(
self._sampler_states[replica_id].positions,
self._sampler_states[replica_id].box_vectors,
)

# Compute and log final energy
final_energy = thermodynamic_state.get_reduced_potential(sampler_state)
log.debug(
Expand Down Expand Up @@ -395,15 +417,18 @@ def _propagate_replica(self, replica_id: int):
thermodynamic_state_id = self._replica_thermodynamic_states[replica_id]
sampler_state = self._sampler_states[replica_id]
thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id]
nbr_list = self._nbr_lists[replica_id]

mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id]
# Propagate using the mcmc sampler
# NOTE this needs to be updated to support neighborlists
(
self._sampler_states[replica_id],
self._thermodynamic_states[thermodynamic_state_id],
nbr_list,
) = mcmc_sampler.run(sampler_state, thermodynamic_state)
self._nbr_lists[replica_id],
) = mcmc_sampler.run(
sampler_state, thermodynamic_state, self.number_of_iterations, nbr_list
)
# Append the new state to the trajectory for analysis.
self._traj[replica_id].append(self._sampler_states[replica_id].positions)

Expand Down
Loading

0 comments on commit d38ed6d

Please sign in to comment.