diff --git a/CHANGELOG.md b/CHANGELOG.md index ba4fb02d..7cf1cc93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,11 @@ ## 5.5.0 -* Clean up application typing -* Update tests and docs to use non-deprecated functions -* Clean up version handling -* Prep for jupyter releaser -* Format the changelog +- Clean up application typing +- Update tests and docs to use non-deprecated functions +- Clean up version handling +- Prep for jupyter releaser +- Format the changelog diff --git a/pyproject.toml b/pyproject.toml index 993179a9..e0b8ed86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ requires-python = ">=3.7" dynamic = ["description", "version"] [project.optional-dependencies] -test = ["pytest", "pre-commit"] +test = ["pytest", "pre-commit", "pytest-mypy-testing"] docs = [ "myst-parser", "pydata-sphinx-theme", diff --git a/traitlets/tests/test_typing.py b/traitlets/tests/test_typing.py new file mode 100644 index 00000000..b40b4b7e --- /dev/null +++ b/traitlets/tests/test_typing.py @@ -0,0 +1,28 @@ +from lib2to3 import pytree + +import pytest +from typing_extensions import reveal_type + +from traitlets import Bool, HasTraits, TCPAddress + + +@pytest.mark.mypy_testing +def mypy_bool_typing(): + class T(HasTraits): + b = Bool().tag(sync=True) + + t = T() + reveal_type(t.b) # R: Union[builtins.bool, None] + # we would expect this to be Optional[Union[bool, int]], but... + t.b = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[int]") [assignment] + T.b.tag(foo=True) + + +@pytest.mark.mypy_testing +def mypy_tcp_typing(): + class T(HasTraits): + tcp = TCPAddress() + + t = T() + reveal_type(t.tcp) # R: Union[Tuple[builtins.str, builtins.int], None] + t.tcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Tuple[str, int]]") [assignment] diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 7489d019..6b9bbece 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -39,6 +39,7 @@ # Adapted from enthought.traits, Copyright (c) Enthought, Inc., # also under the terms of the Modified BSD License. +import collections.abc import contextlib import enum import inspect @@ -514,7 +515,13 @@ def instance_init(self, obj): pass -class TraitType(BaseDescriptor): +G = t.TypeVar("G") +S = t.TypeVar("S") + +Self = t.TypeVar("Self", bound="TraitType") # Holdover waiting for typings.Self in Python 3.11 + + +class TraitType(BaseDescriptor, t.Generic[G, S]): """A base class for all trait types.""" metadata: t.Dict[str, t.Any] = {} @@ -640,9 +647,9 @@ def init_default_value(self, obj): obj._trait_values[self.name] = value return value - def get(self, obj, cls=None): + def get(self, obj: "HasTraits", cls: t.Any = None) -> t.Optional[G]: try: - value = obj._trait_values[self.name] + value = obj._trait_values[self.name] # type: ignore except KeyError: # Check for a dynamic initializer. default = obj.trait_defaults(self.name) @@ -656,7 +663,7 @@ def get(self, obj, cls=None): ) with obj.cross_validation_lock: value = self._validate(obj, default) - obj._trait_values[self.name] = value + obj._trait_values[self.name] = value # type: ignore obj._notify_observers( Bunch( name=self.name, @@ -665,14 +672,14 @@ def get(self, obj, cls=None): type="default", ) ) - return value + return value # type: ignore except Exception: # This should never be reached. raise TraitError("Unexpected error in TraitType: default value not set properly") else: - return value + return value # type: ignore - def __get__(self, obj, cls=None): + def __get__(self, obj: "HasTraits", cls: t.Any = None) -> t.Optional[G]: """Get the value of the trait by self.name for the instance. Default values are instantiated when :meth:`HasTraits.__new__` @@ -703,7 +710,7 @@ def set(self, obj, value): # comparison above returns something other than True/False obj._notify_trait(self.name, old_value, new_value) - def __set__(self, obj, value): + def __set__(self, obj: "HasTraits", value: t.Optional[S]) -> None: """Set the value of the trait by self.name for the instance. Values pass through a validation stage where errors are raised when @@ -850,7 +857,7 @@ def set_metadata(self, key, value): warn("Deprecated in traitlets 4.1, " + msg, DeprecationWarning, stacklevel=2) self.metadata[key] = value - def tag(self, **metadata): + def tag(self, **metadata) -> "TraitType[G, S]": """Sets metadata and returns self. This allows convenient metadata tagging when initializing the trait, such as: @@ -1859,7 +1866,7 @@ def trait_events(cls, name=None): # ----------------------------------------------------------------------------- -class ClassBasedTraitType(TraitType): +class ClassBasedTraitType(TraitType[t.Any, t.Any]): """ A trait with error reporting and string -> type resolution for Type, Instance and This. @@ -2117,7 +2124,7 @@ def validate(self, obj, value): self.error(obj, value) -class Union(TraitType): +class Union(TraitType[t.Any, t.Any]): """A trait type representing a Union type.""" def __init__(self, trait_types, **kwargs): @@ -2205,7 +2212,7 @@ def from_string(self, s): # ----------------------------------------------------------------------------- -class Any(TraitType): +class Any(TraitType[t.Optional[t.Any], t.Optional[t.Any]]): """A trait which allows any value.""" default_value: t.Optional[t.Any] = None @@ -2239,7 +2246,7 @@ def _validate_bounds(trait, obj, value): return value -class Int(TraitType): +class Int(TraitType[int, int]): """An int trait.""" default_value = 0 @@ -2261,7 +2268,7 @@ def from_string(self, s): return int(s) -class CInt(Int): +class CInt(Int, TraitType[int, t.Any]): """A casting version of the int trait.""" def validate(self, obj, value): @@ -2276,7 +2283,7 @@ def validate(self, obj, value): Integer = Int -class Float(TraitType): +class Float(TraitType[float, int | float]): """A float trait.""" default_value = 0.0 @@ -2300,7 +2307,7 @@ def from_string(self, s): return float(s) -class CFloat(Float): +class CFloat(Float, TraitType[float, t.Any]): """A casting version of the float trait.""" def validate(self, obj, value): @@ -2311,7 +2318,7 @@ def validate(self, obj, value): return _validate_bounds(self, obj, value) -class Complex(TraitType): +class Complex(TraitType[complex, complex | float | int]): """A trait for complex numbers.""" default_value = 0.0 + 0.0j @@ -2330,7 +2337,7 @@ def from_string(self, s): return complex(s) -class CComplex(Complex): +class CComplex(Complex, TraitType[complex, t.Any]): """A casting version of the complex number trait.""" def validate(self, obj, value): @@ -2343,7 +2350,7 @@ def validate(self, obj, value): # We should always be explicit about whether we're using bytes or unicode, both # for Python 3 conversion and for reliable unicode behaviour on Python 2. So # we don't have a Str type. -class Bytes(TraitType): +class Bytes(TraitType[bytes, bytes]): """A trait for byte strings.""" default_value = b"" @@ -2372,7 +2379,7 @@ def from_string(self, s): return s.encode("utf8") -class CBytes(Bytes): +class CBytes(Bytes, TraitType[bytes, t.Any]): """A casting version of the byte string trait.""" def validate(self, obj, value): @@ -2382,7 +2389,7 @@ def validate(self, obj, value): self.error(obj, value) -class Unicode(TraitType): +class Unicode(TraitType[str, str | bytes]): """A trait for unicode strings.""" default_value = "" @@ -2417,7 +2424,7 @@ def from_string(self, s): return s -class CUnicode(Unicode): +class CUnicode(Unicode, TraitType[str, t.Any]): """A casting version of the unicode trait.""" def validate(self, obj, value): @@ -2427,7 +2434,7 @@ def validate(self, obj, value): self.error(obj, value) -class ObjectName(TraitType): +class ObjectName(TraitType[str, str]): """A string holding a valid object name in this version of Python. This does not check that the name exists in any scope.""" @@ -2460,7 +2467,7 @@ def validate(self, obj, value): self.error(obj, value) -class Bool(TraitType): +class Bool(TraitType[bool, t.Union[bool, int]]): """A boolean (True, False) trait.""" default_value = False @@ -2488,7 +2495,7 @@ def from_string(self, s): raise ValueError("%r is not 1, 0, true, or false") -class CBool(Bool): +class CBool(Bool, TraitType[bool, t.Any]): """A casting version of the boolean trait.""" def validate(self, obj, value): @@ -2498,7 +2505,7 @@ def validate(self, obj, value): self.error(obj, value) -class Enum(TraitType): +class Enum(TraitType[t.Any, t.Any]): """An enum whose value must be in a given sequence.""" def __init__(self, values, default_value=Undefined, **kwargs): @@ -3323,7 +3330,7 @@ def item_from_string(self, s): return {key: value} -class TCPAddress(TraitType): +class TCPAddress(TraitType[tuple[str, int], tuple[str, int]]): """A trait for an (ip, port) tuple. This allows for both IPv4 IP addresses as well as hostnames. @@ -3351,7 +3358,7 @@ def from_string(self, s): return (ip, port) -class CRegExp(TraitType): +class CRegExp(TraitType[re.Pattern[t.Any], re.Pattern[t.Any] | str]): """A casting compiled regular expression trait. Accepts both strings and compiled regular expressions. The resulting @@ -3366,7 +3373,7 @@ def validate(self, obj, value): self.error(obj, value) -class UseEnum(TraitType): +class UseEnum(TraitType[t.Any, t.Any]): """Use a Enum class as model for the data type description. Note that if no default-value is provided, the first enum-value is used as default-value. @@ -3463,7 +3470,9 @@ def info_rst(self): return self._info(as_rst=True) -class Callable(TraitType): +class Callable( + TraitType[collections.abc.Callable[..., t.Any], collections.abc.Callable[..., t.Any]] +): """A trait which is callable. Notes