Skip to content

Commit

Permalink
Ensure parity with top level legacy methods (#691)
Browse files Browse the repository at this point in the history
* Parity in get_entries* methods

* Add additional criteria warning

* Get_entry_by_material_id parity

* Update tests

* Fix assert statements

* Fix conventional unit cell behavior

* Update get_entries and tests

* Fix tests

* Update docstrings

* Update warnings

* Skip chgcar test
  • Loading branch information
munrojm authored Oct 11, 2022
1 parent 25b21d8 commit 2292fc1
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 37 deletions.
185 changes: 152 additions & 33 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from multiprocessing.sharedctypes import Value
import warnings
from functools import lru_cache
from os import environ
Expand All @@ -16,13 +17,14 @@
from pymatgen.analysis.pourbaix_diagram import IonEntry
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.ion import Ion
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from pymatgen.io.vasp import Chgcar
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from requests import get
from typing import Literal

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.utils import validate_ids
from mp_api.client.routes import *

_DEPRECATION_WARNING = (
Expand Down Expand Up @@ -438,50 +440,111 @@ def find_structure(
)

def get_entries(
self, chemsys_formula: Union[str, List[str]], sort_by_e_above_hull=False,
):
self,
chemsys_formula_mpids: Union[str, List[str]],
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
sort_by_e_above_hull=False,
) -> List[ComputedStructureEntry]:
"""
Get a list of ComputedEntries or ComputedStructureEntries corresponding
to a chemical system or formula.
Args:
chemsys_formula (str): A chemical system, list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]), or single formula (e.g., Fe2O3, Si*).
chemsys_formula_mpids (str, List[str]): A chemical system, list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]), formula, list of formulas
(e.g., Fe2O3, Si*, [SiO2, BiFeO3]), Materials Project ID, or list of Materials
Project IDs (e.g., mp-22526, [mp-22526, mp-149]).
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
sort_by_e_above_hull (bool): Whether to sort the list of entries by
e_above_hull in ascending order.
Returns:
List of ComputedEntry or ComputedStructureEntry objects.
List ComputedStructureEntry objects.
"""

if isinstance(chemsys_formula, list) or (
isinstance(chemsys_formula, str) and "-" in chemsys_formula
):
input_params = {"chemsys": chemsys_formula}
else:
input_params = {"formula": chemsys_formula}
if inc_structure is not None:
warnings.warn("The 'inc_structure' argument is deprecated as structure "
"data is now always included in all returned entry objects.")

if isinstance(chemsys_formula_mpids, str):
chemsys_formula_mpids = [chemsys_formula_mpids]

try:
input_params = {"material_ids": validate_ids(chemsys_formula_mpids)}
except ValueError:

if any("-" in entry for entry in chemsys_formula_mpids):
input_params = {"chemsys": chemsys_formula_mpids}
else:
input_params = {"formula": chemsys_formula_mpids}

entries = []

if sort_by_e_above_hull:
fields = ["entries"] if not property_data else ["entries"] + property_data

for doc in self.thermo.search(
if sort_by_e_above_hull:
docs = self.thermo.search(
**input_params, # type: ignore
all_fields=False,
fields=["entries"],
fields=fields,
sort_fields=["energy_above_hull"],
):
entries.extend(list(doc.entries.values()))
)
else:
docs = self.thermo.search(
**input_params, all_fields=False, fields=fields, # type: ignore
)

return entries
for doc in docs:
for entry in doc.entries.values():
if not compatible_only:
entry.correction = 0.0
entry.energy_adjustments = []

else:
for doc in self.thermo.search(
**input_params, all_fields=False, fields=["entries"], # type: ignore
):
entries.extend(list(doc.entries.values()))
if property_data:
for property in property_data:
entry.data[property] = doc.dict()[property]

if conventional_unit_cell:

s = SpacegroupAnalyzer(entry.structure).get_conventional_standard_structure()
site_ratio = (len(s) / len(entry.structure))
new_energy = entry.uncorrected_energy * site_ratio

entry_dict = entry.as_dict()
entry_dict["energy"] = new_energy
entry_dict["structure"] = s.as_dict()
entry_dict["correction"] = 0.0

for element in entry_dict["composition"]:
entry_dict["composition"][element] *= site_ratio

return entries
for correction in entry_dict["energy_adjustments"]:
correction["n_atoms"] *= site_ratio

entry = ComputedStructureEntry.from_dict(entry_dict)

entries.append(entry)

return entries

def get_pourbaix_entries(
self,
Expand Down Expand Up @@ -783,24 +846,51 @@ def get_ion_entries(

return ion_entries

def get_entry_by_material_id(self, material_id: str):
def get_entry_by_material_id(self, material_id: str,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,):
"""
Get all ComputedEntry objects corresponding to a material_id.
Args:
material_id (str): Materials Project material_id (a string,
e.g., mp-1234).
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
Returns:
List of ComputedEntry or ComputedStructureEntry object.
"""
return list(
self.thermo.get_data_by_id(
document_id=material_id, fields=["entries"]
).entries.values()
)
return self.get_entries(material_id,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell)

def get_entries_in_chemsys(
self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None,
self, elements: Union[str, List[str]],
use_gibbs: Optional[int] = None,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
additional_criteria=None,
):
"""
Helper method to get a list of ComputedEntries in a chemical system.
Expand All @@ -817,9 +907,34 @@ def get_entries_in_chemsys(
(see GibbsComputedStructureEntry). The number is the temperature in
Kelvin at which to estimate the free energy. Must be between 300 K and
2000 K.
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
additional_criteria (dict): *This is a deprecated argument*. To obtain entry objects
with additional criteria, use the `MPRester.thermo.search` method directly.
Returns:
List of ComputedEntries.
List of ComputedStructureEntries.
"""

if additional_criteria is not None:
warnings.warn("The 'additional_criteria' argument is deprecated. "
"To obtain entry objects with additional criteria, use "
"the 'MPRester.thermo.search' method directly")

if isinstance(elements, str):
elements = elements.split("-")

Expand All @@ -830,7 +945,11 @@ def get_entries_in_chemsys(

entries = [] # type: List[ComputedEntry]

entries.extend(self.get_entries(all_chemsyses))
entries.extend(self.get_entries(all_chemsyses,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell))

if use_gibbs:
# replace the entries with GibbsComputedStructureEntry
Expand Down
36 changes: 32 additions & 4 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from emmet.core.symmetry import CrystalSystem
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types import CalcType
from sympy import prime
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client import MPRester
from pymatgen.analysis.magnetism import Ordering
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_get_structures(self, mpr):
structs = mpr.get_structures("Mn-O", final=False)
assert len(structs) > 0

@pytest.mark.skip(reason="endpoint issues")
@pytest.mark.skip(reason="Endpoint issues")
def test_find_structure(self, mpr):
path = os.path.join(MAPIClientSettings().TEST_FILES, "Si_mp_149.cif")
with open(path) as file:
Expand Down Expand Up @@ -130,12 +131,41 @@ def test_get_entries(self, mpr):

assert sorted_entries != entries

# Formula
formula = "SiO2"
entries = mpr.get_entries(formula)

for e in entries:
assert isinstance(e, ComputedEntry)

# Property data
formula = "BiFeO3"
entries = mpr.get_entries(formula, property_data=["energy_above_hull"])

for e in entries:
assert e.data.get("energy_above_hull", None) is not None

# Conventional structure
formula = "BiFeO3"
entry = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=True)[0]

s = entry.structure
assert pytest.approx(s.lattice.a) == s.lattice.b
assert pytest.approx(s.lattice.a) != s.lattice.c
assert pytest.approx(s.lattice.alpha) == 90
assert pytest.approx(s.lattice.beta) == 90
assert pytest.approx(s.lattice.gamma) == 120

# Ensure energy per atom is same
prim = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=False)[0]
assert pytest.approx(prim.energy_per_atom) == entry.energy_per_atom

s = prim.structure
assert pytest.approx(s.lattice.a) == s.lattice.b
assert pytest.approx(s.lattice.a) == s.lattice.c
assert pytest.approx(s.lattice.alpha) == s.lattice.beta
assert pytest.approx(s.lattice.alpha) == s.lattice.gamma

def test_get_entries_in_chemsys(self, mpr):
syms = ["Li", "Fe", "O"]
syms2 = "Li-Fe-O"
Expand All @@ -154,7 +184,6 @@ def test_get_entries_in_chemsys(self, mpr):
for e in gibbs_entries:
assert isinstance(e, GibbsComputedStructureEntry)

@pytest.mark.skip(reason="Until SSL issue fix")
def test_get_pourbaix_entries(self, mpr):
# test input chemsys as a list of elements
pbx_entries = mpr.get_pourbaix_entries(["Fe", "Cr"])
Expand Down Expand Up @@ -195,7 +224,6 @@ def test_get_pourbaix_entries(self, mpr):
# so4_two_minus = pbx_entries[9]
# self.assertAlmostEqual(so4_two_minus.energy, 0.301511, places=3)

@pytest.mark.skip(reason="Until SSL issue fix")
def test_get_ion_entries(self, mpr):
entries = mpr.get_entries_in_chemsys("Ti-O-H")
pd = PhaseDiagram(entries)
Expand Down Expand Up @@ -249,7 +277,7 @@ def test_get_phonon_data_by_material_id(self, mpr):
dos = mpr.get_phonon_dos_by_material_id("mp-11659")
assert isinstance(dos, PhononDos)

@pytest.mark.xfail(reason="SSL issue")
@pytest.mark.skip(reason="Test needs fixing with ENV variables")
def test_get_charge_density_data(self, mpr):
chgcar = mpr.get_charge_density_from_material_id("mp-149")
assert isinstance(chgcar, Chgcar)
Expand Down

0 comments on commit 2292fc1

Please sign in to comment.