Skip to content

Commit

Permalink
starting with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Nov 27, 2023
1 parent c6800fb commit 436aafe
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 10 deletions.
72 changes: 62 additions & 10 deletions chiron/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from openmm import unit
from typing import List
from .potential import Potential
from typing import List, Optional
from .potential import NeuralNetworkPotential
from jax import numpy as jnp
from loguru import logger as log


class SimulationState:
"""
Expand Down Expand Up @@ -39,14 +42,62 @@ class SimulationState:
"""

def __init__(self) -> None:
def __init__(
self,
temperature: Optional[unit.Quantity] = None,
volume: Optional[unit.Quantity] = None,
pressure: Optional[unit.Quantity] = None,
nr_of_particles: Optional[int] = None,
position: Optional[jnp.ndarray] = None,
potential: Optional[NeuralNetworkPotential] = None,
) -> None:
# initialize all state variables
self.temperature: unit.Quantity
self.volume: unit.Quantity
self.pressure: unit.Quantity
self.nr_of_particles: int
self.position: unit.Quantity
self.potential: Potential
self.temperature = temperature
self.volume = volume
self.pressure = pressure
self.nr_of_particles = nr_of_particles
self.position = position
self.potential = potential

# check which variables are not None

self._check_completness()

def check_variables(self):
"""
Check which variables in the __init__ method are None.
Returns
-------
List[str]
A list of variable names that are None.
"""
variables = [
"temperature",
"volume",
"pressure",
"nr_of_particles",
"position",
"potential",
]
set_variables = [var for var in variables if getattr(self, var) is not None]
return set_variables

def _check_completness(self):
# check which variables are set
set_variables = self.check_variables()

if len(set_variables) == 0:
log.info("No variables are set.")

# print all set variables
for var in set_variables:
log.info(f"{var} is set.")

if self.temperature and self.volume and self.nr_of_particles:
log.info("NVT ensemble simulated.")
if self.temperature and self.pressure and self.nr_of_particles:
log.info("NpT ensemble is simulated.")

@classmethod
def are_states_compatible(cls, state1, state2):
Expand Down Expand Up @@ -125,7 +176,8 @@ def get_reduced_potential(self):

class JointSimulationStates:
"""
Manages a collection of SimulationState objects.
Manages a collection of SimulationState objects to define a joint probability distribution
to generate samples from.
"""

def __init__(self, states: List[SimulationState]):
Expand Down
39 changes: 39 additions & 0 deletions chiron/tests/test_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
def test_initialize_state():
from chiron.states import SimulationState
from openmm import unit

state = SimulationState()
assert state.temperature is None
assert state.pressure is None
assert state.volume is None
assert state.nr_of_particles is None
assert state.position is None
assert state.potential is None

state = SimulationState(
temperature=300, volume=30 * (unit.angstrom**3), nr_of_particles=3000
)
assert state.temperature == 300
assert state.pressure is None
assert state.volume == 30 * (unit.angstrom**3)
assert state.nr_of_particles == 3000
assert state.position is None


def test_reduced_potential():
from chiron.states import SimulationState
from openmm import unit
from chiron.potential import HarmonicOscillatorPotential
import jax.numpy as jnp
from openmmtools.testsystems import HarmonicOscillator

state = SimulationState(
temperature=300, volume=30 * (unit.angstrom**3), nr_of_particles=1
)
ho = HarmonicOscillator()

harmonic_potential = HarmonicOscillatorPotential(
ho.K, jnp.array([0, 0, 0]) * unit.angstrom, 0.0
)
state.potential = harmonic_potential
assert state.reduced_potential() == 0.5

0 comments on commit 436aafe

Please sign in to comment.