diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 6dfa25c..1f9b295 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -365,7 +365,6 @@ def apply(self, thermodynamic_state, sampler_state): The initial sampler state to apply the move to. This is modified. """ - import copy import jax.numpy as jnp # Compute initial energy @@ -374,16 +373,22 @@ def apply(self, thermodynamic_state, sampler_state): # Store initial positions of the atoms that are moved. # We'll use this also to recover in case the move is rejected. atom_subset = self.atom_subset + x0 = sampler_state.x0 log.debug(f"Atom subset is {atom_subset}.") - initial_positions = copy.deepcopy(sampler_state.x0[atom_subset]) + initial_positions = (jnp.copy(x0[jnp.array(atom_subset)]),) log.debug(f"Initial positions are {initial_positions}.") # Propose perturbed positions. Modifying the reference changes the sampler state. - proposed_positions = self._propose_positions(initial_positions) + proposed_positions = self._propose_positions( + unit.Quantity(initial_positions, sampler_state.distance_unit) + ) log.debug(f"Proposed positions are {proposed_positions}.") - log.debug(f"Sampler state is {sampler_state.x0}.") + log.debug(f"Sampler state is {sampler_state.x0_unitless}.") # Compute the energy of the proposed positions. - sampler_state.x0[atom_subset] = proposed_positions + sampler_state.x0 = sampler_state.x0.at[jnp.array(atom_subset)].set( + proposed_positions + ) + log.debug(f"Sampler state is {sampler_state.x0}.") proposed_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug(f"Proposed energy is {proposed_energy}.") # Accept or reject with Metropolis criteria. @@ -408,14 +413,14 @@ def apply(self, thermodynamic_state, sampler_state): ) self.n_proposed += 1 - def _propose_positions(self, positions: jnp.ndarray): + def _propose_positions(self, positions: unit.Quantity): """Return new proposed positions. These method must be implemented in subclasses. Parameters ---------- - positions : nx3 jnp.ndarray + positions : nx3 jnp.ndarray unit.Quantity The original positions of the subset of atoms that these move applied to. @@ -463,17 +468,44 @@ def __init__( displacement_sigma=1.0 * unit.nanometer, nr_of_moves: int = 100, atom_subset: Optional[List[int]] = None, + slice_dim: Optional[int] = None, ): + """ + Initialize the MCMC class. + + Parameters + ---------- + seed : int, optional + The seed for the random number generator. Default is 1234. + displacement_sigma : float or unit.Quantity, optional + The standard deviation of the displacement for each move. Default is 1.0 nm. + nr_of_moves : int, optional + The number of moves to perform. Default is 100. + atom_subset : list of int, optional + A subset of atom indices to consider for the moves. Default is None. + slice_dim : int, optional + The dimension along which to slice the atom subset. Default is None. + + Returns + ------- + None + """ + super().__init__(nr_of_moves=nr_of_moves, seed=seed) self.displacement_sigma = displacement_sigma self.atom_subset = atom_subset + self.slice_dim = slice_dim + if slice_dim is not None: + log.info(f"Updating coordinates only along dimension {self.slice_dim}") - def displace_positions(self, positions, displacement_sigma=1.0 * unit.nanometer): + def displace_positions( + self, positions: unit.Quantity, displacement_sigma=1.0 * unit.nanometer + ): """Return the positions after applying a random displacement to them. Parameters ---------- - positions : nx3 numpy.ndarray openmm.unit.Quantity + positions : nx3 jnp.array unit.Quantity The positions to displace. displacement_sigma : openmm.unit.Quantity The standard deviation of the normal distribution used to propose @@ -488,15 +520,24 @@ def displace_positions(self, positions, displacement_sigma=1.0 * unit.nanometer) import jax.random as jrandom key, subkey = jrandom.split(self.key) - positions_unit = positions.unit - unitless_displacement_sigma = displacement_sigma / positions_unit - displacement_vector = unit.Quantity( - jrandom.normal(subkey, shape=(3,)) * unitless_displacement_sigma, - positions_unit, - ) - return positions + displacement_vector + distance_unit = positions.unit + x0 = positions.value_in_unit(distance_unit) + unitless_displacement_sigma = displacement_sigma.value_in_unit(distance_unit) + if self.slice_dim is not None: + displacement_scalar = ( + jrandom.normal(subkey, shape=(1,)) * unitless_displacement_sigma + ) + updated_position = (x0.at[0, self.slice_dim].add(displacement_scalar),) + else: + displacement_vector = ( + jrandom.normal(subkey, shape=(3,)) * unitless_displacement_sigma + ) + updated_position = x0 + displacement_vector + + log.debug(f"Updated position: {updated_position}") + return unit.Quantity(updated_position, distance_unit) - def _propose_positions(self, initial_positions): + def _propose_positions(self, initial_positions: unit.Quantity): """Implement MetropolizedMove._propose_positions for apply().""" return self.displace_positions(initial_positions, self.displacement_sigma) diff --git a/chiron/states.py b/chiron/states.py index 26a5cef..83db8a6 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -27,9 +27,45 @@ def __init__( velocities: Optional[unit.Quantity] = None, box_vectors: Optional[unit.Quantity] = None, ) -> None: - self.x0 = x0 - self.velocities = velocities - self.box_vectors = box_vectors + import jax.numpy as jnp + + self._distance_unit = x0.unit + self._x0 = x0 + self._velocities = velocities + self._box_vectors = box_vectors + + @property + def x0(self) -> jnp.array: + return self._convert_to_jnp(self._x0) + + @property + def velocities(self) -> jnp.array: + if self._velocities is None: + return None + return self._convert_to_jnp(self._velocities) + + @property + def box_vectors(self) -> jnp.array: + if self._box_vectors is None: + return None + return self._convert_to_jnp(self._box_vectors) + + @x0.setter + def x0(self, x0: jnp.array) -> None: + self._x0 = unit.Quantity(x0, self._distance_unit) + + @property + def distance_unit(self) -> unit.Unit: + return self._distance_unit + + def _convert_to_jnp(self, array: unit.Quantity) -> unit.Quantity: + """ + Convert the sampler state to jnp arrays. + """ + import jax.numpy as jnp + + array_ = array / self.distance_unit + return unit.Quantity(jnp.array(array_), self.distance_unit) @property def x0_unitless(self) -> jnp.ndarray: diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index c5705f1..b2ffad9 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -124,6 +124,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla nr_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=[0], + slice_dim=None, ) move_set = MoveSet([("MetropolisDisplacementMove", mc_displacement_move)]) @@ -133,6 +134,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla # Run the sampler with the thermodynamic state and sampler state and return the sampler state sampler.run(n_iterations=2) # how many times to repeat + #assert False def test_sample_from_joint_distribution_of_two_HO_with_local_moves_and_MC_updates():