Skip to content

Commit

Permalink
adding MoveSet
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Nov 27, 2023
1 parent 21112ee commit c6800fb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 54 deletions.
135 changes: 82 additions & 53 deletions chiron/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,48 +40,28 @@
from chiron.potential import NeuralNetworkPotential
from openmm import unit
from loguru import logger as log
from typing import Dict


class GibbsSampler(object):
"""Basic Markov chain Monte Carlo Gibbs sampler.
Parameters
----------
StateVariablesCollection : states.StateVariablesCollection
Defines the states describing the conditional distributions.
move_set : container of MarkovChainMonteCarloMove objects
Moves to attempt during MCMC run.
The move set can be a single move or a sequence of moves.
The moves will define the joint distributions that are sampled.
"""
from typing import Optional

def __init__(self, state_variables: SimulationState, move_set: MoveSet):
from copy import deepcopy

log.info("Initializing Gibbs sampler")

# Make a deep copy of the state so that initial state is unchanged.
self.state_variables = deepcopy(state_variables)
self.move = move_set

def run(self, n_iterations: int = 1):
class StateUpdateMove:
def __init__(self, NeuralNetworkPotential: NeuralNetworkPotential):
"""
Run the sampler for a specified number of iterations.
Initialize the MCMove with a molecular system.
Parameters
----------
n_iterations : int
Number of iterations of the sampler to run.
system : object
A representation of the molecular system (e.g., coordinates, topology).
"""
# Apply move for n_iterations.


from typing import Optional
# system represents the potential energy function and topology
self.system: Optional[NeuralNetworkPotential] = None
self.NeuralNetworkPotential = NeuralNetworkPotential


class LangevinDynamicsMove:
class LangevinDynamicsMove(StateUpdateMove):
def __init__(
self,
n_steps: int,
Expand All @@ -103,13 +83,10 @@ def __init__(
collision_rate : unit.Quantity
Collision rate for the Langevin dynamics.
"""

super().__init__(NeuralNetworkPotential)
self.n_steps = n_steps
self.NeuralNetworkPotential = NeuralNetworkPotential
self.stepsize = stepsize
self.collision_rate = collision_rate
# system represents the potential energy function and topology
self.system: Optional[NeuralNetworkPotential] = None

def run(
self,
Expand Down Expand Up @@ -139,22 +116,13 @@ def run(
)


class MoveSet:
class MCMove(StateUpdateMove):
def __init__(self) -> None:
pass


class MCMove:
def __init__(self, NeuralNetworkPotential: NeuralNetworkPotential):
"""
Initialize the MCMove with a molecular system.
self.system: Optional[NeuralNetworkPotential] = None

Parameters
----------
system : object
A representation of the molecular system (e.g., coordinates, topology).
"""
self.system = NeuralNetworkPotential
def _initialize_system(self, state_variables: SimulationState):
if self.system is None:
self.system = self.NeuralNetworkPotential(state_variables)

def _check_state_compatiblity(
self,
Expand Down Expand Up @@ -189,6 +157,7 @@ def apply_move(self):
NotImplementedError
If the method is not implemented in subclasses.
"""

raise NotImplementedError("apply_move() must be implemented in subclasses")

def compute_acceptance_probability(
Expand All @@ -212,10 +181,8 @@ def compute_acceptance_probability(
Acceptance probability as a float.
"""
self._check_state_compatiblity(old_state, new_state)
old_system = self.NeuronNetworkPotential(old_state)
new_system = self.NeuronNetworkPotential(new_state).compute_energy(
new_state.position
)
old_system = self.system(old_state)
new_system = self.system(new_state)

energy_before_state_change = old_system.compute_energy(old_state.position)
enegy_after_state_change = new_system.compute_energy(new_state.position)
Expand Down Expand Up @@ -271,3 +238,65 @@ def apply_move(self):
Implement the logic specific to a MC position change.
"""
pass


class MoveSet:
"""
A container for a set of moves.
The moves will define the joint distributions that are sampled.
"""

def __init__(
self, availalbe_moves: Dict[str, StateUpdateMove], move_schedule: Dict[str, int]
) -> None:
self.availalbe_moves = availalbe_moves
self.move_schedule = move_schedule

self._check_completness()

def _check_completness(self):
for key in self.availalbe_moves.keys():
if key not in self.move_schedule.keys():
raise ValueError(f"Move {key} is not in the move schedule.")

def add_move(self, new_moves: Dict[str, MCMove]):
self.availalbe_moves.update(new_moves)

def remove_move(self, move_name: str):
del self.availalbe_moves[move_name]


class GibbsSampler(object):
"""Basic Markov chain Monte Carlo Gibbs sampler.
Parameters
----------
StateVariablesCollection : states.StateVariablesCollection
Defines the states describing the conditional distributions.
move_set : container of MarkovChainMonteCarloMove objects
Moves to attempt during MCMC run.
The move set can be a single move or a sequence of moves.
The moves will define the joint distributions that are sampled.
"""

def __init__(self, state_variables: SimulationState, move_set: MoveSet):
from copy import deepcopy

log.info("Initializing Gibbs sampler")

# Make a deep copy of the state so that initial state is unchanged.
self.state_variables = deepcopy(state_variables)
self.move = move_set

def run(self, ):
"""
Run the sampler for a specified number of iterations.
Parameters
----------
n_iterations : int
Number of iterations of the sampler to run.
"""
# Apply move for n_iterations.
3 changes: 2 additions & 1 deletion chiron/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from openmm import unit
from typing import List

from .potential import Potential

class SimulationState:
"""
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(self) -> None:
self.pressure: unit.Quantity
self.nr_of_particles: int
self.position: unit.Quantity
self.potential: Potential

@classmethod
def are_states_compatible(cls, state1, state2):
Expand Down

0 comments on commit c6800fb

Please sign in to comment.