Skip to content

Commit

Permalink
Add a test for chemfiles version check
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Apr 8, 2021
1 parent b285d13 commit b3b869d
Showing 1 changed file with 56 additions and 31 deletions.
87 changes: 56 additions & 31 deletions testsuite/MDAnalysisTests/coordinates/test_chemfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,45 @@
#
import numpy as np
import pytest
import unittest

import MDAnalysis as mda
from MDAnalysis.coordinates.chemfiles import (
ChemfilesReader, ChemfilesWriter, check_chemfiles_version,
)
from MDAnalysis.coordinates.chemfiles import ChemfilesReader, ChemfilesWriter
from MDAnalysis.coordinates.chemfiles import check_chemfiles_version

from MDAnalysisTests import datafiles
from MDAnalysisTests.coordinates.base import (
MultiframeReaderTest, BaseWriterTest, BaseReference
MultiframeReaderTest,
BaseWriterTest,
BaseReference,
)
from MDAnalysisTests.coordinates.test_xyz import XYZReference


# skip entire test module if no appropriate chemfiles
chemfiles = pytest.importorskip('chemfiles')
@pytest.mark.skipif(not check_chemfiles_version(),
reason="Wrong version of chemfiles")
class TestChemFileXYZ(MultiframeReaderTest):
chemfiles = pytest.importorskip("chemfiles")


class TestChemfileVersion(unittest.TestCase):
def test_version_check(self):
# Make sure the version check works as intended
import chemfiles

actual_version = chemfiles.__version__
chemfiles.__version__ = "0.8.3"
self.assertFalse(check_chemfiles_version())

chemfiles.__version__ = "0.11.0"
self.assertFalse(check_chemfiles_version())

chemfiles.__version__ = "1.1.0"
self.assertFalse(check_chemfiles_version())

chemfiles.__version__ = actual_version


@pytest.mark.skipif(not check_chemfiles_version(), reason="Wrong version of chemfiles")
class TestChemfileXYZ(MultiframeReaderTest):
@staticmethod
@pytest.fixture
def ref():
Expand All @@ -51,36 +72,45 @@ def ref():
@pytest.fixture
def reader(self, ref):
reader = ChemfilesReader(ref.trajectory)
reader.add_auxiliary('lowf', ref.aux_lowf, dt=ref.aux_lowf_dt, initial_time=0, time_selector=None)
reader.add_auxiliary('highf', ref.aux_highf, dt=ref.aux_highf_dt, initial_time=0, time_selector=None)
reader.add_auxiliary(
"lowf",
ref.aux_lowf,
dt=ref.aux_lowf_dt,
initial_time=0,
time_selector=None,
)
reader.add_auxiliary(
"highf",
ref.aux_highf,
dt=ref.aux_highf_dt,
initial_time=0,
time_selector=None,
)
return reader



class ChemfilesXYZReference(BaseReference):
def __init__(self):
super(ChemfilesXYZReference, self).__init__()
self.trajectory = datafiles.COORDINATES_XYZ
self.topology = datafiles.COORDINATES_XYZ
self.reader = ChemfilesReader
self.writer = ChemfilesWriter
self.ext = 'xyz'
self.ext = "xyz"
self.volume = 0
self.dimensions = np.zeros(6)
self.dimensions[3:] = 90.0


@pytest.mark.skipif(not check_chemfiles_version(),
reason="Wrong version of chemfiles")
@pytest.mark.skipif(not check_chemfiles_version(), reason="Wrong version of chemfiles")
class TestChemfilesReader(MultiframeReaderTest):
@staticmethod
@pytest.fixture()
def ref():
return ChemfilesXYZReference()


@pytest.mark.skipif(not check_chemfiles_version(),
reason="Wrong version of chemfiles")
@pytest.mark.skipif(not check_chemfiles_version(), reason="Wrong version of chemfiles")
class TestChemfilesWriter(BaseWriterTest):
@staticmethod
@pytest.fixture()
Expand All @@ -94,18 +124,17 @@ def test_no_container(self, ref):

def test_no_extension_raises(self, ref):
with pytest.raises(chemfiles.ChemfilesError):
ref.writer('foo')
ref.writer("foo")


@pytest.mark.skipif(not check_chemfiles_version(),
reason="Wrong version of chemfiles")
@pytest.mark.skipif(not check_chemfiles_version(), reason="Wrong version of chemfiles")
class TestChemfiles(object):
def test_read_chemfiles_format(self):
u = mda.Universe(
datafiles.LAMMPSdata,
format="chemfiles",
topology_format="data",
chemfiles_format="LAMMPS Data"
chemfiles_format="LAMMPS Data",
)

for ts in u.trajectory:
Expand All @@ -128,19 +157,13 @@ def test_wrong_open_mode(self):

def check_topology(self, reference, file):
u = mda.Universe(reference)
atoms = set([
(atom.name, atom.type, atom.record_type)
for atom in u.atoms
])
bonds = set([
(bond.atoms[0].ix, bond.atoms[1].ix)
for bond in u.bonds
])
atoms = set([(atom.name, atom.type, atom.record_type) for atom in u.atoms])
bonds = set([(bond.atoms[0].ix, bond.atoms[1].ix) for bond in u.bonds])

check = mda.Universe(file)
np.testing.assert_equal(
u.trajectory.ts.positions,
check.trajectory.ts.positions
check.trajectory.ts.positions,
)

for atom in check.atoms:
Expand Down Expand Up @@ -174,7 +197,7 @@ def test_write_topology(self, tmpdir):

def test_write_velocities(self, tmpdir):
u = mda.Universe.empty(4, trajectory=True)
u.add_TopologyAttr('type', values=['H', 'H', 'H', 'H'])
u.add_TopologyAttr("type", values=["H", "H", "H", "H"])

ts = u.trajectory.ts
ts.dimensions = [20, 30, 41, 90, 90, 90]
Expand All @@ -193,7 +216,9 @@ def test_write_velocities(self, tmpdir):

outfile = "chemfiles-write-velocities.lmp"
with tmpdir.as_cwd():
with ChemfilesWriter(outfile, topology=u, chemfiles_format='LAMMPS Data') as writer:
with ChemfilesWriter(
outfile, topology=u, chemfiles_format="LAMMPS Data"
) as writer:
writer.write(u)

with open(outfile) as file:
Expand Down

0 comments on commit b3b869d

Please sign in to comment.