From c7fd7c4028da6c8f203c56a723337c6474a27517 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sat, 15 Jun 2024 19:35:17 +0200 Subject: [PATCH 1/4] Check required attributes before dumping to a file --- iodata/api.py | 51 +++++++++++++++++++++++++++++-- iodata/formats/wfn.py | 6 +++- iodata/test/test_api.py | 67 +++++++++++++++++++++++++++++++++++++++++ iodata/utils.py | 6 ++-- 4 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 iodata/test/test_api.py diff --git a/iodata/api.py b/iodata/api.py index 838fa21d..e84c87f0 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -27,7 +27,7 @@ from typing import Callable, Optional from .iodata import IOData -from .utils import LineIterator +from .utils import FileFormatError, LineIterator __all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"] @@ -173,6 +173,28 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO return +def _check_required(iodata: IOData, dump_func: Callable): + """Check that required attributes are not None before dumping to a file. + + Parameters + ---------- + iodata + The data to be written. + dump_func + The dump_one or dump_many function that will write the file. + + Raises + ------ + FileFormatError + When a required attribute is ``None``. + """ + for attr_name in dump_func.required: + if getattr(iodata, attr_name) is None: + raise FileFormatError( + f"Required attribute {attr_name}, for format {dump_func.fmt}, is None." + ) + + def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs): """Write data to a file. @@ -194,6 +216,7 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) """ format_module = _select_format_module(filename, "dump_one", fmt) + _check_required(iodata, format_module.dump_one) with open(filename, "w") as f: format_module.dump_one(f, iodata, **kwargs) @@ -216,10 +239,34 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non **kwargs Keyword arguments are passed on to the format-specific dump_many function. + Raises + ------ + FileFormatError + When iodatas has zero length + or when one of the iodata items does not have the required attributes. """ format_module = _select_format_module(filename, "dump_many", fmt) + + # Check the first item before creating the file. + # If the file already exists, this may prevent data loss: + # The file is not overwritten when it is clear that writing will fail. + iter_iodatas = iter(iodatas) + try: + first = next(iter_iodatas) + _check_required(first, format_module.dump_many) + except StopIteration as exc: + raise FileFormatError("dump_many needs at least one iodata object.") from exc + + def checking_iterator(): + """Iterate over all iodata items, not checking the first.""" + # The first one was already checked. + yield first + for other in iter_iodatas: + _check_required(other, format_module.dump_many) + yield other + with open(filename, "w") as f: - format_module.dump_many(f, iodatas, **kwargs) + format_module.dump_many(f, checking_iterator(), **kwargs) def write_input( diff --git a/iodata/formats/wfn.py b/iodata/formats/wfn.py index b1719cb8..98fcdcb4 100644 --- a/iodata/formats/wfn.py +++ b/iodata/formats/wfn.py @@ -494,7 +494,11 @@ def _dump_helper_section(f: TextIO, data: NDArray, fmt: str, skip: int, step: in DEFAULT_WFN_TTL = "WFN auto-generated by IOData" -@document_dump_one("WFN", ["atcoords", "atnums", "energy", "mo", "obasis", "title", "extra"]) +@document_dump_one( + "WFN", + ["atcoords", "atnums", "mo", "obasis"], + ["energy", "title", "extra"], +) def dump_one(f: TextIO, data: IOData) -> None: """Do not edit this docstring. It will be overwritten.""" # occs_aminusb is not supported diff --git a/iodata/test/test_api.py b/iodata/test/test_api.py new file mode 100644 index 00000000..db6f9015 --- /dev/null +++ b/iodata/test/test_api.py @@ -0,0 +1,67 @@ +# IODATA is an input and output module for quantum chemistry. +# Copyright (C) 2011-2019 The IODATA Development Team +# +# This file is part of IODATA. +# +# IODATA is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 3 +# of the License, or (at your option) any later version. +# +# IODATA is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see +# -- +"""Unit tests for iodata.api. + +Relatively simple formats are used in this module to keep testing simple +and focus on the functionality of the API rather than the formats. +""" + +import os + +import pytest + +from ..api import dump_many, dump_one +from ..iodata import IOData +from ..utils import FileFormatError + + +def test_empty_dump_many_no_file(tmpdir): + path_xyz = os.path.join(tmpdir, "empty.xyz") + with pytest.raises(FileFormatError): + dump_many([], path_xyz) + assert not os.path.isfile(path_xyz) + + +def test_dump_one_missing_attribute_no_file(tmpdir): + path_xyz = os.path.join(tmpdir, "missing_atcoords.xyz") + with pytest.raises(FileFormatError): + dump_one(IOData(atnums=[1, 2, 3]), path_xyz) + assert not os.path.isfile(path_xyz) + + +def test_dump_many_missing_attribute_first(tmpdir): + path_xyz = os.path.join(tmpdir, "missing_atcoords.xyz") + iodatas = [ + IOData(atnums=[1, 1]), + IOData(atnums=[1, 1], atcoords=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + ] + with pytest.raises(FileFormatError): + dump_many(iodatas, path_xyz) + assert not os.path.isfile(path_xyz) + + +def test_dump_many_missing_attribute_second(tmpdir): + path_xyz = os.path.join(tmpdir, "missing_atcoords.xyz") + iodatas = [ + IOData(atnums=[1, 1], atcoords=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + IOData(atnums=[1, 1]), + ] + with pytest.raises(FileFormatError): + dump_many(iodatas, path_xyz) + assert os.path.isfile(path_xyz) diff --git a/iodata/utils.py b/iodata/utils.py index 6abb0b68..b34eba67 100644 --- a/iodata/utils.py +++ b/iodata/utils.py @@ -28,7 +28,9 @@ from .attrutils import validate_shape -__all__ = [ +__all__ = ( + "FileFormatError", + "FileFormatWarning", "LineIterator", "Cube", "set_four_index_element", @@ -36,7 +38,7 @@ "derive_naturals", "check_dm", "strtobool", -] +) # The unit conversion factors below can be used as follows: From 3350ba3b4ef89e83860aa80608118579f0a35eea Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sat, 15 Jun 2024 19:59:15 +0200 Subject: [PATCH 2/4] Extra test based on AI suggestions --- iodata/test/test_api.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/iodata/test/test_api.py b/iodata/test/test_api.py index db6f9015..5be1455d 100644 --- a/iodata/test/test_api.py +++ b/iodata/test/test_api.py @@ -25,8 +25,9 @@ import os import pytest +from numpy.testing import assert_allclose, assert_array_equal -from ..api import dump_many, dump_one +from ..api import dump_many, dump_one, load_many from ..iodata import IOData from ..utils import FileFormatError @@ -65,3 +66,23 @@ def test_dump_many_missing_attribute_second(tmpdir): with pytest.raises(FileFormatError): dump_many(iodatas, path_xyz) assert os.path.isfile(path_xyz) + + +def test_dump_many_generator(tmpdir): + path_xyz = os.path.join(tmpdir, "traj.xyz") + + iodata0 = IOData(atnums=[1, 1], atcoords=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + iodata1 = IOData(atnums=[2, 2], atcoords=[[0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]) + + def iodata_generator(): + yield iodata0 + yield iodata1 + + dump_many(iodata_generator(), path_xyz) + assert os.path.isfile(path_xyz) + iodatas = list(load_many(path_xyz)) + assert len(iodatas) == 2 + assert_array_equal(iodatas[0].atnums, iodata0.atnums) + assert_array_equal(iodatas[1].atnums, iodata1.atnums) + assert_allclose(iodatas[0].atcoords, iodata0.atcoords) + assert_allclose(iodatas[1].atcoords, iodata1.atcoords) From ff1b3b66d3bf89af2cafc4b2dc7357d0848653cc Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sat, 15 Jun 2024 20:01:21 +0200 Subject: [PATCH 3/4] One more AI suggestion --- iodata/test/test_api.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/iodata/test/test_api.py b/iodata/test/test_api.py index 5be1455d..dd7f6169 100644 --- a/iodata/test/test_api.py +++ b/iodata/test/test_api.py @@ -59,13 +59,15 @@ def test_dump_many_missing_attribute_first(tmpdir): def test_dump_many_missing_attribute_second(tmpdir): path_xyz = os.path.join(tmpdir, "missing_atcoords.xyz") - iodatas = [ - IOData(atnums=[1, 1], atcoords=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), - IOData(atnums=[1, 1]), - ] + iodata0 = IOData(atnums=[1, 1], atcoords=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + iodatas = [iodata0, IOData(atnums=[1, 1])] with pytest.raises(FileFormatError): dump_many(iodatas, path_xyz) assert os.path.isfile(path_xyz) + iodatas = list(load_many(path_xyz)) + assert len(iodatas) == 1 + assert_array_equal(iodatas[0].atnums, iodata0.atnums) + assert_allclose(iodatas[0].atcoords, iodata0.atcoords) def test_dump_many_generator(tmpdir): From 0b2a68e76f9d0a80936e6bd77d099e1e0b88cf03 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sat, 15 Jun 2024 20:15:28 +0200 Subject: [PATCH 4/4] Add Raises section to dump_one docstring --- iodata/api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/iodata/api.py b/iodata/api.py index e84c87f0..9639c437 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -214,6 +214,10 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) **kwargs Keyword arguments are passed on to the format-specific dump_one function. + Raises + ------ + FileFormatError + When one of the iodata items does not have the required attributes. """ format_module = _select_format_module(filename, "dump_one", fmt) _check_required(iodata, format_module.dump_one)