diff --git a/chiron/mcmc.py b/chiron/mcmc.py index f7fdbf7..52ac712 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -40,48 +40,28 @@ from chiron.potential import NeuralNetworkPotential from openmm import unit from loguru import logger as log +from typing import Dict -class GibbsSampler(object): - """Basic Markov chain Monte Carlo Gibbs sampler. - - Parameters - ---------- - StateVariablesCollection : states.StateVariablesCollection - Defines the states describing the conditional distributions. - move_set : container of MarkovChainMonteCarloMove objects - Moves to attempt during MCMC run. - The move set can be a single move or a sequence of moves. - The moves will define the joint distributions that are sampled. - - """ +from typing import Optional - def __init__(self, state_variables: SimulationState, move_set: MoveSet): - from copy import deepcopy - log.info("Initializing Gibbs sampler") - - # Make a deep copy of the state so that initial state is unchanged. - self.state_variables = deepcopy(state_variables) - self.move = move_set - - def run(self, n_iterations: int = 1): +class StateUpdateMove: + def __init__(self, NeuralNetworkPotential: NeuralNetworkPotential): """ - Run the sampler for a specified number of iterations. + Initialize the MCMove with a molecular system. Parameters ---------- - n_iterations : int - Number of iterations of the sampler to run. - + system : object + A representation of the molecular system (e.g., coordinates, topology). """ - # Apply move for n_iterations. - - -from typing import Optional + # system represents the potential energy function and topology + self.system: Optional[NeuralNetworkPotential] = None + self.NeuralNetworkPotential = NeuralNetworkPotential -class LangevinDynamicsMove: +class LangevinDynamicsMove(StateUpdateMove): def __init__( self, n_steps: int, @@ -103,13 +83,10 @@ def __init__( collision_rate : unit.Quantity Collision rate for the Langevin dynamics. """ - + super().__init__(NeuralNetworkPotential) self.n_steps = n_steps - self.NeuralNetworkPotential = NeuralNetworkPotential self.stepsize = stepsize self.collision_rate = collision_rate - # system represents the potential energy function and topology - self.system: Optional[NeuralNetworkPotential] = None def run( self, @@ -139,22 +116,13 @@ def run( ) -class MoveSet: +class MCMove(StateUpdateMove): def __init__(self) -> None: - pass - - -class MCMove: - def __init__(self, NeuralNetworkPotential: NeuralNetworkPotential): - """ - Initialize the MCMove with a molecular system. + self.system: Optional[NeuralNetworkPotential] = None - Parameters - ---------- - system : object - A representation of the molecular system (e.g., coordinates, topology). - """ - self.system = NeuralNetworkPotential + def _initialize_system(self, state_variables: SimulationState): + if self.system is None: + self.system = self.NeuralNetworkPotential(state_variables) def _check_state_compatiblity( self, @@ -189,6 +157,7 @@ def apply_move(self): NotImplementedError If the method is not implemented in subclasses. """ + raise NotImplementedError("apply_move() must be implemented in subclasses") def compute_acceptance_probability( @@ -212,10 +181,8 @@ def compute_acceptance_probability( Acceptance probability as a float. """ self._check_state_compatiblity(old_state, new_state) - old_system = self.NeuronNetworkPotential(old_state) - new_system = self.NeuronNetworkPotential(new_state).compute_energy( - new_state.position - ) + old_system = self.system(old_state) + new_system = self.system(new_state) energy_before_state_change = old_system.compute_energy(old_state.position) enegy_after_state_change = new_system.compute_energy(new_state.position) @@ -271,3 +238,65 @@ def apply_move(self): Implement the logic specific to a MC position change. """ pass + + +class MoveSet: + """ + A container for a set of moves. + The moves will define the joint distributions that are sampled. + """ + + def __init__( + self, availalbe_moves: Dict[str, StateUpdateMove], move_schedule: Dict[str, int] + ) -> None: + self.availalbe_moves = availalbe_moves + self.move_schedule = move_schedule + + self._check_completness() + + def _check_completness(self): + for key in self.availalbe_moves.keys(): + if key not in self.move_schedule.keys(): + raise ValueError(f"Move {key} is not in the move schedule.") + + def add_move(self, new_moves: Dict[str, MCMove]): + self.availalbe_moves.update(new_moves) + + def remove_move(self, move_name: str): + del self.availalbe_moves[move_name] + + +class GibbsSampler(object): + """Basic Markov chain Monte Carlo Gibbs sampler. + + Parameters + ---------- + StateVariablesCollection : states.StateVariablesCollection + Defines the states describing the conditional distributions. + move_set : container of MarkovChainMonteCarloMove objects + Moves to attempt during MCMC run. + The move set can be a single move or a sequence of moves. + The moves will define the joint distributions that are sampled. + + """ + + def __init__(self, state_variables: SimulationState, move_set: MoveSet): + from copy import deepcopy + + log.info("Initializing Gibbs sampler") + + # Make a deep copy of the state so that initial state is unchanged. + self.state_variables = deepcopy(state_variables) + self.move = move_set + + def run(self, ): + """ + Run the sampler for a specified number of iterations. + + Parameters + ---------- + n_iterations : int + Number of iterations of the sampler to run. + + """ + # Apply move for n_iterations. diff --git a/chiron/states.py b/chiron/states.py index 0a632e7..e486132 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -1,6 +1,6 @@ from openmm import unit from typing import List - +from .potential import Potential class SimulationState: """ @@ -46,6 +46,7 @@ def __init__(self) -> None: self.pressure: unit.Quantity self.nr_of_particles: int self.position: unit.Quantity + self.potential: Potential @classmethod def are_states_compatible(cls, state1, state2):