Skip to content

Commit

Permalink
Fix optional dependencies error detection and remove pybtex hidden …
Browse files Browse the repository at this point in the history
…dependency (pybamm-team#3968)

* Edit pybamm import test

* Guard pybtex imports

* Update pybamm/citations.py

Co-authored-by: Agriya Khetarpal <[email protected]>

* Update pybamm/citations.py (1)

Co-authored-by: Agriya Khetarpal <[email protected]>

* Update pybamm/citations.py (2)

Co-authored-by: Agriya Khetarpal <[email protected]>

* Update pybamm/citations.py (3)

Co-authored-by: Agriya Khetarpal <[email protected]>

* Update contributing guidelines

* Exclude excepts from test coverage

* style: pre-commit fixes

---------

Co-authored-by: Lorenzo Favaro <[email protected]>
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric G. Kratz <[email protected]>
Co-authored-by: Arjun Verma <[email protected]>
  • Loading branch information
6 people authored and js1tr3 committed Aug 12, 2024
1 parent 5ef9ea6 commit 6030036
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 93 deletions.
30 changes: 11 additions & 19 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,18 @@ This allows people to (1) use PyBaMM without importing optional dependencies by

**Writing Tests for Optional Dependencies**

Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in `test_util.py`. This ensures that an error is raised upon the absence of said dependency. Here's an example:
Below, we list the currently available test functions to provide an overview. If you find it useful to add new test cases please do so within `tests/unit/test_util.py`.

```python
from tests import TestCase
import pybamm


class TestUtil(TestCase):
def test_optional_dependency(self):
# Test that an error is raised when pybtex is not available
with self.assertRaisesRegex(
ModuleNotFoundError, "Optional dependency pybtex is not available"
):
sys.modules["pybtex"] = None
pybamm.function_using_pybtex(x, y, z)

# Test that the function works when pybtex is available
sys.modules["pybtex"] = pybamm.util.import_optional_dependency("pybtex")
pybamm.function_using_pybtex(x, y, z)
```
Currently, there are three functions to test what concerns optional dependencies:
- `test_import_optional_dependency`
- `test_pybamm_import`
- `test_optional_dependencies`

The `test_import_optional_dependency` function extracts the optional dependencies installed in the setup environment, makes them unimportable (by setting them to `None` among the `sys.modules`), and tests that the `pybamm.util.import_optional_dependency` function throws a `ModuleNotFoundError` exception when their import is attempted.

The `test_pybamm_import` function extracts the optional dependencies installed in the setup environment and makes them unimportable (by setting them to `None` among the `sys.modules`), unloads `pybamm` and its sub-modules, and finally tests that `pybamm` can be imported successfully. In fact, it is essential that the `pybamm` package is importable with only the mandatory dependencies.

The `test_optional_dependencies` function extracts `pybamm` mandatory distribution packages and verifies that they are not present in the optional distribution packages list in `pyproject.toml`. This test is crucial for ensuring the consistency of the released package information and potential updates to dependencies during development.

## Testing

Expand Down
154 changes: 90 additions & 64 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,41 @@ def read_citations(self):
"""Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited
by passing a BibTeX citation to :meth:`register`.
"""
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
self._add_citation(key, entry)
try:
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
self._add_citation(key, entry)
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
"Citations could not be read because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable citation reading."
)

def _add_citation(self, key, entry):
"""Adds `entry` to `self._all_citations` under `key`, warning the user if a
previous entry is overwritten
"""

Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()

# Warn if overwriting a previous citation
new_citation = entry.to_string("bibtex")
if key in self._all_citations and new_citation != self._all_citations[key]:
warnings.warn(f"Replacing citation for {key}", stacklevel=2)

# Add to database
self._all_citations[key] = new_citation
try:
Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()

# Warn if overwriting a previous citation
new_citation = entry.to_string("bibtex")
if key in self._all_citations and new_citation != self._all_citations[key]:
warnings.warn(f"Replacing citation for {key}", stacklevel=2)

# Add to database
self._all_citations[key] = new_citation
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
f"Could not add citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable adding citations."
)

def _add_citation_tag(self, key, entry):
"""Adds a tag for a citation key in the dict, which represents the name of the
Expand Down Expand Up @@ -143,24 +155,32 @@ def _parse_citation(self, key):
key: str
A BibTeX formatted citation
"""
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
if not bib_data.entries:
raise PybtexError("no entries found")

# Add and register all citations
for key, entry in bib_data.entries.items():
# Add to _all_citations dictionary
self._add_citation(key, entry)
# Add to _papers_to_cite set
self._papers_to_cite.add(key)
return
except PybtexError as error:
# Unable to parse / unknown key
raise KeyError(f"Not a bibtex citation or known citation: {key}") from error
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
if not bib_data.entries:
raise PybtexError("no entries found")

# Add and register all citations
for key, entry in bib_data.entries.items():
# Add to _all_citations dictionary
self._add_citation(key, entry)
# Add to _papers_to_cite set
self._papers_to_cite.add(key)
return
except PybtexError as error:
# Unable to parse / unknown key
raise KeyError(
f"Not a bibtex citation or known citation: {key}"
) from error
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
f"Could not parse citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable citation parsing."
)

def _tag_citations(self):
"""Prints the citation tags for the citations that have been registered
Expand Down Expand Up @@ -211,38 +231,44 @@ def print(self, filename=None, output_format="text", verbose=False):
"""
# Parse citations that were not known keys at registration, but do not
# fail if they cannot be parsed
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
except KeyError: # pragma: no cover
warnings.warn(
message=f'\nCitation with key "{key}" is invalid. Please try again\n',
category=UserWarning,
stacklevel=2,
)
# delete the invalid citation from the set
self._unknown_citations.remove(key)

if output_format == "text":
citations = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
except KeyError: # pragma: no cover
warnings.warn(
message=f'\nCitation with key "{key}" is invalid. Please try again\n',
category=UserWarning,
stacklevel=2,
)
# delete the invalid citation from the set
self._unknown_citations.remove(key)

if output_format == "text":
citations = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
)
elif output_format == "bibtex":
citations = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)

if filename is None:
print(citations)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(citations)
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
"Could not print citations because the 'pybtex' library is not installed. "
"Please, install 'pybamm[cite]' to print citations."
)
elif output_format == "bibtex":
citations = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)

if filename is None:
print(citations)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(citations)


def print_citations(filename=None, output_format="text", verbose=False):
Expand Down
28 changes: 18 additions & 10 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_import_optional_dependency(self):
"pybamm", optional_distribution_deps=optional_distribution_deps
)

# Save optional dependencies, then set to None
# Save optional dependencies, then make them not importable
modules = {}
for import_pkg in present_optional_import_deps:
modules[import_pkg] = sys.modules.get(import_pkg)
Expand All @@ -117,23 +117,31 @@ def test_pybamm_import(self):
"pybamm", optional_distribution_deps=optional_distribution_deps
)

# Save optional dependencies, then set to None
# Save optional dependencies and their sub-modules, then make them not importable
modules = {}
for import_pkg in present_optional_import_deps:
modules[import_pkg] = sys.modules.get(import_pkg)
sys.modules[import_pkg] = None
for module_name, module in sys.modules.items():
base_module_name = module_name.split(".")[0]
if base_module_name in present_optional_import_deps:
modules[module_name] = module
sys.modules[module_name] = None

# Unload pybamm and its sub-modules
for module_name in list(sys.modules.keys()):
base_module_name = module_name.split(".")[0]
if base_module_name == "pybamm":
sys.modules.pop(module_name)

# Test pybamm is still importable
try:
importlib.reload(importlib.import_module("pybamm"))
importlib.import_module("pybamm")
except ModuleNotFoundError as error:
self.fail(
f"Import of 'pybamm' shouldn't require optional dependencies. Error: {error}"
)

# Restore optional dependencies
for import_pkg in present_optional_import_deps:
sys.modules[import_pkg] = modules[import_pkg]
finally:
# Restore optional dependencies and their sub-modules
for module_name, module in modules.items():
sys.modules[module_name] = module

def test_optional_dependencies(self):
optional_distribution_deps = get_optional_distribution_deps("pybamm")
Expand Down

0 comments on commit 6030036

Please sign in to comment.