diff --git a/chiron/multistate.py b/chiron/multistate.py index e67d7bb..89c9f8f 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union from chiron.states import SamplerState, ThermodynamicState -from chiron.neighbors import NeighborListNsqrd +from chiron.neighbors import PairsBase from openmm import unit import numpy as np from chiron.mcmc import MCMCMove, MCMCSampler @@ -36,7 +36,7 @@ class MultiStateSampler: Methods ------- - create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd) + create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_lists: List[PairsBase]) Creates a new multistate sampler simulation. minimize(tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1000) Minimizes all replicas in the sampler. @@ -75,6 +75,7 @@ def __init__( self._neighborhoods = None self._n_accepted_matrix = None self._n_proposed_matrix = None + self._nbr_lists = None self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None @@ -162,7 +163,7 @@ def is_periodic(self): """ if self._sampler_states is None: return None - return self._thermodynamic_states[0].is_periodic + return self.is_periodic @property def is_completed(self): @@ -187,9 +188,10 @@ def _compute_replica_energies(self, replica_id: int) -> np.ndarray: # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] + nbr_list = self._sampler_states[replica_id] # Compute energy for all thermodynamic states. energies = calculate_reduced_potential_at_states( - sampler_state, self._thermodynamic_states, self.nbr_list + sampler_state, self._thermodynamic_states, nbr_list ) return energies @@ -197,7 +199,7 @@ def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], - nbr_list: NeighborListNsqrd, + nbr_lists: List[PairsBase], ): """ Create a new multistate sampler simulation. @@ -208,8 +210,8 @@ def create( List of ThermodynamicStates to simulate, with one replica per state. sampler_states : List[SamplerState] List of initial SamplerStates. The number of states is the number of replicas. - nbr_list : NeighborListNsqrd - Neighbor list object for the simulation. + nbr_lists : List[PairsBase] + A list of objects used to efficiently calculate interacting pairs for each sampler state. Raises ------ @@ -227,14 +229,14 @@ def create( "Number of thermodynamic states and sampler states must be equal." ) - self.nbr_list = nbr_list - self._allocate_variables(thermodynamic_states, sampler_states) + self._allocate_variables(thermodynamic_states, sampler_states, nbr_lists) self._reporter = MultistateReporter() def _allocate_variables( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], + nbr_lists: List[PairsBase], ) -> None: """ Allocate and initialize internal variables for the sampler. @@ -245,6 +247,8 @@ def _allocate_variables( A list of ThermodynamicState objects to be used in the sampler. sampler_states : List[SamplerState] A list of SamplerState objects for initializing the sampler. + nbr_lists : List[PairsBase] + A list of objects used to efficiently calculate interacting pairs for each sampler state. Raises ------ @@ -255,8 +259,16 @@ def _allocate_variables( import numpy as np self._thermodynamic_states = copy.deepcopy(thermodynamic_states) - self._sampler_states = sampler_states + self._sampler_states = copy.deepcopy(sampler_states) + self._nbr_lists = copy.deepcopy(nbr_lists) + assert len(self._thermodynamic_states) == len(self._sampler_states) + assert len(self._thermodynamic_states) == len(self._nbr_lists) + + # initial build of neighborlists + for nbr_list, state in zip(self._nbr_lists, self._sampler_states): + nbr_list.build(state.positions, state.box_vectors) + self._replica_thermodynamic_states = np.arange( len(thermodynamic_states), dtype=int ) @@ -325,13 +337,23 @@ def _minimize_replica( minimized_state = minimize_energy( sampler_state.positions, thermodynamic_state.potential.compute_energy, - self.nbr_list, + self._nbr_lists[replica_id], maxiter=max_iterations, ) # Update the sampler state self._sampler_states[replica_id].positions = minimized_state.params + # it is not likely that we would need to rebuild after minimization + # but we should make sure check to make sure + if self._nbr_lists[replica_id].check( + self._sampler_states[replica_id].positions + ): + self._nbr_lists[replica_id].build( + self._sampler_states[replica_id].positions, + self._sampler_states[replica_id].box_vectors, + ) + # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( @@ -395,6 +417,7 @@ def _propagate_replica(self, replica_id: int): thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + nbr_list = self._nbr_lists[replica_id] mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id] # Propagate using the mcmc sampler @@ -402,8 +425,10 @@ def _propagate_replica(self, replica_id: int): ( self._sampler_states[replica_id], self._thermodynamic_states[thermodynamic_state_id], - nbr_list, - ) = mcmc_sampler.run(sampler_state, thermodynamic_state) + self._nbr_lists[replica_id], + ) = mcmc_sampler.run( + sampler_state, thermodynamic_state, self.number_of_iterations, nbr_list + ) # Append the new state to the trajectory for analysis. self._traj[replica_id].append(self._sampler_states[replica_id].positions) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index be80cf0..ead53df 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -68,6 +68,9 @@ def displacement( # calculate uncorrected r_ij r_ij = xyz_1 - xyz_2 + if box_vectors is None: + raise ValueError("box_vectors must be provided for a periodic system") + box_lengths = jnp.array( [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) @@ -97,6 +100,9 @@ def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: Wrapped positions of the system """ + if box_vectors is None: + raise ValueError("box_vectors must be provided for a periodic system") + box_lengths = jnp.array( [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) @@ -106,16 +112,16 @@ def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: return xyz -class OrthogonalNonperiodicSpace(Space): +class OrthogonalNonPeriodicSpace(Space): @partial(jax.jit, static_argnums=(0,)) def displacement( self, xyz_1: jnp.array, xyz_2: jnp.array, - box_vectors: jnp.array, + box_vectors: Optional[jnp.array] = None, ) -> Tuple[jnp.array, jnp.array]: """ - Calculate the periodic distance between two points. + Calculate the distance between two points in a non-periodic system. Parameters ---------- @@ -123,8 +129,9 @@ def displacement( Positions of the first point xyz_2: jnp.array Positions of the second point - box_vectors: jnp.array + box_vectors: Optional[jnp.array]=None Box vectors for the system. + These are not needed for a non-periodic system, but are included for consistent API usage. Returns ------- @@ -134,7 +141,7 @@ def displacement( Distance between the two points """ - # calculate uncorrect r_ij + # calculate r_ij r_ij = xyz_1 - xyz_2 # calculate the scalar distance @@ -143,17 +150,21 @@ def displacement( return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: + def wrap( + self, xyz: jnp.array, box_vectors: Optional[jnp.array] = None + ) -> jnp.array: """ - Wrap the positions of the system. - For the Non-periodic system, this does not alter the positions + Wrap the positions of the system inside the box. + For the non-periodic system, this does not alter the positions. Parameters ---------- xyz: jnp.array Positions of the system - box_vectors: jnp.array - Box vectors for the system + box_vectors: Optional[jnp.array]=None + Box vectors for the system. + These are not needed for a non-periodic system, but are included for consistent API usage. + Returns ------- @@ -226,7 +237,7 @@ def __init__( def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build list from an array of positions and array of box vectors. @@ -236,9 +247,10 @@ def build( positions: jnp.array or unit.Quantity Shape[n_particles,3] array of particle positions, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. - box_vectors: jnp.array or unit.Quantity + box_vectors: jnp.array or unit.Quantity or None Shape[3,3] array of box vectors for the system, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + If None, the system is assumed to be non-periodic and the Space class must reflect this. Returns ------- @@ -250,7 +262,7 @@ def build( def _validate_build_inputs( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Validate the inputs to the build function. @@ -292,6 +304,8 @@ def _validate_build_inputs( f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" ) self.box_vectors = box_vectors + if box_vectors is None: + self.box_vectors = None def build_from_state(self, sampler_state: SamplerState): """ @@ -310,8 +324,8 @@ def build_from_state(self, sampler_state: SamplerState): raise TypeError(f"Expected SamplerState, got {type(sampler_state)} instead") positions = sampler_state.positions - if sampler_state.box_vectors is None: - raise ValueError(f"SamplerState does not contain box vectors") + # if sampler_state.box_vectors is None: + # raise ValueError(f"SamplerState does not contain box vectors") box_vectors = sampler_state.box_vectors self.build(positions, box_vectors) @@ -557,8 +571,9 @@ def _build_neighborlist( Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations cutoff_and_skin: float Cutoff distance for the neighborlist plus the skin distance, in nanometers. - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -613,7 +628,7 @@ def _build_neighborlist( def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build the neighbor list from an array of positions and box vectors. @@ -622,7 +637,7 @@ def build( ---------- positions: jnp.array Shape[N,3] array of particle positions - box_vectors: jnp.array + box_vectors: Union[jnp.array, None] Shape[3,3] array of box vectors Returns @@ -647,10 +662,11 @@ def build( ) box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) - if box_vectors.shape != (3, 3): - raise ValueError( - f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" - ) + if isinstance(box_vectors, jnp.ndarray): + if box_vectors.shape != (3, 3): + raise ValueError( + f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" + ) self.ref_positions = positions self.box_vectors = box_vectors @@ -714,7 +730,13 @@ def build( @partial(jax.jit, static_argnums=(0,)) def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors + self, + particle1: int, + neighbors: jnp.array, + neighbor_mask: jnp.array, + positions: jnp.array, + cutoff: float, + box_vectors: Union[jnp.array, None], ): """ Jitted function to calculate the distance between a particle and its neighbors @@ -731,8 +753,9 @@ def _calc_distance_per_particle( X,Y,Z positions of all particles cutoff: float Cutoff distance for the neighborlist, in nanometers - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1046,7 +1069,7 @@ def _remove_self_interactions(self, particles, temp_mask): def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build the list from an array of positions and box vectors. @@ -1055,8 +1078,9 @@ def build( ---------- positions: jnp.array Shape[n_particles,3] array of particle positions - box_vectors: jnp.array - Shape[3,3] array of box vectors + box_vectors: jnp.array or unit.Quantity, or None + Shape[3,3] array of box vectors, with or without units. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1098,8 +1122,9 @@ def _calc_distance_per_particle_with_cutoff( X,Y,Z positions of all particles, shaped (n_particles, 3) cutoff: float Cutoff distance for the interaction. - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1152,8 +1177,9 @@ def _calc_distance_per_particle_no_cutoff( Mask to exclude double particles to prevent double counting positions: jnp.array X,Y,Z positions of all particles, shaped (n_particles, 3) - box_vectors: jnp.array - Box vectors of the system + box_vectors: Union[jnp.array, None] + Box vectors of the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 533d879..c850efb 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,29 +1,27 @@ +import copy + from chiron.multistate import MultiStateSampler -from chiron.neighbors import NeighborListNsqrd +from chiron.neighbors import PairListNsqrd import pytest from typing import Tuple -def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: +def setup_sampler() -> Tuple[PairListNsqrd, MultiStateSampler]: """ - Set up the neighbor list and multistate sampler for the simulation. + Set up the pair list and multistate sampler for the simulation. Returns: - Tuple: A tuple containing the neighbor list and multistate sampler objects. + Tuple: A tuple containing the pair list and multistate sampler objects. """ from openmm import unit from chiron.mcmc import LangevinDynamicsMove - from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalNonPeriodicSpace from chiron.reporters import MultistateReporter, BaseReporter from chiron.mcmc import MCMCSampler, MoveSchedule - sigma = 0.34 * unit.nanometer - cutoff = 3.0 * sigma - skin = 0.5 * unit.nanometer + cutoff = 1.0 * unit.nanometer - nbr_list = NeighborListNsqrd( - OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 - ) + nbr_list = PairListNsqrd(OrthogonalNonPeriodicSpace(), cutoff=cutoff) lang_move = LangevinDynamicsMove( timestep=1.0 * unit.femtoseconds, number_of_steps=100 @@ -76,10 +74,14 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] nbr_list, multistate_sampler = setup_sampler() + import copy + + nbr_lists = [copy.deepcopy(nbr_list) for _ in x0s] + multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, - nbr_list=nbr_list, + nbr_lists=nbr_lists, ) return multistate_sampler @@ -135,11 +137,14 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: ) nbr_list, multistate_sampler = setup_sampler() + import copy + + nbr_lists = [copy.deepcopy(nbr_list) for _ in sigmas] multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, - nbr_list=nbr_list, + nbr_lists=nbr_lists, ) multistate_sampler.analytical_f_i = f_i multistate_sampler.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index 60df5ad..f463a43 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -4,7 +4,7 @@ NeighborListNsqrd, PairListNsqrd, OrthogonalPeriodicSpace, - OrthogonalNonperiodicSpace, + OrthogonalNonPeriodicSpace, ) from chiron.states import SamplerState @@ -42,7 +42,7 @@ def test_orthogonal_periodic_displacement(): def test_orthogonal_nonperiodic_displacement(): - space = OrthogonalNonperiodicSpace() + space = OrthogonalNonPeriodicSpace() box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]])