From 88bd877506353d7329e24172634c20ca2b50ff66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 23 Sep 2022 22:12:18 +0200 Subject: [PATCH] Add `lightning_utilities.test.warning.no_warning_call` (#55) --- .github/workflows/check-package.yml | 3 --- src/lightning_utilities/test/warning.py | 14 ++++++++++++++ tests/unittests/test/test_warnings.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 src/lightning_utilities/test/warning.py create mode 100644 tests/unittests/test/test_warnings.py diff --git a/.github/workflows/check-package.yml b/.github/workflows/check-package.yml index 45c8558b..aa07d920 100644 --- a/.github/workflows/check-package.yml +++ b/.github/workflows/check-package.yml @@ -41,9 +41,6 @@ jobs: - name: Check local package run: | - pip install -q -U check-manifest - # check MANIFEST.in - check-manifest # check package python setup.py check --metadata --strict diff --git a/src/lightning_utilities/test/warning.py b/src/lightning_utilities/test/warning.py new file mode 100644 index 00000000..52fa1047 --- /dev/null +++ b/src/lightning_utilities/test/warning.py @@ -0,0 +1,14 @@ +import re +import warnings +from contextlib import contextmanager +from typing import Generator, Optional, Type + + +@contextmanager +def no_warning_call(expected_warning: Type[Warning] = Warning, match: Optional[str] = None) -> Generator: + with warnings.catch_warnings(record=True) as record: + yield + + for w in record: + if issubclass(w.category, expected_warning) and (match is None or re.compile(match).search(str(w.message))): + raise AssertionError(f"`{expected_warning.__name__}` was raised: {w.message!r}") diff --git a/tests/unittests/test/test_warnings.py b/tests/unittests/test/test_warnings.py new file mode 100644 index 00000000..3b5954ed --- /dev/null +++ b/tests/unittests/test/test_warnings.py @@ -0,0 +1,25 @@ +import warnings +from re import escape + +import pytest + +from lightning_utilities.test.warning import no_warning_call + + +def test_no_warning_call(): + with no_warning_call(): + ... + + with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")): + with no_warning_call(): + warnings.warn("foo") + + with no_warning_call(DeprecationWarning): + warnings.warn("foo") + + class MyDeprecationWarning(DeprecationWarning): + ... + + with pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")): + with pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning): + warnings.warn("bar", category=MyDeprecationWarning)