diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index 93f4415..e2d66ca 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -39,8 +39,6 @@ nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) -from chiron.neighbors import PairList - # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) diff --git a/chiron/integrators.py b/chiron/integrators.py index 1ac682d..e94aed7 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -12,6 +12,17 @@ class LangevinIntegrator: + """ + Langevin dynamics integrator for molecular dynamics simulation using the BAOAB splitting scheme [1]. + + References: + [1] Benedict Leimkuhler, Charles Matthews; + Robust and efficient configurational molecular sampling via Langevin dynamics. + J. Chem. Phys. 7 May 2013; 138 (17): 174102. https://doi.org/10.1063/1.4802990 + + + """ + def __init__( self, stepsize=1.0 * unit.femtoseconds, @@ -25,14 +36,20 @@ def __init__( Parameters ---------- stepsize : unit.Quantity, optional - Time step size for the integration. + Time step of integration with units of time. Default is 1.0 * unit.femtoseconds. collision_rate : unit.Quantity, optional - Collision rate for the Langevin dynamics. + Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds. + save_frequency : int, optional + Frequency of saving the simulation data. Default is 100. + reporter : SimulationReporter, optional + Reporter object for saving the simulation data. Default is None. """ self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA log.info(f"stepsize = {stepsize}") log.info(f"collision_rate = {collision_rate}") + log.info(f"save_frequency = {save_frequency}") + self.stepsize = stepsize self.collision_rate = collision_rate if reporter is not None: @@ -40,6 +57,7 @@ def __init__( self.reporter = reporter self.save_frequency = save_frequency + self.velocities = None def set_velocities(self, vel: unit.Quantity) -> None: """ Set the initial velocities for the Langevin Integrator. @@ -73,6 +91,8 @@ def run( Number of simulation steps to perform. key : jax.random.PRNGKey, optional Random key for generating random numbers. + nbr_list : NeighborListNsqrd, optional + Neighbor list for the system. progress_bar : bool, optional Flag indicating whether to display a progress bar during integration. @@ -85,7 +105,6 @@ def run( self.box_vectors = sampler_state.box_vectors self.progress_bar = progress_bar - self.velocities = None temperature = thermodynamic_state.temperature x0 = sampler_state.x0 diff --git a/chiron/mcmc.py b/chiron/mcmc.py index b649afa..20efebd 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -412,6 +412,7 @@ def apply( log.debug(f"Initial positions are {initial_positions} nm.") # Propose perturbed positions. Modifying the reference changes the sampler state. proposed_positions = self._propose_positions(initial_positions) + log.debug(f"Proposed positions are {proposed_positions} nm.") # Compute the energy of the proposed positions. if atom_subset is None: @@ -420,6 +421,10 @@ def apply( sampler_state.x0 = sampler_state.x0.at[jnp.array(atom_subset)].set( proposed_positions ) + if nbr_list is not None: + if nbr_list.check(sampler_state.x0): + nbr_list.build(sampler_state.x0, sampler_state.box_vectors) + proposed_energy = thermodynamic_state.get_reduced_potential( sampler_state, nbr_list ) # NOTE: in kT diff --git a/chiron/neighbors.py b/chiron/neighbors.py index a8c7bbe..69a9502 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -14,9 +14,35 @@ class Space(ABC): - def __init__(self, box_vectors: Optional[jnp.array] = None) -> None: + def __init__( + self, box_vectors: Union[jnp.array, unit.Quantity, None] = None + ) -> None: + """ + Abstract base class for defining the simulation space. + + Parameters + ---------- + box_vectors: jnp.array, optional + Box vectors for the system. + """ if box_vectors is not None: - self.box_vectors = box_vectors + if isinstance(box_vectors, unit.Quantity): + if not box_vectors.unit.is_compatible(unit.nanometer): + raise ValueError( + f"Box vectors require distance unit, not {box_vectors.unit}" + ) + self.box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) + elif isinstance(box_vectors, jnp.ndarray): + if box_vectors.shape != (3, 3): + raise ValueError( + f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" + ) + + self.box_vectors = box_vectors + else: + raise TypeError( + f"box_vectors must be a jnp.array or unit.Quantity, not {type(box_vectors)}" + ) @property def box_vectors(self) -> jnp.array: @@ -39,21 +65,10 @@ def wrap(self, xyz: jnp.array) -> jnp.array: class OrthogonalPeriodicSpace(Space): """ - Calculate the periodic distance between two points. + Defines the simulation space for an orthogonal periodic system. - Returns - ------- - Callable - Function that calculates the periodic displacement and distance between two points """ - def __init__(self, box_vectors: Optional[jnp.array] = None) -> None: - super().__init__(box_vectors) - if box_vectors is not None: - self.box_lengths = jnp.array( - [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] - ) - @property def box_vectors(self) -> jnp.array: return self._box_vectors @@ -176,7 +191,7 @@ def wrap(self, xyz: jnp.array) -> jnp.array: class PairsBase(ABC): """ - Base class for pairlist implementations that returns the particle pair ids, displacement vectors, and distances. + Abstract Base Class for different algorithms that determine which particles are interacting. Parameters ---------- @@ -185,7 +200,26 @@ class PairsBase(ABC): cutoff: float, default = 2.5 Cutoff distance for the neighborlist - Examples""" + Examples + -------- + >>> from chiron.neighbors import PairsBase, OrthogonalPeriodicSpace + >>> from chiron.states import SamplerState + >>> from openmm as unit + >>> import jax.numpy as jnp + >>> + >>> space = OrthogonalPeriodicSpace() # define the simulation space, in this case an orthogonal periodic space + >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) + >>> + >>> pair_list = PairsBase(space, cutoff=2.5*unit.nanometer) # initialize the pair list + >>> pair_list.build_from_state(sampler_state) # build the pair list from the sampler state + >>> + >>> coordinates = sampler_state.x0 # get the coordinates from the sampler state, without units attached + >>> + >>> # the calculate function will produce information used to calculate the energy + >>> n_neighbors, padding_mask, dist, r_ij = pair_list.calculate(coordinates) + >>> + """ def __init__( self, @@ -208,14 +242,16 @@ def build( box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build the neighborlist from an array of coordinates and box vectors. + Build list from an array of coordinates and array of box vectors. Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates - box_vectors: jnp.array - Shape[3,3] array of box vectors + coordinates: jnp.array or unit.Quantity + Shape[n_particles,3] array of particle coordinates, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + box_vectors: jnp.array or unit.Quantity + Shape[3,3] array of box vectors for the system, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. Returns ------- @@ -224,9 +260,43 @@ def build( """ pass + def _validate_build_inputs( + self, + coordinates: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity], + ): + """ + Validate the inputs to the build function. + """ + if isinstance(coordinates, unit.Quantity): + if not coordinates.unit.is_compatible(unit.nanometer): + raise ValueError( + f"Coordinates require distance units, not {coordinates.unit}" + ) + self.ref_coordinates = coordinates.value_in_unit_system(unit.md_unit_system) + if isinstance(coordinates, jnp.ndarray): + if coordinates.shape[1] != 3: + raise ValueError( + f"coordinates should be a Nx3 array, shape provided: {coordinates.shape}" + ) + self.ref_coordinates = coordinates + if isinstance(box_vectors, unit.Quantity): + if not box_vectors.unit.is_compatible(unit.nanometer): + raise ValueError( + f"Box vectors require distance unit, not {box_vectors.unit}" + ) + self.box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) + + if isinstance(box_vectors, jnp.ndarray): + if box_vectors.shape != (3, 3): + raise ValueError( + f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" + ) + self.box_vectors = box_vectors + def build_from_state(self, sampler_state: SamplerState): """ - Build the neighbor list from a SamplerState object + Build the list from a SamplerState object Parameters ---------- @@ -261,6 +331,9 @@ def calculate(self, coordinates: jnp.array): ------- n_neighbors: jnp.array Array of number of neighbors for each particle + pairs: jnp.array + Array of particle ids for the possible neighbors of each particle. + The size of this array will depend on the underlying algorithm. padding_mask: jnp.array Array of masks to exclude padding from the neighbor list of each particle dist: jnp.array @@ -423,13 +496,11 @@ def _build_neighborlist( # since this needs to be uniformly sized, we can just fill this array up to the n_max_neighbors. neighbor_list = jnp.array( jnp.where(neighbor_mask, size=n_max_neighbors, fill_value=fill_value), - dtype=jnp.uint16, + dtype=jnp.uint32, ) # we need to generate a new mask associatd with the padded neighbor list # to be able to quickly exclude the padded values from the neighbor list - neighbor_list_mask = jnp.where( - jnp.arange(self.n_max_neighbors) < n_neighbors, 1, 0 - ) + neighbor_list_mask = jnp.where(jnp.arange(n_max_neighbors) < n_neighbors, 1, 0) del r_ij, dist return neighbor_list_mask, neighbor_list, n_neighbors @@ -485,7 +556,7 @@ def build( # store the ids of all the particles self.particle_ids = jnp.array( - range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint16 + range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint32 ) # calculate which pairs to exclude @@ -510,11 +581,12 @@ def build( self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) - if jnp.any(self.n_neighbors == self.n_max_neighbors).block_until_ready(): - self.n_max_neighbors = int(jnp.max(self.n_neighbors) + 10) + while jnp.any(self.n_neighbors == self.n_max_neighbors).block_until_ready(): log.debug( f"Increasing n_max_neighbors from {self.n_max_neighbors} to at {jnp.max(self.n_neighbors)+10}" ) + self.n_max_neighbors = int(jnp.max(self.n_neighbors) + 10) + self.neighbor_mask, self.neighbor_list, self.n_neighbors = jax.vmap( self._build_neighborlist, in_axes=(0, 0, 0, None, None) )( @@ -588,12 +660,14 @@ def calculate(self, coordinates: jnp.array): ------- n_neighbors: jnp.array Array of number of neighbors for each particle + neighbor_list: jnp.array + Array of particle ids for the neighbors, padded to n_max_neighbors. Shape (n_particles, n_max_neighbors) padding_mask: jnp.array - Array of masks to exclude padding from the neighbor list of each particle + Array of masks to exclude padding from the neighbor list of each particle. Shape (n_particles, n_max_neighbors) dist: jnp.array - Array of distances between each particle and its neighbors + Array of distances between each particle and its neighbors. Shape (n_particles, n_max_neighbors) r_ij: jnp.array - Array of displacement vectors between each particle and its neighbors + Array of displacement vectors between each particle and its neighbors. Shape (n_particles, n_max_neighbors, 3) """ # coordinates = sampler_state.x0 # note, we assume the box vectors do not change between building and calculating the neighbor list @@ -603,7 +677,7 @@ def calculate(self, coordinates: jnp.array): self._calc_distance_per_particle, in_axes=(0, 0, 0, None) )(self.particle_ids, self.neighbor_list, self.neighbor_mask, coordinates) # mask = mask.reshape(-1, self.n_max_neighbors) - return n_neighbors, padding_mask, dist, r_ij + return n_neighbors, self.neighbor_list, padding_mask, dist, r_ij @partial(jax.jit, static_argnums=(0,)) def _calculate_particle_displacement(self, particle, coordinates, ref_coordinates): @@ -740,6 +814,7 @@ def _pairs_and_mask(self, particle_ids: jnp.array): """ # for the nsq approach, we consider the distance between a particle and all other particles in the system # if we used a cell list the possible_neighbors would be a smaller list, i.e., only those in the neigboring cells + # we'll just keep with naming syntax for future flexibility possible_neighbors = particle_ids @@ -747,19 +822,19 @@ def _pairs_and_mask(self, particle_ids: jnp.array): possible_neighbors, (particle_ids.shape[0], possible_neighbors.shape[0]), ) - # reshape the particle_ids particles_i = jnp.reshape(particle_ids, (particle_ids.shape[0], 1)) # create a mask to exclude self interactions and double counting temp_mask = particles_i != particles_j - all_pairs = jax.vmap(self._remove_self_interactions, in_axes=(0, 0))( particles_j, temp_mask ) + del temp_mask + all_pairs = jnp.array(all_pairs[0], dtype=jnp.uint32) - reduction_mask = jnp.where(particles_i < all_pairs[0], True, False) + reduction_mask = jnp.where(particles_i < all_pairs, True, False) - return all_pairs[0], reduction_mask + return all_pairs, reduction_mask @partial(jax.jit, static_argnums=(0,)) def _remove_self_interactions(self, particles, temp_mask): @@ -789,35 +864,17 @@ def build( """ # set our reference coordinates - # the call to x0 and box_vectors automatically convert these to jnp arrays in the correct unit system - if isinstance(coordinates, unit.Quantity): - if not coordinates.unit.is_compatible(unit.nanometer): - raise ValueError( - f"Coordinates require distance units, not {coordinates.unit}" - ) - coordinates = coordinates.value_in_unit_system(unit.md_unit_system) - - if isinstance(box_vectors, unit.Quantity): - if not box_vectors.unit.is_compatible(unit.nanometer): - raise ValueError( - f"Box vectors require distance unit, not {box_vectors.unit}" - ) - box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) - - if box_vectors.shape != (3, 3): - raise ValueError( - f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" - ) + # this will set self.ref_coordinates=coordinates and self.box_vectors + self._validate_build_inputs(coordinates, box_vectors) - self.n_particles = coordinates.shape[0] - self.box_vectors = box_vectors + self.n_particles = self.ref_coordinates.shape[0] # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list # changes to the box vectors require rebuilding the neighbor list self.space.box_vectors = self.box_vectors # store the ids of all the particles - self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint16) + self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint32) # calculate which pairs to exclude self.all_pairs, self.reduction_mask = self._pairs_and_mask(self.particle_ids) @@ -886,13 +943,15 @@ def calculate(self, coordinates: jnp.array): Returns ------- n_neighbors: jnp.array - Array of number of interacting particles for each particle + Array of the number of interacting particles (i.e., where dist < cutoff). Shape: (n_particles) + pairs: jnp.array + Array of particle ids that were considered for interaction. Shape: (n_particles, n_particles-1) padding_mask: jnp.array - Array used to masks non interaction particle pairs, + Array used to masks non interaction particle pairs. Shape: (n_particles, n_particles-1) dist: jnp.array - Array of distances between each particle and all other particles in the system + Array of distances between pairs in the system. Shape: (n_particles, n_particles-1) r_ij: jnp.array - Array of displacement vectors between each particle and all other particles in the system. + Array of displacement vectors between particle pairs. Shape: (n_particles, n_particles-1, 3). """ if coordinates.shape[0] != self.n_particles: raise ValueError( @@ -906,7 +965,7 @@ def calculate(self, coordinates: jnp.array): self._calc_distance_per_particle, in_axes=(0, 0, 0, None) )(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates) - return n_neighbors, padding_mask, dist, r_ij + return n_neighbors, self.all_pairs, padding_mask, dist, r_ij def check(self, coordinates: jnp.array) -> bool: """ diff --git a/chiron/potential.py b/chiron/potential.py index abb3a03..b5982f0 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -169,7 +169,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): if nbr_list is None: log.debug( - "nbr_list is None, computing pairlist using N^2 method without PBC." + "nbr_list is None, computing using inefficient N^2 pairlist without PBC." ) # Compute the pairlist for a given set of positions and a cutoff distance # Note in this case, we do not need the pairs or displacement vectors @@ -204,7 +204,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): f"Neighborlist cutoff ({nbr_list.cutoff}) must be the same as the potential cutoff ({self.cutoff})" ) - n_neighbors, mask, dist, displacement_vectors = nbr_list.calculate( + n_neighbors, pairs, mask, dist, displacement_vectors = nbr_list.calculate( positions ) @@ -235,17 +235,16 @@ def compute_force(self, positions: jnp.array, nbr_list=None) -> jnp.array: return super().compute_force(positions, nbr_list=nbr_list) def compute_force_analytical( - self, positions: jnp.array, nbr_list=None + self, + positions: jnp.array, ) -> jnp.array: """ - Compute the LJ force using the analytical expression. + Compute the LJ force using the analytical expression for testing purposes. Parameters ---------- positions : jnp.array The positions of the particles in the system - nbr_list : NeighborList, optional - Instance of the neighborlist class to use. By default, set to None, which will use an N^2 pairlist Returns ------- @@ -253,22 +252,19 @@ def compute_force_analytical( The forces on the particles in the system """ - if nbr_list is None: - dist, displacement_vector, pairs = self.compute_pairlist( - positions, self.cutoff - ) - - forces = ( - 24 - * (self.epsilon / (dist * dist)) - * (2 * (self.sigma / dist) ** 12 - (self.sigma / dist) ** 6) - ).reshape(-1, 1) * displacement_vector - - force_array = jnp.zeros((positions.shape[0], 3)) - for force, p1, p2 in zip(forces, pairs[0], pairs[1]): - force_array = force_array.at[p1].add(force) - force_array = force_array.at[p2].add(-force) - return force_array + dist, displacement_vector, pairs = self.compute_pairlist(positions, self.cutoff) + + forces = ( + 24 + * (self.epsilon / (dist * dist)) + * (2 * (self.sigma / dist) ** 12 - (self.sigma / dist) ** 6) + ).reshape(-1, 1) * displacement_vector + + force_array = jnp.zeros((positions.shape[0], 3)) + for force, p1, p2 in zip(forces, pairs[0], pairs[1]): + force_array = force_array.at[p1].add(force) + force_array = force_array.at[p2].add(-force) + return force_array class HarmonicOscillatorPotential(NeuralNetworkPotential): diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index 79bea7f..4802bb6 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -12,12 +12,32 @@ def test_orthogonal_periodic_displacement(): + # test that the incorrect box shapes throw an exception + with pytest.raises(ValueError): + space = OrthogonalPeriodicSpace(jnp.array([10.0, 10.0, 10.0])) + # test that incorrect units throw an exception + with pytest.raises(ValueError): + space = OrthogonalPeriodicSpace( + unit.Quantity( + jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), + unit.radians, + ) + ) + space = OrthogonalPeriodicSpace( jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) ) + # test that the box vectors are set correctly + assert jnp.all( + space.box_vectors + == jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + ) + + # test that the box lengths for an orthogonal box are set correctly assert jnp.all(space._box_lengths == jnp.array([10.0, 10.0, 10.0])) + # test calculation of the displacement_vector and distance between two points p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) @@ -27,6 +47,7 @@ def test_orthogonal_periodic_displacement(): assert jnp.all(distance == jnp.array([1, 4])) + # test that the periodic wrapping works as expected wrapped_x = space.wrap(jnp.array([11, 0, 0])) assert jnp.all(wrapped_x == jnp.array([1, 0, 0])) @@ -39,6 +60,7 @@ def test_orthogonal_periodic_displacement(): wrapped_x = space.wrap(jnp.array([5, 12, -1])) assert jnp.all(wrapped_x == jnp.array([5, 2, 9])) + # test the setter for the box vectors space.box_vectors = jnp.array( [[10.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 30.0]] ) @@ -116,8 +138,16 @@ def test_neighborlist_pair(): nbr_list.neighbor_mask == jnp.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) ) - n_neighbors, padding_mask, dist, r_ij = nbr_list.calculate(coordinates) + n_neighbors, neighbor_list, padding_mask, dist, r_ij = nbr_list.calculate( + coordinates + ) assert jnp.all(n_neighbors == jnp.array([1, 0])) + + # 2 particles, padded to 5 + assert jnp.all(neighbor_list.shape == (2, 5)) + + assert jnp.all(neighbor_list == jnp.array([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0]])) + assert jnp.all(padding_mask == jnp.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0]])) assert jnp.all( @@ -260,7 +290,7 @@ def test_neighborlist_pair_multiple_particles(): assert jnp.all(nbr_list.n_neighbors == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) - n_interacting, mask, dist, rij = nbr_list.calculate(coordinates) + n_interacting, neighbor_list, mask, dist, rij = nbr_list.calculate(coordinates) assert jnp.all(n_interacting == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) # every particle should be in the nieghbor list, but only a subset in the interacting range @@ -276,14 +306,31 @@ def test_neighborlist_pair_multiple_particles(): assert jnp.all(nbr_list.n_neighbors == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) - n_interacting, mask, dist, rij = nbr_list.calculate(coordinates) + n_interacting, neighbor_list, mask, dist, rij = nbr_list.calculate(coordinates) assert jnp.all(n_interacting == jnp.array([3, 2, 2, 1, 2, 1, 1, 0])) + assert jnp.all(neighbor_list.shape == (8, 17)) + assert jnp.all( + neighbor_list + == jnp.array( + [ + [1, 2, 3, 4, 5, 6, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [2, 3, 4, 5, 6, 7, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + [3, 4, 5, 6, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + [4, 5, 6, 7, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [5, 6, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [6, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], + [7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + ) + # test passing coordinates and box vectors directly nbr_list.build(state.x0, state.box_vectors) assert jnp.all(nbr_list.n_neighbors == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) - n_interacting, mask, dist, rij = nbr_list.calculate(coordinates) + n_interacting, neighbor_list, mask, dist, rij = nbr_list.calculate(coordinates) assert jnp.all(n_interacting == jnp.array([3, 2, 2, 1, 2, 1, 1, 0])) @@ -314,9 +361,11 @@ def test_pairlist_pair(): assert jnp.all(pair_list.reduction_mask == jnp.array([[True], [False]])) assert pair_list.is_built == True - n_pairs, mask, dist, displacement = pair_list.calculate(coordinates) + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(coordinates) assert jnp.all(n_pairs == jnp.array([1, 0])) + assert jnp.all(all_pairs.shape == (2, 1)) + assert jnp.all(all_pairs == jnp.array([[1], [0]])) assert jnp.all(mask == jnp.array([[1], [0]])) assert jnp.all(dist == jnp.array([[1.0], [1.0]])) assert displacement.shape == (2, 1, 3) @@ -356,9 +405,24 @@ def test_pair_list_multiple_particles(): ) pair_list.build_from_state(state) - n_interacting, mask, dist, rij = pair_list.calculate(coordinates) + n_interacting, all_pairs, mask, dist, rij = pair_list.calculate(coordinates) assert jnp.all(n_interacting == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) - + assert jnp.all(all_pairs.shape == (8, 7)) + assert jnp.all( + all_pairs + == jnp.array( + [ + [1, 2, 3, 4, 5, 6, 7], + [0, 2, 3, 4, 5, 6, 7], + [0, 1, 3, 4, 5, 6, 7], + [0, 1, 2, 4, 5, 6, 7], + [0, 1, 2, 3, 5, 6, 7], + [0, 1, 2, 3, 4, 6, 7], + [0, 1, 2, 3, 4, 5, 7], + [0, 1, 2, 3, 4, 5, 6], + ] + ) + ) assert jnp.all(mask.shape == (coordinates.shape[0], coordinates.shape[0] - 1)) # compare to nbr_list @@ -369,7 +433,7 @@ def test_pair_list_multiple_particles(): n_max_neighbors=20, ) nbr_list.build_from_state(state) - n_interacting1, mask1, dist1, rij1 = nbr_list.calculate(coordinates) + n_interacting1, all_pairs, mask1, dist1, rij1 = nbr_list.calculate(coordinates) # sum up all the distances within range, see if they match those in the nlist assert jnp.where(mask, dist, 0).sum() == jnp.where(mask1, dist1, 0).sum()