Skip to content

Commit

Permalink
Use generic defaults for Counter
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto committed Jul 16, 2024
1 parent 1b9e90b commit 8790f72
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 22 deletions.
45 changes: 45 additions & 0 deletions stdlib/@tests/test_cases/collections/check_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from collections import Counter
from decimal import Decimal
from typing_extensions import assert_type

# Initialize a Counter for strings with integer values
word_counts: Counter[str] = Counter()
word_counts["foo"] += 3
word_counts["bar"] += 2
assert_type(word_counts, "Counter[str, int]")

# Initialize a Counter for strings with float values
floating_point_counts: Counter[str, float] = Counter()
floating_point_counts["foo"] += 3.0
floating_point_counts["bar"] += 5.0

# Initialize a Counter for strings with Decimal values
decimal_counts: Counter[str, Decimal] = Counter()
decimal_counts["foo"] += Decimal("3.0")
decimal_counts["bar"] += Decimal("5.0")

# Counter combining integers and floats
mixed_type_counter = Counter({"foo": 3, "bar": 2.5})
mixed_type_counter["baz"] += 1.5
assert_type(mixed_type_counter, "Counter[str, int | float]")

# Check ORing and ANDing Counters with different value types
assert_type(mixed_type_counter or decimal_counts, "Counter[str, int | float] | Counter[str, Decimal]")
assert_type(decimal_counts or mixed_type_counter, "Counter[str, Decimal] | Counter[str, int | float]")
assert_type(mixed_type_counter and decimal_counts, "Counter[str, int | float] | Counter[str, Decimal]")
assert_type(decimal_counts and mixed_type_counter, "Counter[str, Decimal] | Counter[str, int | float]")

# We shouldn't be able to add Counters with incompatible value types
_ = mixed_type_counter + decimal_counts # type: ignore
mixed_type_counter += decimal_counts # type: ignore

# Adding Counters with compatible types
_ = word_counts + Counter[str]()
word_counts += Counter[str]()

# Combining Counters of different key types
integer_key_counter = Counter[int]()
combined_word_and_integer_keys = word_counts + integer_key_counter
assert_type(combined_word_and_integer_keys, "Counter[str | int, int]")
46 changes: 24 additions & 22 deletions stdlib/collections/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from typing import Any, Generic, NoReturn, SupportsIndex, TypeVar, final, overload
from typing_extensions import Self
from typing import Any, Generic, NoReturn, SupportsIndex, final, overload
from typing_extensions import Self, TypeVar

if sys.version_info >= (3, 9):
from types import GenericAlias
Expand All @@ -28,6 +28,8 @@ __all__ = ["ChainMap", "Counter", "OrderedDict", "UserDict", "UserList", "UserSt

_S = TypeVar("_S")
_T = TypeVar("_T")
_V = TypeVar("_V")
_V_I = TypeVar("_V_I", default=int)
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_KT = TypeVar("_KT")
Expand Down Expand Up @@ -273,24 +275,24 @@ class deque(MutableSequence[_T]):
if sys.version_info >= (3, 9):
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...

class Counter(dict[_T, int], Generic[_T]):
class Counter(dict[_T, _V_I], Generic[_T, _V_I]):
@overload
def __init__(self, iterable: None = None, /) -> None: ...
@overload
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ...
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: _V_I) -> None: ...
@overload
def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ...
def __init__(self, mapping: SupportsKeysAndGetItem[_T, _V_I], /) -> None: ...
@overload
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def copy(self) -> Self: ...
def elements(self) -> Iterator[_T]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, _V_I]]: ...
@classmethod
def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override]
@overload
def subtract(self, iterable: None = None, /) -> None: ...
@overload
def subtract(self, mapping: Mapping[_T, int], /) -> None: ...
def subtract(self, mapping: Mapping[_T, _V_I], /) -> None: ...
@overload
def subtract(self, iterable: Iterable[_T], /) -> None: ...
# Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload
Expand All @@ -305,29 +307,29 @@ class Counter(dict[_T, int], Generic[_T]):
def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ...
@overload
def update(self, iterable: None = None, /, **kwargs: int) -> None: ...
def __missing__(self, key: _T) -> int: ...
def __missing__(self, key: _T) -> _V_I: ...
def __delitem__(self, elem: object) -> None: ...
if sys.version_info >= (3, 10):
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...

def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ...
def __sub__(self, other: Counter[_T]) -> Counter[_T]: ...
def __and__(self, other: Counter[_T]) -> Counter[_T]: ...
def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T]: ...
def __neg__(self) -> Counter[_T]: ...
def __add__(self, other: Counter[_S, _V_I]) -> Counter[_T | _S, _V_I]: ...
def __sub__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ...
def __and__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ...
def __or__(self, other: Counter[_S, _V]) -> Counter[_T | _S, _V_I | _V]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T, _V_I]: ...
def __neg__(self) -> Counter[_T, _V_I]: ...
# several type: ignores because __iadd__ is supposedly incompatible with __add__, etc.
def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, int]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, int]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc]
def __iadd__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, _V_I]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, _V_I]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[override,misc]
if sys.version_info >= (3, 10):
def total(self) -> int: ...
def __le__(self, other: Counter[Any]) -> bool: ...
def __lt__(self, other: Counter[Any]) -> bool: ...
def __ge__(self, other: Counter[Any]) -> bool: ...
def __gt__(self, other: Counter[Any]) -> bool: ...
def __le__(self, other: Counter[Any, _V_I]) -> bool: ...
def __lt__(self, other: Counter[Any, _V_I]) -> bool: ...
def __ge__(self, other: Counter[Any, _V_I]) -> bool: ...
def __gt__(self, other: Counter[Any, _V_I]) -> bool: ...

# The pure-Python implementations of the "views" classes
# These are exposed at runtime in `collections/__init__.py`
Expand Down

0 comments on commit 8790f72

Please sign in to comment.