diff --git a/src/_pytest/config/__init__.py b/src/_pytest/config/__init__.py index 45aa4d9a824..45710a70177 100644 --- a/src/_pytest/config/__init__.py +++ b/src/_pytest/config/__init__.py @@ -796,7 +796,7 @@ def __init__(self, pluginmanager, *, invocation_params=None) -> None: ) @property - def invocation_dir(self): + def invocation_dir(self) -> py.path.local: """Backward compatibility""" return py.path.local(str(self.invocation_params.dir)) diff --git a/src/_pytest/doctest.py b/src/_pytest/doctest.py index 0140e0b3a98..d0ee40ad865 100644 --- a/src/_pytest/doctest.py +++ b/src/_pytest/doctest.py @@ -7,6 +7,7 @@ import warnings from contextlib import contextmanager from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Sequence @@ -108,13 +109,18 @@ def pytest_unconfigure(): RUNNER_CLASS = None -def pytest_collect_file(path, parent): +def pytest_collect_file( + path: py.path.local, parent +) -> Optional[Union["DoctestModule", "DoctestTextfile"]]: config = parent.config if path.ext == ".py": if config.option.doctestmodules and not _is_setup_py(config, path, parent): - return DoctestModule.from_parent(parent, fspath=path) + mod = DoctestModule.from_parent(parent, fspath=path) # type: DoctestModule + return mod elif _is_doctest(config, path, parent): - return DoctestTextfile.from_parent(parent, fspath=path) + txt = DoctestTextfile.from_parent(parent, fspath=path) # type: DoctestTextfile + return txt + return None def _is_setup_py(config, path, parent): @@ -361,7 +367,7 @@ def _get_continue_on_failure(config): class DoctestTextfile(pytest.Module): obj = None - def collect(self): + def collect(self) -> Iterable[DoctestItem]: import doctest # inspired by doctest.testfile; ideally we would use it directly, @@ -440,7 +446,7 @@ def _mock_aware_unwrap(obj, stop=None): class DoctestModule(pytest.Module): - def collect(self): + def collect(self) -> Iterable[DoctestItem]: import doctest class MockAwareDocTestFinder(doctest.DocTestFinder): diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index 3cd7f5ffebb..9bf59b52ef9 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -1,13 +1,20 @@ """ hook specifications for pytest plugins, invoked from main.py and builtin plugins. """ from typing import Any +from typing import List from typing import Optional +from typing import Union +import py.path from pluggy import HookspecMarker from _pytest.compat import TYPE_CHECKING if TYPE_CHECKING: from _pytest.main import Session + from _pytest.nodes import Collector + from _pytest.nodes import Item + from _pytest.python import Module + from _pytest.python import PyCollector hookspec = HookspecMarker("pytest") @@ -215,7 +222,7 @@ def pytest_collect_directory(path, parent): """ -def pytest_collect_file(path, parent): +def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]": """ return collection Node or None for the given path. Any new node needs to have the specified ``parent`` as a parent. @@ -255,7 +262,7 @@ def pytest_make_collect_report(collector): @hookspec(firstresult=True) -def pytest_pycollect_makemodule(path, parent): +def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]": """ return a Module collector or None for the given path. This hook will be called for each matching test module path. The pytest_collect_file hook needs to be used if you want to @@ -268,7 +275,9 @@ def pytest_pycollect_makemodule(path, parent): @hookspec(firstresult=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem( + collector: "PyCollector", name: str, obj +) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]": """ return custom item/collector for a python object in a module, or None. Stops at first non-None result, see :ref:`firstresult` """ diff --git a/src/_pytest/main.py b/src/_pytest/main.py index dbb6236a3ec..9961a2c69bb 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -5,11 +5,14 @@ import os import sys from typing import Callable +from typing import cast from typing import Dict from typing import FrozenSet +from typing import Iterator from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Union @@ -18,12 +21,14 @@ import _pytest._code from _pytest import nodes +from _pytest.compat import overload from _pytest.compat import TYPE_CHECKING from _pytest.config import Config from _pytest.config import directory_arg from _pytest.config import ExitCode from _pytest.config import hookimpl from _pytest.config import UsageError +from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureManager from _pytest.outcomes import Exit from _pytest.reports import CollectReport @@ -33,11 +38,12 @@ if TYPE_CHECKING: from typing import Type + from typing_extensions import Literal from _pytest.python import Package -def pytest_addoption(parser): +def pytest_addoption(parser: Parser) -> None: parser.addini( "norecursedirs", "directory patterns to avoid for recursion", @@ -237,7 +243,7 @@ def wrap_session( return session.exitstatus -def pytest_cmdline_main(config): +def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]: return wrap_session(config, _main) @@ -254,11 +260,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]: return None -def pytest_collection(session): +def pytest_collection(session: "Session") -> Sequence[nodes.Item]: return session.perform_collect() -def pytest_runtestloop(session): +def pytest_runtestloop(session: "Session") -> bool: if session.testsfailed and not session.config.option.continue_on_collection_errors: raise session.Interrupted( "%d error%s during collection" @@ -278,7 +284,7 @@ def pytest_runtestloop(session): return True -def _in_venv(path): +def _in_venv(path: py.path.local) -> bool: """Attempts to detect if ``path`` is the root of a Virtual Environment by checking for the existence of the appropriate activate script""" bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin") @@ -295,7 +301,7 @@ def _in_venv(path): return any([fname.basename in activates for fname in bindir.listdir()]) -def pytest_ignore_collect(path, config): +def pytest_ignore_collect(path: py.path.local, config: Config) -> bool: ignore_paths = config._getconftest_pathlist("collect_ignore", path=path.dirpath()) ignore_paths = ignore_paths or [] excludeopt = config.getoption("ignore") @@ -323,7 +329,7 @@ def pytest_ignore_collect(path, config): return False -def pytest_collection_modifyitems(items, config): +def pytest_collection_modifyitems(items, config: Config) -> None: deselect_prefixes = tuple(config.getoption("deselect") or []) if not deselect_prefixes: return @@ -380,8 +386,8 @@ def __init__(self, config: Config) -> None: ) self.testsfailed = 0 self.testscollected = 0 - self.shouldstop = False - self.shouldfail = False + self.shouldstop = False # type: Union[bool, str] + self.shouldfail = False # type: Union[bool, str] self.trace = config.trace.root.get("collection") self.startdir = config.invocation_dir self._initialpaths = frozenset() # type: FrozenSet[py.path.local] @@ -407,10 +413,11 @@ def __init__(self, config: Config) -> None: self.config.pluginmanager.register(self, name="session") @classmethod - def from_config(cls, config): - return cls._create(config) + def from_config(cls, config: Config) -> "Session": + session = cls._create(config) # type: Session + return session - def __repr__(self): + def __repr__(self) -> str: return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % ( self.__class__.__name__, self.name, @@ -424,14 +431,14 @@ def _node_location_to_relpath(self, node_path: py.path.local) -> str: return self._bestrelpathcache[node_path] @hookimpl(tryfirst=True) - def pytest_collectstart(self): + def pytest_collectstart(self) -> None: if self.shouldfail: raise self.Failed(self.shouldfail) if self.shouldstop: raise self.Interrupted(self.shouldstop) @hookimpl(tryfirst=True) - def pytest_runtest_logreport(self, report): + def pytest_runtest_logreport(self, report) -> None: if report.failed and not hasattr(report, "wasxfail"): self.testsfailed += 1 maxfail = self.config.getvalue("maxfail") @@ -440,13 +447,27 @@ def pytest_runtest_logreport(self, report): pytest_collectreport = pytest_runtest_logreport - def isinitpath(self, path): + def isinitpath(self, path: py.path.local) -> bool: return path in self._initialpaths def gethookproxy(self, fspath: py.path.local): return super()._gethookproxy(fspath) - def perform_collect(self, args=None, genitems=True): + @overload + def perform_collect( + self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ... + ) -> Sequence[nodes.Item]: + raise NotImplementedError() + + @overload # noqa: F811 + def perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]] = ..., genitems: bool = ... + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: + raise NotImplementedError() + + def perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]] = None, genitems: bool = True + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: hook = self.config.hook try: items = self._perform_collect(args, genitems) @@ -459,15 +480,29 @@ def perform_collect(self, args=None, genitems=True): self.testscollected = len(items) return items - def _perform_collect(self, args, genitems): + @overload + def _perform_collect( + self, args: Optional[Sequence[str]], genitems: "Literal[True]" + ) -> Sequence[nodes.Item]: + raise NotImplementedError() + + @overload # noqa: F811 + def _perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]], genitems: bool + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: + raise NotImplementedError() + + def _perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]], genitems: bool + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: if args is None: args = self.config.args self.trace("perform_collect", self, args) self.trace.root.indent += 1 - self._notfound = [] + self._notfound = [] # type: List[Tuple[str, NoMatch]] initialpaths = [] # type: List[py.path.local] self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]] - self.items = items = [] + self.items = items = [] # type: List[nodes.Item] for arg in args: fspath, parts = self._parsearg(arg) self._initial_parts.append((fspath, parts)) @@ -490,7 +525,7 @@ def _perform_collect(self, args, genitems): self.items.extend(self.genitems(node)) return items - def collect(self): + def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]: for fspath, parts in self._initial_parts: self.trace("processing argument", (fspath, parts)) self.trace.root.indent += 1 @@ -500,7 +535,8 @@ def collect(self): report_arg = "::".join((str(fspath), *parts)) # we are inside a make_report hook so # we cannot directly pass through the exception - self._notfound.append((report_arg, sys.exc_info()[1])) + exc = cast(NoMatch, sys.exc_info()[1]) + self._notfound.append((report_arg, exc)) self.trace.root.indent -= 1 self._collection_node_cache1.clear() @@ -508,7 +544,9 @@ def collect(self): self._collection_node_cache3.clear() self._collection_pkg_roots.clear() - def _collect(self, argpath, names): + def _collect( + self, argpath: py.path.local, names: List[str] + ) -> Iterator[Union[nodes.Item, nodes.Collector]]: from _pytest.python import Package # Start with a Session root, and delve to argpath item (dir or file) @@ -536,7 +574,7 @@ def _collect(self, argpath, names): if argpath.check(dir=1): assert not names, "invalid arg {!r}".format((argpath, names)) - seen_dirs = set() + seen_dirs = set() # type: Set[py.path.local] for path in argpath.visit( fil=self._visit_filter, rec=self._recurse, bf=True, sort=True ): @@ -577,8 +615,9 @@ def _collect(self, argpath, names): # Module itself, so just use that. If this special case isn't taken, then all # the files in the package will be yielded. if argpath.basename == "__init__.py": + assert isinstance(m[0], nodes.Collector) try: - yield next(m[0].collect()) + yield next(iter(m[0].collect())) except StopIteration: # The package collects nothing with only an __init__.py # file in it, which gets ignored by the default @@ -587,7 +626,9 @@ def _collect(self, argpath, names): return yield from m - def _collectfile(self, path, handle_dupes=True): + def _collectfile( + self, path: py.path.local, handle_dupes: bool = True + ) -> Sequence[nodes.Collector]: assert ( path.isfile() ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format( @@ -607,13 +648,17 @@ def _collectfile(self, path, handle_dupes=True): else: duplicate_paths.add(path) - return ihook.pytest_collect_file(path=path, parent=self) + collector = ihook.pytest_collect_file( + path=path, parent=self + ) # type: List[nodes.Collector] + return collector @staticmethod - def _visit_filter(f): - return f.check(file=1) + def _visit_filter(f: py.path.local) -> bool: + # TODO: Remove type: ignore once `py` is typed. + return f.check(file=1) # type: ignore - def _tryconvertpyarg(self, x): + def _tryconvertpyarg(self, x: str) -> str: """Convert a dotted module name to path.""" try: spec = importlib.util.find_spec(x) @@ -622,14 +667,14 @@ def _tryconvertpyarg(self, x): # ValueError: not a module name except (AttributeError, ImportError, ValueError): return x - if spec is None or spec.origin in {None, "namespace"}: + if spec is None or spec.origin is None or spec.origin == "namespace": return x elif spec.submodule_search_locations: return os.path.dirname(spec.origin) else: return spec.origin - def _parsearg(self, arg): + def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]: """ return (fspath, names) tuple after checking the file exists. """ strpath, *parts = str(arg).split("::") if self.config.option.pyargs: @@ -645,7 +690,9 @@ def _parsearg(self, arg): fspath = fspath.realpath() return (fspath, parts) - def matchnodes(self, matching, names): + def matchnodes( + self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str], + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: self.trace("matchnodes", matching, names) self.trace.root.indent += 1 nodes = self._matchnodes(matching, names) @@ -656,13 +703,15 @@ def matchnodes(self, matching, names): raise NoMatch(matching, names[:1]) return nodes - def _matchnodes(self, matching, names): + def _matchnodes( + self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str], + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: if not matching or not names: return matching name = names[0] assert name nextnames = names[1:] - resultnodes = [] + resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]] for node in matching: if isinstance(node, nodes.Item): if not names: @@ -693,7 +742,9 @@ def _matchnodes(self, matching, names): node.ihook.pytest_collectreport(report=rep) return resultnodes - def genitems(self, node): + def genitems( + self, node: Union[nodes.Item, nodes.Collector] + ) -> Iterator[nodes.Item]: self.trace("genitems", node) if isinstance(node, nodes.Item): node.ihook.pytest_itemcollected(item=node) diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index 81a25ddd5c2..99453941124 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -3,6 +3,7 @@ from functools import lru_cache from typing import Any from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Set @@ -205,7 +206,7 @@ def warn(self, warning): # methods for ordering nodes @property - def nodeid(self): + def nodeid(self) -> str: """ a ::-separated string denoting its collection tree address. """ return self._nodeid @@ -390,7 +391,7 @@ class Collector(Node): class CollectError(Exception): """ an error during collection, contains a custom message. """ - def collect(self): + def collect(self) -> Iterable[Union["Item", "Collector"]]: """ returns a list of children (items and collectors) for this collection node. """ diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 6402164f9e4..a9264667119 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -4,14 +4,17 @@ import inspect import os import sys +import typing import warnings from collections import Counter from collections import defaultdict from collections.abc import Sequence from functools import partial from typing import Dict +from typing import Iterable from typing import List from typing import Optional +from typing import Set from typing import Tuple from typing import Union @@ -177,16 +180,20 @@ def pytest_pyfunc_call(pyfuncitem: "Function"): return True -def pytest_collect_file(path, parent): +def pytest_collect_file(path: py.path.local, parent) -> Optional["Module"]: ext = path.ext if ext == ".py": if not parent.session.isinitpath(path): if not path_matches_patterns( path, parent.config.getini("python_files") + ["__init__.py"] ): - return + return None ihook = parent.session.gethookproxy(path) - return ihook.pytest_pycollect_makemodule(path=path, parent=parent) + module = ihook.pytest_pycollect_makemodule( + path=path, parent=parent + ) # type: Module + return module + return None def path_matches_patterns(path, patterns): @@ -194,14 +201,16 @@ def path_matches_patterns(path, patterns): return any(path.fnmatch(pattern) for pattern in patterns) -def pytest_pycollect_makemodule(path, parent): +def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module": if path.basename == "__init__.py": - return Package.from_parent(parent, fspath=path) - return Module.from_parent(parent, fspath=path) + pkg = Package.from_parent(parent, fspath=path) # type: Package + return pkg + mod = Module.from_parent(parent, fspath=path) # type: Module + return mod @hookimpl(hookwrapper=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj): outcome = yield res = outcome.get_result() if res is not None: @@ -353,7 +362,7 @@ def _matches_prefix_or_glob_option(self, option_name, name): return True return False - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: if not getattr(self.obj, "__test__", True): return [] @@ -362,8 +371,8 @@ def collect(self): dicts = [getattr(self.obj, "__dict__", {})] for basecls in inspect.getmro(self.obj.__class__): dicts.append(basecls.__dict__) - seen = {} - values = [] + seen = {} # type: Dict[str, bool] + values = [] # type: List[Union[nodes.Item, nodes.Collector]] for dic in dicts: for name, obj in list(dic.items()): if name in seen: @@ -383,9 +392,16 @@ def sort_key(item): values.sort(key=sort_key) return values - def _makeitem(self, name, obj): + def _makeitem( + self, name: str, obj + ) -> Union[ + None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]] + ]: # assert self.ihook.fspath == self.fspath, self - return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj) + item = self.ihook.pytest_pycollect_makeitem( + collector=self, name=name, obj=obj + ) # type: Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]] + return item def _genfunctions(self, name, funcobj): module = self.getparent(Module).obj @@ -437,7 +453,7 @@ class Module(nodes.File, PyCollector): def _getobj(self): return self._importtestmodule() - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: self._inject_setup_module_fixture() self._inject_setup_function_fixture() self.session._fixturemanager.parsefactories(self) @@ -584,7 +600,9 @@ def setup(self): def gethookproxy(self, fspath: py.path.local): return super()._gethookproxy(fspath) - def _collectfile(self, path, handle_dupes=True): + def _collectfile( + self, path: py.path.local, handle_dupes: bool = True + ) -> typing.Sequence[nodes.Collector]: assert ( path.isfile() ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format( @@ -607,19 +625,22 @@ def _collectfile(self, path, handle_dupes=True): if self.fspath == path: # __init__.py return [self] - return ihook.pytest_collect_file(path=path, parent=self) + collectors = ihook.pytest_collect_file( + path=path, parent=self + ) # type: List[nodes.Collector] + return collectors - def isinitpath(self, path): + def isinitpath(self, path: py.path.local) -> bool: return path in self.session._initialpaths - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: this_path = self.fspath.dirpath() init_module = this_path.join("__init__.py") if init_module.check(file=1) and path_matches_patterns( init_module, self.config.getini("python_files") ): yield Module.from_parent(self, fspath=init_module) - pkg_prefixes = set() + pkg_prefixes = set() # type: Set[py.path.local] for path in this_path.visit(rec=self._recurse, bf=True, sort=True): # We will visit our own __init__.py file, in which case we skip it. is_file = path.isfile() @@ -676,10 +697,11 @@ def from_parent(cls, parent, *, name, obj=None): """ return super().from_parent(name=name, parent=parent) - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: if not safe_getattr(self.obj, "__test__", True): return [] if hasinit(self.obj): + assert self.parent is not None self.warn( PytestCollectionWarning( "cannot collect test class %r because it has a " @@ -689,6 +711,7 @@ def collect(self): ) return [] elif hasnew(self.obj): + assert self.parent is not None self.warn( PytestCollectionWarning( "cannot collect test class %r because it has a " @@ -762,7 +785,7 @@ class Instance(PyCollector): def _getobj(self): return self.parent.obj() - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: self.session._fixturemanager.parsefactories(self) return super().collect() diff --git a/src/_pytest/reports.py b/src/_pytest/reports.py index 3ad67c224c4..d3bfc91979a 100644 --- a/src/_pytest/reports.py +++ b/src/_pytest/reports.py @@ -20,7 +20,8 @@ from _pytest._code.code import TerminalRepr from _pytest._io import TerminalWriter from _pytest.compat import TYPE_CHECKING -from _pytest.nodes import Node +from _pytest.nodes import Collector +from _pytest.nodes import Item from _pytest.outcomes import skip from _pytest.pathlib import Path @@ -315,7 +316,13 @@ class CollectReport(BaseReport): when = "collect" def __init__( - self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra + self, + nodeid: str, + outcome, + longrepr, + result: Optional[List[Union[Item, Collector]]], + sections=(), + **extra ) -> None: self.nodeid = nodeid self.outcome = outcome diff --git a/src/_pytest/runner.py b/src/_pytest/runner.py index e10e4d8bdf0..3eb9ca28667 100644 --- a/src/_pytest/runner.py +++ b/src/_pytest/runner.py @@ -373,10 +373,10 @@ def prepare(self, colitem): raise -def collect_one_node(collector): +def collect_one_node(collector: Collector) -> CollectReport: ihook = collector.ihook ihook.pytest_collectstart(collector=collector) - rep = ihook.pytest_make_collect_report(collector=collector) + rep = ihook.pytest_make_collect_report(collector=collector) # type: CollectReport call = rep.__dict__.pop("call", None) if call and check_interactive_exception(call, rep): ihook.pytest_exception_interact(node=collector, call=call, report=rep) diff --git a/src/_pytest/unittest.py b/src/_pytest/unittest.py index a5512e9443c..3740ff1249a 100644 --- a/src/_pytest/unittest.py +++ b/src/_pytest/unittest.py @@ -2,29 +2,40 @@ import functools import sys import traceback +from typing import Iterable +from typing import Optional +from typing import Union import _pytest._code import pytest from _pytest.compat import getimfunc from _pytest.config import hookimpl +from _pytest.nodes import Collector +from _pytest.nodes import Item from _pytest.outcomes import exit from _pytest.outcomes import fail from _pytest.outcomes import skip from _pytest.outcomes import xfail from _pytest.python import Class from _pytest.python import Function +from _pytest.python import PyCollector from _pytest.runner import CallInfo -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem( + collector: PyCollector, name: str, obj +) -> Optional["UnitTestCase"]: # has unittest been imported and is obj a subclass of its TestCase? try: - if not issubclass(obj, sys.modules["unittest"].TestCase): - return + ut = sys.modules["unittest"] + # Type ignored because `ut` is an opaque module. + if not issubclass(obj, ut.TestCase): # type: ignore + return None except Exception: - return + return None # yes, so let's collect it - return UnitTestCase.from_parent(collector, name=name, obj=obj) + item = UnitTestCase.from_parent(collector, name=name, obj=obj) # type: UnitTestCase + return item class UnitTestCase(Class): @@ -32,7 +43,7 @@ class UnitTestCase(Class): # to declare that our children do not support funcargs nofuncargs = True - def collect(self): + def collect(self) -> Iterable[Union[Item, Collector]]: from unittest import TestLoader cls = self.obj @@ -59,8 +70,8 @@ def collect(self): runtest = getattr(self.obj, "runTest", None) if runtest is not None: ut = sys.modules.get("twisted.trial.unittest", None) - if ut is None or runtest != ut.TestCase.runTest: - # TODO: callobj consistency + # Type ignored because `ut` is an opaque module. + if ut is None or runtest != ut.TestCase.runTest: # type: ignore yield TestCaseFunction.from_parent(self, name="runTest") def _inject_setup_teardown_fixtures(self, cls):