Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow type checkers to infer trait types #788

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

## 5.5.0

* Clean up application typing
* Update tests and docs to use non-deprecated functions
* Clean up version handling
* Prep for jupyter releaser
* Format the changelog
- Clean up application typing
- Update tests and docs to use non-deprecated functions
- Clean up version handling
- Prep for jupyter releaser
- Format the changelog

<!-- <END NEW CHANGELOG ENTRY> -->

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ requires-python = ">=3.7"
dynamic = ["description", "version"]

[project.optional-dependencies]
test = ["pytest", "pre-commit"]
test = ["pytest", "pre-commit", "pytest-mypy-testing"]
docs = [
"myst-parser",
"pydata-sphinx-theme",
Expand Down
28 changes: 28 additions & 0 deletions traitlets/tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from lib2to3 import pytree

import pytest
from typing_extensions import reveal_type

from traitlets import Bool, HasTraits, TCPAddress


@pytest.mark.mypy_testing
def mypy_bool_typing():
class T(HasTraits):
b = Bool().tag(sync=True)

t = T()
reveal_type(t.b) # R: Union[builtins.bool, None]
# we would expect this to be Optional[Union[bool, int]], but...
t.b = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[int]") [assignment]
T.b.tag(foo=True)


@pytest.mark.mypy_testing
def mypy_tcp_typing():
class T(HasTraits):
tcp = TCPAddress()

t = T()
reveal_type(t.tcp) # R: Union[Tuple[builtins.str, builtins.int], None]
t.tcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Tuple[str, int]]") [assignment]
69 changes: 39 additions & 30 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
# also under the terms of the Modified BSD License.

import collections.abc
import contextlib
import enum
import inspect
Expand Down Expand Up @@ -514,7 +515,13 @@ def instance_init(self, obj):
pass


class TraitType(BaseDescriptor):
G = t.TypeVar("G")
S = t.TypeVar("S")

Self = t.TypeVar("Self", bound="TraitType") # Holdover waiting for typings.Self in Python 3.11


class TraitType(BaseDescriptor, t.Generic[G, S]):
"""A base class for all trait types."""

metadata: t.Dict[str, t.Any] = {}
Expand Down Expand Up @@ -640,9 +647,9 @@ def init_default_value(self, obj):
obj._trait_values[self.name] = value
return value

def get(self, obj, cls=None):
def get(self, obj: "HasTraits", cls: t.Any = None) -> t.Optional[G]:
try:
value = obj._trait_values[self.name]
value = obj._trait_values[self.name] # type: ignore
except KeyError:
# Check for a dynamic initializer.
default = obj.trait_defaults(self.name)
Expand All @@ -656,7 +663,7 @@ def get(self, obj, cls=None):
)
with obj.cross_validation_lock:
value = self._validate(obj, default)
obj._trait_values[self.name] = value
obj._trait_values[self.name] = value # type: ignore
obj._notify_observers(
Bunch(
name=self.name,
Expand All @@ -665,14 +672,14 @@ def get(self, obj, cls=None):
type="default",
)
)
return value
return value # type: ignore
except Exception:
# This should never be reached.
raise TraitError("Unexpected error in TraitType: default value not set properly")
else:
return value
return value # type: ignore

def __get__(self, obj, cls=None):
def __get__(self, obj: "HasTraits", cls: t.Any = None) -> t.Optional[G]:
"""Get the value of the trait by self.name for the instance.

Default values are instantiated when :meth:`HasTraits.__new__`
Expand Down Expand Up @@ -703,7 +710,7 @@ def set(self, obj, value):
# comparison above returns something other than True/False
obj._notify_trait(self.name, old_value, new_value)

def __set__(self, obj, value):
def __set__(self, obj: "HasTraits", value: t.Optional[S]) -> None:
"""Set the value of the trait by self.name for the instance.

Values pass through a validation stage where errors are raised when
Expand Down Expand Up @@ -850,7 +857,7 @@ def set_metadata(self, key, value):
warn("Deprecated in traitlets 4.1, " + msg, DeprecationWarning, stacklevel=2)
self.metadata[key] = value

def tag(self, **metadata):
def tag(self, **metadata) -> "TraitType[G, S]":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the typing_extension package we can do this:

Suggested change
def tag(self, **metadata) -> "TraitType[G, S]":
def tag(self, **metadata) -> typing_extension.Self":

"""Sets metadata and returns self.

This allows convenient metadata tagging when initializing the trait, such as:
Expand Down Expand Up @@ -1859,7 +1866,7 @@ def trait_events(cls, name=None):
# -----------------------------------------------------------------------------


class ClassBasedTraitType(TraitType):
class ClassBasedTraitType(TraitType[t.Any, t.Any]):
"""
A trait with error reporting and string -> type resolution for Type,
Instance and This.
Expand Down Expand Up @@ -2117,7 +2124,7 @@ def validate(self, obj, value):
self.error(obj, value)


class Union(TraitType):
class Union(TraitType[t.Any, t.Any]):
"""A trait type representing a Union type."""

def __init__(self, trait_types, **kwargs):
Expand Down Expand Up @@ -2205,7 +2212,7 @@ def from_string(self, s):
# -----------------------------------------------------------------------------


class Any(TraitType):
class Any(TraitType[t.Optional[t.Any], t.Optional[t.Any]]):
"""A trait which allows any value."""

default_value: t.Optional[t.Any] = None
Expand Down Expand Up @@ -2239,7 +2246,7 @@ def _validate_bounds(trait, obj, value):
return value


class Int(TraitType):
class Int(TraitType[int, int]):
"""An int trait."""

default_value = 0
Expand All @@ -2261,7 +2268,7 @@ def from_string(self, s):
return int(s)


class CInt(Int):
class CInt(Int, TraitType[int, t.Any]):
"""A casting version of the int trait."""

def validate(self, obj, value):
Expand All @@ -2276,7 +2283,7 @@ def validate(self, obj, value):
Integer = Int


class Float(TraitType):
class Float(TraitType[float, int | float]):
"""A float trait."""

default_value = 0.0
Expand All @@ -2300,7 +2307,7 @@ def from_string(self, s):
return float(s)


class CFloat(Float):
class CFloat(Float, TraitType[float, t.Any]):
"""A casting version of the float trait."""

def validate(self, obj, value):
Expand All @@ -2311,7 +2318,7 @@ def validate(self, obj, value):
return _validate_bounds(self, obj, value)


class Complex(TraitType):
class Complex(TraitType[complex, complex | float | int]):
"""A trait for complex numbers."""

default_value = 0.0 + 0.0j
Expand All @@ -2330,7 +2337,7 @@ def from_string(self, s):
return complex(s)


class CComplex(Complex):
class CComplex(Complex, TraitType[complex, t.Any]):
"""A casting version of the complex number trait."""

def validate(self, obj, value):
Expand All @@ -2343,7 +2350,7 @@ def validate(self, obj, value):
# We should always be explicit about whether we're using bytes or unicode, both
# for Python 3 conversion and for reliable unicode behaviour on Python 2. So
# we don't have a Str type.
class Bytes(TraitType):
class Bytes(TraitType[bytes, bytes]):
"""A trait for byte strings."""

default_value = b""
Expand Down Expand Up @@ -2372,7 +2379,7 @@ def from_string(self, s):
return s.encode("utf8")


class CBytes(Bytes):
class CBytes(Bytes, TraitType[bytes, t.Any]):
"""A casting version of the byte string trait."""

def validate(self, obj, value):
Expand All @@ -2382,7 +2389,7 @@ def validate(self, obj, value):
self.error(obj, value)


class Unicode(TraitType):
class Unicode(TraitType[str, str | bytes]):
"""A trait for unicode strings."""

default_value = ""
Expand Down Expand Up @@ -2417,7 +2424,7 @@ def from_string(self, s):
return s


class CUnicode(Unicode):
class CUnicode(Unicode, TraitType[str, t.Any]):
"""A casting version of the unicode trait."""

def validate(self, obj, value):
Expand All @@ -2427,7 +2434,7 @@ def validate(self, obj, value):
self.error(obj, value)


class ObjectName(TraitType):
class ObjectName(TraitType[str, str]):
"""A string holding a valid object name in this version of Python.

This does not check that the name exists in any scope."""
Expand Down Expand Up @@ -2460,7 +2467,7 @@ def validate(self, obj, value):
self.error(obj, value)


class Bool(TraitType):
class Bool(TraitType[bool, t.Union[bool, int]]):
"""A boolean (True, False) trait."""

default_value = False
Expand Down Expand Up @@ -2488,7 +2495,7 @@ def from_string(self, s):
raise ValueError("%r is not 1, 0, true, or false")


class CBool(Bool):
class CBool(Bool, TraitType[bool, t.Any]):
"""A casting version of the boolean trait."""

def validate(self, obj, value):
Expand All @@ -2498,7 +2505,7 @@ def validate(self, obj, value):
self.error(obj, value)


class Enum(TraitType):
class Enum(TraitType[t.Any, t.Any]):
"""An enum whose value must be in a given sequence."""

def __init__(self, values, default_value=Undefined, **kwargs):
Expand Down Expand Up @@ -3323,7 +3330,7 @@ def item_from_string(self, s):
return {key: value}


class TCPAddress(TraitType):
class TCPAddress(TraitType[tuple[str, int], tuple[str, int]]):
"""A trait for an (ip, port) tuple.

This allows for both IPv4 IP addresses as well as hostnames.
Expand Down Expand Up @@ -3351,7 +3358,7 @@ def from_string(self, s):
return (ip, port)


class CRegExp(TraitType):
class CRegExp(TraitType[re.Pattern[t.Any], re.Pattern[t.Any] | str]):
"""A casting compiled regular expression trait.

Accepts both strings and compiled regular expressions. The resulting
Expand All @@ -3366,7 +3373,7 @@ def validate(self, obj, value):
self.error(obj, value)


class UseEnum(TraitType):
class UseEnum(TraitType[t.Any, t.Any]):
"""Use a Enum class as model for the data type description.
Note that if no default-value is provided, the first enum-value is used
as default-value.
Expand Down Expand Up @@ -3463,7 +3470,9 @@ def info_rst(self):
return self._info(as_rst=True)


class Callable(TraitType):
class Callable(
TraitType[collections.abc.Callable[..., t.Any], collections.abc.Callable[..., t.Any]]
):
"""A trait which is callable.

Notes
Expand Down