Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility function to warn changes in default arguments #5738

Merged
merged 8 commits into from
Jan 4, 2023
2 changes: 1 addition & 1 deletion monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
Average,
Expand Down
106 changes: 104 additions & 2 deletions monai/utils/deprecate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import warnings
from functools import wraps
from types import FunctionType
from typing import Optional
from typing import Any, Optional

from monai.utils.module import version_leq

from .. import __version__

__all__ = ["deprecated", "deprecated_arg", "DeprecatedError"]
__all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"]


class DeprecatedError(Exception):
Expand Down Expand Up @@ -223,3 +223,105 @@ def _wrapper(*args, **kwargs):
return _wrapper

return _decorator


def deprecated_arg_default(
name: str,
old_default: Any,
new_default: Any,
since: Optional[str] = None,
replaced: Optional[str] = None,
msg_suffix: str = "",
version_val: str = __version__,
warning_category=FutureWarning,
):
"""
Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default`
in version `changed`.

When the decorated definition is called, a `warning_category` is issued if `since` is given,
the default is not explicitly set by the caller and the current version is at or later than that given.
Another warning with the same category is issued if `changed` is given and the current version is at or later.

The relevant docstring of the deprecating function should also be updated accordingly,
using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded

In the current implementation type annotations are not preserved.


Args:
name: name of position or keyword argument where the default is deprecated/changed.
old_default: name of the old default. This is only for the warning message, it will not be validated.
new_default: name of the new default.
It is validated that this value is not present as the default before version `replaced`.
This means, that you can also use this if the actual default value is `None` and set later in the function.
You can also set this to any string representation, e.g. `"calculate_default_value()"`
if the default is calculated from another function.
since: version at which the argument default was marked deprecated but not replaced.
replaced: version at which the argument default was/will be replaced.
msg_suffix: message appended to warning/exception detailing reasons for deprecation.
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
warning_category: a warning category class, defaults to `FutureWarning`.

Returns:
Decorated callable which warns when deprecated default argument is not explicitly specified.
"""

if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit():
# version unknown, set version_val to a large value (assuming the latest version)
version_val = f"{sys.maxsize}"
if since is not None and replaced is not None and not version_leq(since, replaced):
raise ValueError(f"since must be less or equal to replaced, got since={since}, replaced={replaced}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
if is_not_yet_deprecated:
# smaller than `since`, do nothing
return lambda obj: obj
if since is None and replaced is None:
# raise a DeprecatedError directly
is_replaced = True
is_deprecated = True
else:
# compare the numbers
is_deprecated = since is not None and version_leq(since, version_val)
is_replaced = replaced is not None and version_leq(replaced, version_val)

def _decorator(func):
argname = f"{func.__module__} {func.__qualname__}:{name}"

msg_prefix = f"Default of argument `{name}`"

if is_replaced:
msg_infix = f"was replaced in version {replaced} from `{old_default}` to `{new_default}`."
elif is_deprecated:
msg_infix = f"has been deprecated since version {since} from `{old_default}` to `{new_default}`."
if replaced is not None:
msg_infix += f" It will be replaced in version {replaced}."
else:
msg_infix = f"has been deprecated from `{old_default}` to `{new_default}`."

msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip()

sig = inspect.signature(func)
if name not in sig.parameters:
raise ValueError(f"Argument `{name}` not found in signature of {func.__qualname__}.")
param = sig.parameters[name]
if param.default is inspect.Parameter.empty:
raise ValueError(f"Argument `{name}` has no default value.")

if param.default == new_default and not is_replaced:
raise ValueError(
f"Argument `{name}` was replaced to the new default value `{new_default}` before the specified version {replaced}."
)

@wraps(func)
def _wrapper(*args, **kwargs):
if name not in sig.bind(*args, **kwargs).arguments and is_deprecated:
# arg was not found so the default value is used
warn_deprecated(argname, msg, warning_category)

return func(*args, **kwargs)

return _wrapper

return _decorator
155 changes: 154 additions & 1 deletion tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import unittest
import warnings

from monai.utils import DeprecatedError, deprecated, deprecated_arg
from monai.utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default


class TestDeprecatedRC(unittest.TestCase):
Expand Down Expand Up @@ -287,6 +287,159 @@ def afoo4(a, b=None, **kwargs):
self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg
self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg

def test_deprecated_arg_default_explicit_default(self):
"""
Test deprecated arg default, where the default is explicitly set (no warning).
"""

@deprecated_arg_default(
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
)
def foo(a, b="a"):
return a, b

with self.assertWarns(FutureWarning) as aw:
self.assertEqual(foo("a", "a"), ("a", "a"))
self.assertEqual(foo("a", "b"), ("a", "b"))
self.assertEqual(foo("a", "c"), ("a", "c"))
warnings.warn("fake warning", FutureWarning)

self.assertEqual(aw.warning.args[0], "fake warning")

def test_deprecated_arg_default_version_less_than_since(self):
"""
Test deprecated arg default, where the current version is less than `since` (no warning).
"""

@deprecated_arg_default(
"b", old_default="a", new_default="b", since=self.test_version, version_val=self.prev_version
)
def foo(a, b="a"):
return a, b

with self.assertWarns(FutureWarning) as aw:
self.assertEqual(foo("a"), ("a", "a"))
self.assertEqual(foo("a", "a"), ("a", "a"))
warnings.warn("fake warning", FutureWarning)

self.assertEqual(aw.warning.args[0], "fake warning")

def test_deprecated_arg_default_warning_deprecated(self):
"""
Test deprecated arg default, where the default is used.
"""

@deprecated_arg_default(
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
)
def foo(a, b="a"):
return a, b

self.assertWarns(FutureWarning, lambda: foo("a"))

def test_deprecated_arg_default_warning_replaced(self):
"""
Test deprecated arg default, where the default is used.
"""

@deprecated_arg_default(
"b",
old_default="a",
new_default="b",
since=self.prev_version,
replaced=self.prev_version,
version_val=self.test_version,
)
def foo(a, b="a"):
return a, b

self.assertWarns(FutureWarning, lambda: foo("a"))

def test_deprecated_arg_default_warning_with_none_as_placeholder(self):
"""
Test deprecated arg default, where the default is used.
"""

@deprecated_arg_default(
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
)
def foo(a, b=None):
if b is None:
b = "a"
return a, b

self.assertWarns(FutureWarning, lambda: foo("a"))

@deprecated_arg_default(
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
)
def foo2(a, b=None):
if b is None:
b = "b"
return a, b

self.assertWarns(FutureWarning, lambda: foo2("a"))

def test_deprecated_arg_default_errors(self):
"""
Test deprecated arg default, where the decorator is wrongly used.
"""

# since > replaced
def since_grater_than_replaced():
@deprecated_arg_default(
"b",
old_default="a",
new_default="b",
since=self.test_version,
replaced=self.prev_version,
version_val=self.test_version,
)
def foo(a, b=None):
return a, b

self.assertRaises(ValueError, since_grater_than_replaced)

# argname doesnt exist
def argname_doesnt_exist():
@deprecated_arg_default(
"other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version
)
def foo(a, b=None):
return a, b

self.assertRaises(ValueError, argname_doesnt_exist)

# argname has no default
def argname_has_no_default():
@deprecated_arg_default(
"a",
old_default="a",
new_default="b",
since=self.prev_version,
replaced=self.test_version,
version_val=self.test_version,
)
def foo(a):
return a

self.assertRaises(ValueError, argname_has_no_default)

# new default is used but version < replaced
def argname_was_replaced_before_specified_version():
@deprecated_arg_default(
"a",
old_default="a",
new_default="b",
since=self.prev_version,
replaced=self.next_version,
version_val=self.test_version,
)
def foo(a, b="b"):
return a, b

self.assertRaises(ValueError, argname_was_replaced_before_specified_version)


if __name__ == "__main__":
unittest.main()