Skip to content

Commit

Permalink
moving unit conversion to states
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Dec 4, 2023
1 parent 1324bcd commit db03bff
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 20 deletions.
75 changes: 58 additions & 17 deletions chiron/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def apply(self, thermodynamic_state, sampler_state):
The initial sampler state to apply the move to. This is modified.
"""
import copy
import jax.numpy as jnp

# Compute initial energy
Expand All @@ -374,16 +373,22 @@ def apply(self, thermodynamic_state, sampler_state):
# Store initial positions of the atoms that are moved.
# We'll use this also to recover in case the move is rejected.
atom_subset = self.atom_subset
x0 = sampler_state.x0
log.debug(f"Atom subset is {atom_subset}.")
initial_positions = copy.deepcopy(sampler_state.x0[atom_subset])
initial_positions = (jnp.copy(x0[jnp.array(atom_subset)]),)
log.debug(f"Initial positions are {initial_positions}.")
# Propose perturbed positions. Modifying the reference changes the sampler state.
proposed_positions = self._propose_positions(initial_positions)
proposed_positions = self._propose_positions(
unit.Quantity(initial_positions, sampler_state.distance_unit)
)
log.debug(f"Proposed positions are {proposed_positions}.")
log.debug(f"Sampler state is {sampler_state.x0}.")
log.debug(f"Sampler state is {sampler_state.x0_unitless}.")

# Compute the energy of the proposed positions.
sampler_state.x0[atom_subset] = proposed_positions
sampler_state.x0 = sampler_state.x0.at[jnp.array(atom_subset)].set(
proposed_positions
)
log.debug(f"Sampler state is {sampler_state.x0}.")
proposed_energy = thermodynamic_state.get_reduced_potential(sampler_state)
log.debug(f"Proposed energy is {proposed_energy}.")
# Accept or reject with Metropolis criteria.
Expand All @@ -408,14 +413,14 @@ def apply(self, thermodynamic_state, sampler_state):
)
self.n_proposed += 1

def _propose_positions(self, positions: jnp.ndarray):
def _propose_positions(self, positions: unit.Quantity):
"""Return new proposed positions.
These method must be implemented in subclasses.
Parameters
----------
positions : nx3 jnp.ndarray
positions : nx3 jnp.ndarray unit.Quantity
The original positions of the subset of atoms that these move
applied to.
Expand Down Expand Up @@ -463,17 +468,44 @@ def __init__(
displacement_sigma=1.0 * unit.nanometer,
nr_of_moves: int = 100,
atom_subset: Optional[List[int]] = None,
slice_dim: Optional[int] = None,
):
"""
Initialize the MCMC class.
Parameters
----------
seed : int, optional
The seed for the random number generator. Default is 1234.
displacement_sigma : float or unit.Quantity, optional
The standard deviation of the displacement for each move. Default is 1.0 nm.
nr_of_moves : int, optional
The number of moves to perform. Default is 100.
atom_subset : list of int, optional
A subset of atom indices to consider for the moves. Default is None.
slice_dim : int, optional
The dimension along which to slice the atom subset. Default is None.
Returns
-------
None
"""

super().__init__(nr_of_moves=nr_of_moves, seed=seed)
self.displacement_sigma = displacement_sigma
self.atom_subset = atom_subset
self.slice_dim = slice_dim
if slice_dim is not None:
log.info(f"Updating coordinates only along dimension {self.slice_dim}")

def displace_positions(self, positions, displacement_sigma=1.0 * unit.nanometer):
def displace_positions(
self, positions: unit.Quantity, displacement_sigma=1.0 * unit.nanometer
):
"""Return the positions after applying a random displacement to them.
Parameters
----------
positions : nx3 numpy.ndarray openmm.unit.Quantity
positions : nx3 jnp.array unit.Quantity
The positions to displace.
displacement_sigma : openmm.unit.Quantity
The standard deviation of the normal distribution used to propose
Expand All @@ -488,15 +520,24 @@ def displace_positions(self, positions, displacement_sigma=1.0 * unit.nanometer)
import jax.random as jrandom

key, subkey = jrandom.split(self.key)
positions_unit = positions.unit
unitless_displacement_sigma = displacement_sigma / positions_unit
displacement_vector = unit.Quantity(
jrandom.normal(subkey, shape=(3,)) * unitless_displacement_sigma,
positions_unit,
)
return positions + displacement_vector
distance_unit = positions.unit
x0 = positions.value_in_unit(distance_unit)
unitless_displacement_sigma = displacement_sigma.value_in_unit(distance_unit)
if self.slice_dim is not None:
displacement_scalar = (
jrandom.normal(subkey, shape=(1,)) * unitless_displacement_sigma
)
updated_position = (x0.at[0, self.slice_dim].add(displacement_scalar),)
else:
displacement_vector = (
jrandom.normal(subkey, shape=(3,)) * unitless_displacement_sigma
)
updated_position = x0 + displacement_vector

log.debug(f"Updated position: {updated_position}")
return unit.Quantity(updated_position, distance_unit)

def _propose_positions(self, initial_positions):
def _propose_positions(self, initial_positions: unit.Quantity):
"""Implement MetropolizedMove._propose_positions for apply()."""
return self.displace_positions(initial_positions, self.displacement_sigma)

Expand Down
42 changes: 39 additions & 3 deletions chiron/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,45 @@ def __init__(
velocities: Optional[unit.Quantity] = None,
box_vectors: Optional[unit.Quantity] = None,
) -> None:
self.x0 = x0
self.velocities = velocities
self.box_vectors = box_vectors
import jax.numpy as jnp

self._distance_unit = x0.unit
self._x0 = x0
self._velocities = velocities
self._box_vectors = box_vectors

@property
def x0(self) -> jnp.array:
return self._convert_to_jnp(self._x0)

@property
def velocities(self) -> jnp.array:
if self._velocities is None:
return None
return self._convert_to_jnp(self._velocities)

@property
def box_vectors(self) -> jnp.array:
if self._box_vectors is None:
return None
return self._convert_to_jnp(self._box_vectors)

@x0.setter
def x0(self, x0: jnp.array) -> None:
self._x0 = unit.Quantity(x0, self._distance_unit)

@property
def distance_unit(self) -> unit.Unit:
return self._distance_unit

def _convert_to_jnp(self, array: unit.Quantity) -> unit.Quantity:
"""
Convert the sampler state to jnp arrays.
"""
import jax.numpy as jnp

array_ = array / self.distance_unit
return unit.Quantity(jnp.array(array_), self.distance_unit)

@property
def x0_unitless(self) -> jnp.ndarray:
Expand Down
2 changes: 2 additions & 0 deletions chiron/tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla
nr_of_moves=10,
displacement_sigma=0.1 * unit.angstrom,
atom_subset=[0],
slice_dim=None,
)

move_set = MoveSet([("MetropolisDisplacementMove", mc_displacement_move)])
Expand All @@ -133,6 +134,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla

# Run the sampler with the thermodynamic state and sampler state and return the sampler state
sampler.run(n_iterations=2) # how many times to repeat
#assert False


def test_sample_from_joint_distribution_of_two_HO_with_local_moves_and_MC_updates():
Expand Down

0 comments on commit db03bff

Please sign in to comment.