Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MyPy disallow untyped decorators #5824

Merged
merged 9 commits into from
Jan 9, 2023
2 changes: 1 addition & 1 deletion monai/apps/deepedit/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd

# first item in batch only
engine.state.batch = batchdata
return engine._iteration(engine, batchdata)
return engine._iteration(engine, batchdata) # type: ignore[arg-type]
2 changes: 1 addition & 1 deletion monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
# collate list into a batch for next round interaction
batchdata = list_data_collate(batchdata_list)

return engine._iteration(engine, batchdata)
return engine._iteration(engine, batchdata) # type: ignore[arg-type]
1 change: 1 addition & 0 deletions monai/apps/reconstruction/networks/nets/complex_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
conv_net: Optional[nn.Module] = None,
):
super().__init__()
self.unet: nn.Module
if conv_net is None:
self.unet = BasicUNet(
spatial_dims=spatial_dims,
Expand Down
8 changes: 4 additions & 4 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
val_handlers: Sequence | None = None,
amp: bool = False,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: dict | None = None,
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
else:
raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")

def run(self, global_epoch: int = 1) -> None:
def run(self, global_epoch: int = 1) -> None: # type: ignore[override]
"""
Execute validation/evaluation based on Ignite Engine.

Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
val_handlers: Sequence | None = None,
amp: bool = False,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: dict | None = None,
Expand Down Expand Up @@ -380,7 +380,7 @@ def __init__(
val_handlers: Sequence | None = None,
amp: bool = False,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: dict | None = None,
Expand Down
4 changes: 2 additions & 2 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Trainer(Workflow):

"""

def run(self) -> None:
def run(self) -> None: # type: ignore[override]
"""
Execute training based on Ignite Engine.
If call this function multiple times, it will continuously run from the previous state.
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Sequence | None = None,
amp: bool = False,
event_names: list[str | EventEnum] | None = None,
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
optim_set_to_none: bool = False,
Expand Down
25 changes: 12 additions & 13 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

import torch
import torch.distributed as dist
Expand All @@ -24,7 +24,6 @@

from .utils import engine_apply_transform

IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="")
State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")

Expand All @@ -43,7 +42,7 @@
)


class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import
class Workflow(Engine):
"""
Workflow defines the core work process inheriting from Ignite engine.
All trainer, validator and evaluator share this same workflow as base class,
Expand Down Expand Up @@ -114,7 +113,7 @@ def __init__(
metric_cmp_fn: Callable = default_metric_cmp_fn,
handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_names: Optional[List[Union[str, EventEnum, Type[EventEnum]]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
Expand All @@ -140,7 +139,7 @@ def set_sampler_epoch(engine: Engine):
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")

# set all sharable data for the workflow based on Ignite engine.state
self.state = State(
self.state: Any = State(
rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0,
seed=0,
iteration=0,
Expand All @@ -167,18 +166,18 @@ def set_sampler_epoch(engine: Engine):
self.scaler: Optional[torch.cuda.amp.GradScaler] = None

if event_names is None:
event_names = [IterationEvents] # type: ignore
event_names = [IterationEvents]
else:
if not isinstance(event_names, list):
raise ValueError("`event_names` must be a list or string or EventEnum.")
event_names += [IterationEvents] # type: ignore
raise ValueError("`event_names` must be a list of strings or EventEnums.")
event_names += [IterationEvents]
for name in event_names:
if isinstance(name, str):
self.register_events(name, event_to_attr=event_to_attr)
elif issubclass(name, EventEnum): # type: ignore
if isinstance(name, (str, EventEnum)):
self.register_events(name, event_to_attr=event_to_attr) # type: ignore[arg-type]
elif issubclass(name, EventEnum):
self.register_events(*name, event_to_attr=event_to_attr)
else:
raise ValueError("`event_names` must be a list or string or EventEnum.")
raise ValueError("`event_names` must be a list of strings or EventEnums.")

if decollate:
self._register_decollate()
Expand Down Expand Up @@ -267,7 +266,7 @@ def _register_handlers(self, handlers: Sequence):
for handler in handlers_:
handler.attach(self)

def run(self) -> None:
def run(self) -> None: # type: ignore[override]
"""
Execute training, validation or evaluation based on Ignite Engine.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.auto3dseg import SegSummarizer
from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, patch_bundle_tracking
from monai.engines import Trainer
from monai.engines import SupervisedTrainer, Trainer
from monai.fl.client import ClientAlgo, ClientAlgoStats
from monai.fl.utils.constants import (
BundleKeys,
Expand Down Expand Up @@ -429,7 +429,7 @@ def __init__(
self.train_parser: Optional[ConfigParser] = None
self.eval_parser: Optional[ConfigParser] = None
self.filter_parser: Optional[ConfigParser] = None
self.trainer: Optional[Trainer] = None
self.trainer: Optional[SupervisedTrainer] = None
self.evaluator: Optional[Any] = None
self.pre_filters = None
self.post_weight_filters = None
Expand Down
16 changes: 10 additions & 6 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
from monai.utils import min_version, optional_import

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
reinit__is_reduced, _ = optional_import(
"ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator"
)


if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
reinit__is_reduced, _ = optional_import(
"ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator"
)


class IgniteMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import
class IgniteMetric(Metric):
"""
Base Metric class based on ignite event handler mechanism.
The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
Expand Down Expand Up @@ -107,7 +111,7 @@ def compute(self) -> Any:
result = result.item()
return result

def attach(self, engine: Engine, name: str) -> None:
def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
"""
Attaches current metric to provided engine. On the end of engine's run,
`engine.state.metrics` dictionary will contain computed metric's value under provided name.
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from enum import Enum
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -2730,7 +2730,7 @@ def __call__(
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
else:
_device = img.device if isinstance(img, torch.Tensor) else self.device
grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")
grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=_device, backend="torch"))
out: torch.Tensor = self.resampler(
img,
grid,
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -426,7 +426,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.spacing_transform.inverse(d[key])
d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key]))
return d


Expand Down Expand Up @@ -1045,7 +1045,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
)
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
else:
grid = create_grid(spatial_size=sp_size, device=device, backend="torch")
grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=device, backend="torch"))

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # type: ignore
Expand Down
13 changes: 5 additions & 8 deletions monai/utils/deprecate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import warnings
from functools import wraps
from types import FunctionType
from typing import Any, Optional
from typing import Any, Callable, Optional, TypeVar

from monai.utils.module import version_leq

from .. import __version__

__all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"]
T = TypeVar("T", type, Callable)


class DeprecatedError(Exception):
Expand All @@ -40,7 +41,7 @@ def deprecated(
msg_suffix: str = "",
version_val: str = __version__,
warning_category=FutureWarning,
):
) -> Callable[[T], T]:
"""
Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the
current version and states at what version of the definition was marked as deprecated. If `removed` is given
Expand Down Expand Up @@ -124,7 +125,7 @@ def deprecated_arg(
version_val: str = __version__,
new_name: Optional[str] = None,
warning_category=FutureWarning,
):
) -> Callable[[T], T]:
"""
Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as
described in the `deprecated` decorator.
Expand All @@ -138,8 +139,6 @@ def deprecated_arg(
using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded

In the current implementation type annotations are not preserved.


Args:
name: name of position or keyword argument to mark as deprecated.
Expand Down Expand Up @@ -234,7 +233,7 @@ def deprecated_arg_default(
msg_suffix: str = "",
version_val: str = __version__,
warning_category=FutureWarning,
):
) -> Callable[[T], T]:
"""
Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default`
in version `changed`.
Expand All @@ -247,8 +246,6 @@ def deprecated_arg_default(
using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded

In the current implementation type annotations are not preserved.


Args:
name: name of position or keyword argument where the default is deprecated/changed.
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ ignore_errors = True
ignore_errors = True

[mypy-monai.*]
# Also check the body of functions with no types in their type signature.
check_untyped_defs = True
# Warns about usage of untyped decorators.
disallow_untyped_decorators = True

[pytype]
# Space-separated list of files or directories to exclude.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import unittest
from typing import cast

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -186,7 +187,7 @@ def test_ornt_meta(
):
img = MetaTensor(img, affine=affine).to(device)
ornt = Orientation(**init_param)
res: MetaTensor = ornt(img)
res = cast(MetaTensor, ornt(img))
assert_allclose(res, expected_data.to(device))
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels)
self.assertEqual("".join(new_code), expected_code)
Expand All @@ -204,6 +205,7 @@ def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, devic
assert_allclose(res, expected_data)
if track_meta:
self.assertIsInstance(res, MetaTensor)
assert isinstance(res, MetaTensor) # for mypy type narrowing
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels)
self.assertEqual("".join(new_code), expected_code)
else:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_orientationd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import unittest
from typing import Optional
from typing import Optional, cast

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_orntd(
data = {k: img.clone() for k in ornt.keys}
res = ornt(data)
for k in ornt.keys:
_im = res[k]
_im = cast(MetaTensor, res[k])
self.assertIsInstance(_im, MetaTensor)
np.testing.assert_allclose(_im.shape, expected_shape)
code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels)
Expand All @@ -94,6 +94,7 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi
np.testing.assert_allclose(_im.shape, expected_shape)
if track_meta:
self.assertIsInstance(_im, MetaTensor)
assert isinstance(_im, MetaTensor) # for mypy type narrowing
code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels)
self.assertEqual("".join(code), expected_code)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_spacingd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, devic
def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):
set_track_meta(track_meta)
tr = Spacingd(**init_param)
data = {"seg": img.to(device)}
res = tr(data)["seg"]
res = tr({"seg": img.to(device)})["seg"]

if track_meta:
self.assertIsInstance(res, MetaTensor)
assert isinstance(res, MetaTensor) # for mypy type narrowing
new_spacing = affine_to_spacing(res.affine, 3)
assert_allclose(new_spacing, init_param["pixdim"], type_test=False)
self.assertNotEqual(img.shape, res.shape)
Expand Down