Skip to content

Commit

Permalink
builtins.sum: Items in the iterable must support addition with `int…
Browse files Browse the repository at this point in the history
…` if no `start` value is given (#8000)
  • Loading branch information
AlexWaygood authored Jun 13, 2022
1 parent 7c47324 commit 1828ba2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
7 changes: 5 additions & 2 deletions stdlib/_typeshed/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichC

# Dunder protocols

class SupportsAdd(Protocol):
def __add__(self, __x: Any) -> Any: ...
class SupportsAdd(Protocol[_T_contra, _T_co]):
def __add__(self, __x: _T_contra) -> _T_co: ...

class SupportsRAdd(Protocol[_T_contra, _T_co]):
def __radd__(self, __x: _T_contra) -> _T_co: ...

class SupportsDivMod(Protocol[_T_contra, _T_co]):
def __divmod__(self, __other: _T_contra) -> _T_co: ...
Expand Down
15 changes: 10 additions & 5 deletions stdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from _typeshed import (
SupportsKeysAndGetItem,
SupportsLenAndGetItem,
SupportsNext,
SupportsRAdd,
SupportsRDivMod,
SupportsRichComparison,
SupportsRichComparisonT,
Expand Down Expand Up @@ -1637,8 +1638,12 @@ def sorted(
@overload
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> list[_T]: ...

_SumT = TypeVar("_SumT", bound=SupportsAdd)
_SumS = TypeVar("_SumS", bound=SupportsAdd)
_AddableT1 = TypeVar("_AddableT1", bound=SupportsAdd[Any, Any])
_AddableT2 = TypeVar("_AddableT2", bound=SupportsAdd[Any, Any])

class _SupportsSumWithNoDefaultGiven(SupportsAdd[Any, Any], SupportsRAdd[int, Any], Protocol): ...

_SupportsSumNoDefaultT = TypeVar("_SupportsSumNoDefaultT", bound=_SupportsSumWithNoDefaultGiven)

# In general, the return type of `x + x` is *not* guaranteed to be the same type as x.
# However, we can't express that in the stub for `sum()`
Expand All @@ -1653,15 +1658,15 @@ else:
def sum(__iterable: Iterable[bool], __start: int = ...) -> int: ... # type: ignore[misc]

@overload
def sum(__iterable: Iterable[_SumT]) -> _SumT | Literal[0]: ...
def sum(__iterable: Iterable[_SupportsSumNoDefaultT]) -> _SupportsSumNoDefaultT | Literal[0]: ...

if sys.version_info >= (3, 8):
@overload
def sum(__iterable: Iterable[_SumT], start: _SumS) -> _SumT | _SumS: ...
def sum(__iterable: Iterable[_AddableT1], start: _AddableT2) -> _AddableT1 | _AddableT2: ...

else:
@overload
def sum(__iterable: Iterable[_SumT], __start: _SumS) -> _SumT | _SumS: ...
def sum(__iterable: Iterable[_AddableT1], __start: _AddableT2) -> _AddableT1 | _AddableT2: ...

# The argument to `vars()` has to have a `__dict__` attribute, so can't be annotated with `object`
# (A "SupportsDunderDict" protocol doesn't work)
Expand Down
52 changes: 52 additions & 0 deletions test_cases/stdlib/builtins/test_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# pyright: reportUnnecessaryTypeIgnoreComment=true

from typing import Any, List, Union
from typing_extensions import Literal, assert_type


class Foo:
def __add__(self, other: Any) -> "Foo":
return Foo()


class Bar:
def __radd__(self, other: Any) -> "Bar":
return Bar()


class Baz:
def __add__(self, other: Any) -> "Baz":
return Baz()

def __radd__(self, other: Any) -> "Baz":
return Baz()


assert_type(sum([2, 4]), int)
assert_type(sum([3, 5], 4), int)

assert_type(sum([True, False]), int)
assert_type(sum([True, False], True), int)

assert_type(sum([["foo"], ["bar"]], ["baz"]), List[str])

assert_type(sum([Foo(), Foo()], Foo()), Foo)
assert_type(sum([Baz(), Baz()]), Union[Baz, Literal[0]])

# mypy and pyright infer the types differently for these, so we can't use assert_type
# Just test that no error is emitted for any of these
sum([("foo",), ("bar", "baz")], ()) # mypy: `tuple[str, ...]`; pyright: `tuple[()] | tuple[str] | tuple[str, str]`
sum([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
sum([2.5, 5.8], 5) # mypy: `float`; pyright: `float | int`

# These all fail at runtime
sum("abcde") # type: ignore[arg-type]
sum([["foo"], ["bar"]]) # type: ignore[list-item]
sum([("foo",), ("bar", "baz")]) # type: ignore[list-item]
sum([Foo(), Foo()]) # type: ignore[list-item]
sum([Bar(), Bar()], Bar()) # type: ignore[call-overload]
sum([Bar(), Bar()]) # type: ignore[list-item]

# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
# sum([3, Fraction(7, 22), complex(8, 0), 9.83])
# sum([3, Decimal('0.98')])

0 comments on commit 1828ba2

Please sign in to comment.