Skip to content

Commit

Permalink
Guard MSONAtoms definition behind ASE package availability (#3645)
Browse files Browse the repository at this point in the history
* Guard `MSONAtoms` definition behind ASE package availability

Add mypy exclude
* provide fallback Atoms class that just raises if ase not found and revert MSONAtoms implementation

* add skip_if_no_ase decorator to individual test functions instead of whole test_ase.py module

* Add test case for no ASE package error

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
ml-evs and janosh authored Feb 23, 2024
1 parent 5b5ea57 commit 5ce2ae1
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 25 deletions.
17 changes: 11 additions & 6 deletions pymatgen/io/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@
from ase.io.jsonio import decode, encode
from ase.spacegroup import Spacegroup

ase_loaded = True
no_ase_err = None
except ImportError:
ase_loaded = False
no_ase_err = PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`")

class Atoms: # type: ignore[no-redef]
def __init__(self, *args, **kwargs):
raise no_ase_err


__author__ = "Shyue Ping Ong, Andrew S. Rosen"
__copyright__ = "Copyright 2012, The Materials Project"
Expand All @@ -51,12 +56,12 @@ def as_dict(s: Atoms) -> dict[str, Any]:
# See ASE issue #1387.
return {"@module": "pymatgen.io.ase", "@class": "MSONAtoms", "atoms_json": encode(s)}

def from_dict(d: dict[str, Any]) -> MSONAtoms:
def from_dict(dct: dict[str, Any]) -> MSONAtoms:
# Normally, we would want to this to be a wrapper around atoms.fromdict() with @module and
# @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant
# to be used in a round-trip fashion and does not work properly with constraints.
# See ASE issue #1387.
return MSONAtoms(decode(d["atoms_json"]))
return MSONAtoms(decode(dct["atoms_json"]))


# NOTE: If making notable changes to this class, please ping @Andrew-S-Rosen on GitHub.
Expand All @@ -77,8 +82,8 @@ def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSO
Returns:
Atoms: ASE Atoms object
"""
if not ase_loaded:
raise PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`")
if no_ase_err:
raise no_ase_err
if not structure.is_ordered:
raise ValueError("ASE Atoms only supports ordered structures")

Expand Down
67 changes: 48 additions & 19 deletions tests/io/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms
from pymatgen.util.testing import TEST_FILES_DIR

ase = pytest.importorskip("ase")
try:
import ase
except ImportError:
ase = None

structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")

skip_if_no_ase = pytest.mark.skipif(ase is None, reason="ase not installed")


@skip_if_no_ase
def test_get_atoms_from_structure():
atoms = AseAtomsAdaptor.get_atoms(structure)
ase_composition = Composition(atoms.get_chemical_formula())
Expand All @@ -32,6 +39,7 @@ def test_get_atoms_from_structure():
assert atoms.get_array("prop").tolist() == prop


@skip_if_no_ase
def test_get_atoms_from_structure_mags():
mags = [1.0] * len(structure)
structure.add_site_property("final_magmom", mags)
Expand All @@ -53,6 +61,7 @@ def test_get_atoms_from_structure_mags():
assert atoms.get_magnetic_moments().tolist(), mags


@skip_if_no_ase
def test_get_atoms_from_structure_charge():
charges = [1.0] * len(structure)
structure.add_site_property("final_charge", charges)
Expand All @@ -74,19 +83,22 @@ def test_get_atoms_from_structure_charge():
assert atoms.get_charges().tolist(), charges


@skip_if_no_ase
def test_get_atoms_from_structure_oxi_states():
oxi_states = [1.0] * len(structure)
structure.add_oxidation_state_by_site(oxi_states)
atoms = AseAtomsAdaptor.get_atoms(structure)
assert atoms.get_array("oxi_states").tolist() == oxi_states


@skip_if_no_ase
def test_get_atoms_from_structure_dyn():
structure.add_site_property("selective_dynamics", [[False] * 3] * len(structure))
atoms = AseAtomsAdaptor.get_atoms(structure)
assert atoms.constraints[0].get_indices().tolist() == [atom.index for atom in atoms]


@skip_if_no_ase
def test_get_atoms_from_molecule():
mol = Molecule.from_file(f"{TEST_FILES_DIR}/acetylene.xyz")
atoms = AseAtomsAdaptor.get_atoms(mol)
Expand All @@ -98,6 +110,7 @@ def test_get_atoms_from_molecule():
assert not atoms.has("initial_magmoms")


@skip_if_no_ase
def test_get_atoms_from_molecule_mags():
molecule = Molecule.from_file(f"{TEST_FILES_DIR}/acetylene.xyz")
atoms = AseAtomsAdaptor.get_atoms(molecule)
Expand All @@ -123,13 +136,15 @@ def test_get_atoms_from_molecule_mags():
assert atoms.spin_multiplicity == 3


@skip_if_no_ase
def test_get_atoms_from_molecule_dyn():
molecule = Molecule.from_file(f"{TEST_FILES_DIR}/acetylene.xyz")
molecule.add_site_property("selective_dynamics", [[False] * 3] * len(molecule))
atoms = AseAtomsAdaptor.get_atoms(molecule)
assert atoms.constraints[0].get_indices().tolist() == [atom.index for atom in atoms]


@skip_if_no_ase
def test_get_structure():
atoms = ase.io.read(f"{TEST_FILES_DIR}/POSCAR")
struct = AseAtomsAdaptor.get_structure(atoms)
Expand All @@ -152,6 +167,7 @@ def test_get_structure():
struct = AseAtomsAdaptor.get_structure(atoms, validate_proximity=True)


@skip_if_no_ase
def test_get_structure_mag():
atoms = ase.io.read(f"{TEST_FILES_DIR}/POSCAR")
mags = [1.0] * len(atoms)
Expand All @@ -168,14 +184,10 @@ def test_get_structure_mag():
assert "initial_magmoms" not in structure.site_properties


@skip_if_no_ase
@pytest.mark.parametrize(
"select_dyn",
[
[True, True, True],
[False, False, False],
np.array([True, True, True]),
np.array([False, False, False]),
],
[[True, True, True], [False, False, False], np.array([True, True, True]), np.array([False, False, False])],
)
def test_get_structure_dyn(select_dyn):
atoms = ase.io.read(f"{TEST_FILES_DIR}/POSCAR")
Expand All @@ -197,6 +209,7 @@ def test_get_structure_dyn(select_dyn):
assert len(ase_atoms) == len(structure)


@skip_if_no_ase
def test_get_molecule():
atoms = ase.io.read(f"{TEST_FILES_DIR}/acetylene.xyz")
molecule = AseAtomsAdaptor.get_molecule(atoms)
Expand Down Expand Up @@ -224,6 +237,7 @@ def test_get_molecule():
assert molecule.spin_multiplicity == 3


@skip_if_no_ase
@pytest.mark.parametrize("filename", ["OUTCAR", "V2O3.cif"])
def test_back_forth(filename):
# Atoms --> Structure --> Atoms --> Structure
Expand All @@ -238,10 +252,11 @@ def test_back_forth(filename):
atoms_back = AseAtomsAdaptor.get_atoms(structure)
structure_back = AseAtomsAdaptor.get_structure(atoms_back)
assert structure_back == structure
for k, v in atoms.todict().items():
assert str(atoms_back.todict()[k]) == str(v)
for key, val in atoms.todict().items():
assert str(atoms_back.todict()[key]) == str(val)


@skip_if_no_ase
def test_back_forth_v2():
# Structure --> Atoms --> Structure --> Atoms
structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")
Expand All @@ -255,14 +270,15 @@ def test_back_forth_v2():
structure_back = AseAtomsAdaptor.get_structure(atoms)
atoms_back = AseAtomsAdaptor.get_atoms(structure_back)
assert structure_back == structure
for k, v in atoms.todict().items():
assert str(atoms_back.todict()[k]) == str(v)
for key, val in atoms.todict().items():
assert str(atoms_back.todict()[key]) == str(val)

# test document can be jsanitized and decoded
dct = jsanitize(structure, strict=True, enum_values=True)
MontyDecoder().process_decoded(dct)


@skip_if_no_ase
def test_back_forth_v3():
# Atoms --> Molecule --> Atoms --> Molecule
atoms = ase.io.read(f"{TEST_FILES_DIR}/acetylene.xyz")
Expand All @@ -275,11 +291,12 @@ def test_back_forth_v3():
molecule = AseAtomsAdaptor.get_molecule(atoms)
atoms_back = AseAtomsAdaptor.get_atoms(molecule)
molecule_back = AseAtomsAdaptor.get_molecule(atoms_back)
for k, v in atoms.todict().items():
assert str(atoms_back.todict()[k]) == str(v)
for key, val in atoms.todict().items():
assert str(atoms_back.todict()[key]) == str(val)
assert molecule_back == molecule


@skip_if_no_ase
def test_back_forth_v4():
# Molecule --> Atoms --> Molecule --> Atoms
molecule = Molecule.from_file(f"{TEST_FILES_DIR}/acetylene.xyz")
Expand All @@ -288,32 +305,44 @@ def test_back_forth_v4():
atoms = AseAtomsAdaptor.get_atoms(molecule)
molecule_back = AseAtomsAdaptor.get_molecule(atoms)
atoms_back = AseAtomsAdaptor.get_atoms(molecule_back)
for k, v in atoms.todict().items():
assert str(atoms_back.todict()[k]) == str(v)
for key, val in atoms.todict().items():
assert str(atoms_back.todict()[key]) == str(val)
assert molecule_back == molecule

# test document can be jsanitized and decoded
dct = jsanitize(molecule, strict=True, enum_values=True)
MontyDecoder().process_decoded(dct)


@skip_if_no_ase
def test_msonable_atoms():
atoms = ase.io.read(f"{TEST_FILES_DIR}/OUTCAR")
assert not isinstance(atoms, MSONAtoms)
ref = {"@module": "pymatgen.io.ase", "@class": "MSONAtoms", "atoms_json": ase.io.jsonio.encode(atoms)}
msonable_atoms = MSONAtoms(atoms)
assert atoms == msonable_atoms
assert msonable_atoms.as_dict() == ref
assert MSONAtoms.from_dict(ref) == atoms


def test_msonable_atoms_v2():
structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")

atoms = AseAtomsAdaptor.get_atoms(structure, msonable=True)
assert hasattr(atoms, "as_dict")
assert hasattr(atoms, "from_dict")
assert callable(atoms.as_dict)
assert callable(atoms.from_dict)
assert isinstance(atoms, MSONAtoms)

atoms = AseAtomsAdaptor.get_atoms(structure, msonable=False)
assert not hasattr(atoms, "as_dict")
assert not hasattr(atoms, "from_dict")
assert isinstance(atoms, ase.Atoms)


@pytest.mark.skipif(ase is not None, reason="ase is present")
def test_no_ase_err():
from importlib.metadata import PackageNotFoundError

import pymatgen.io.ase

expected_msg = str(pymatgen.io.ase.no_ase_err)
with pytest.raises(PackageNotFoundError, match=expected_msg):
pymatgen.io.ase.MSONAtoms()

0 comments on commit 5ce2ae1

Please sign in to comment.