From 8790f7292af8ce5f6e172d1fc936ffa8fcd00001 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Mon, 15 Jul 2024 19:38:37 -0500 Subject: [PATCH] Use generic defaults for Counter --- .../test_cases/collections/check_counter.py | 45 ++++++++++++++++++ stdlib/collections/__init__.pyi | 46 ++++++++++--------- 2 files changed, 69 insertions(+), 22 deletions(-) create mode 100644 stdlib/@tests/test_cases/collections/check_counter.py diff --git a/stdlib/@tests/test_cases/collections/check_counter.py b/stdlib/@tests/test_cases/collections/check_counter.py new file mode 100644 index 000000000000..c30b26df6383 --- /dev/null +++ b/stdlib/@tests/test_cases/collections/check_counter.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from collections import Counter +from decimal import Decimal +from typing_extensions import assert_type + +# Initialize a Counter for strings with integer values +word_counts: Counter[str] = Counter() +word_counts["foo"] += 3 +word_counts["bar"] += 2 +assert_type(word_counts, "Counter[str, int]") + +# Initialize a Counter for strings with float values +floating_point_counts: Counter[str, float] = Counter() +floating_point_counts["foo"] += 3.0 +floating_point_counts["bar"] += 5.0 + +# Initialize a Counter for strings with Decimal values +decimal_counts: Counter[str, Decimal] = Counter() +decimal_counts["foo"] += Decimal("3.0") +decimal_counts["bar"] += Decimal("5.0") + +# Counter combining integers and floats +mixed_type_counter = Counter({"foo": 3, "bar": 2.5}) +mixed_type_counter["baz"] += 1.5 +assert_type(mixed_type_counter, "Counter[str, int | float]") + +# Check ORing and ANDing Counters with different value types +assert_type(mixed_type_counter or decimal_counts, "Counter[str, int | float] | Counter[str, Decimal]") +assert_type(decimal_counts or mixed_type_counter, "Counter[str, Decimal] | Counter[str, int | float]") +assert_type(mixed_type_counter and decimal_counts, "Counter[str, int | float] | Counter[str, Decimal]") +assert_type(decimal_counts and mixed_type_counter, "Counter[str, Decimal] | Counter[str, int | float]") + +# We shouldn't be able to add Counters with incompatible value types +_ = mixed_type_counter + decimal_counts # type: ignore +mixed_type_counter += decimal_counts # type: ignore + +# Adding Counters with compatible types +_ = word_counts + Counter[str]() +word_counts += Counter[str]() + +# Combining Counters of different key types +integer_key_counter = Counter[int]() +combined_word_and_integer_keys = word_counts + integer_key_counter +assert_type(combined_word_and_integer_keys, "Counter[str | int, int]") diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index 71e3c564dd57..98e94c03a9db 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -1,8 +1,8 @@ import sys from _collections_abc import dict_items, dict_keys, dict_values from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT -from typing import Any, Generic, NoReturn, SupportsIndex, TypeVar, final, overload -from typing_extensions import Self +from typing import Any, Generic, NoReturn, SupportsIndex, final, overload +from typing_extensions import Self, TypeVar if sys.version_info >= (3, 9): from types import GenericAlias @@ -28,6 +28,8 @@ __all__ = ["ChainMap", "Counter", "OrderedDict", "UserDict", "UserList", "UserSt _S = TypeVar("_S") _T = TypeVar("_T") +_V = TypeVar("_V") +_V_I = TypeVar("_V_I", default=int) _T1 = TypeVar("_T1") _T2 = TypeVar("_T2") _KT = TypeVar("_KT") @@ -273,24 +275,24 @@ class deque(MutableSequence[_T]): if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... -class Counter(dict[_T, int], Generic[_T]): +class Counter(dict[_T, _V_I], Generic[_T, _V_I]): @overload def __init__(self, iterable: None = None, /) -> None: ... @overload - def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ... + def __init__(self: Counter[str], iterable: None = None, /, **kwargs: _V_I) -> None: ... @overload - def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ... + def __init__(self, mapping: SupportsKeysAndGetItem[_T, _V_I], /) -> None: ... @overload def __init__(self, iterable: Iterable[_T], /) -> None: ... def copy(self) -> Self: ... def elements(self) -> Iterator[_T]: ... - def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ... + def most_common(self, n: int | None = None) -> list[tuple[_T, _V_I]]: ... @classmethod def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override] @overload def subtract(self, iterable: None = None, /) -> None: ... @overload - def subtract(self, mapping: Mapping[_T, int], /) -> None: ... + def subtract(self, mapping: Mapping[_T, _V_I], /) -> None: ... @overload def subtract(self, iterable: Iterable[_T], /) -> None: ... # Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload @@ -305,29 +307,29 @@ class Counter(dict[_T, int], Generic[_T]): def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ... @overload def update(self, iterable: None = None, /, **kwargs: int) -> None: ... - def __missing__(self, key: _T) -> int: ... + def __missing__(self, key: _T) -> _V_I: ... def __delitem__(self, elem: object) -> None: ... if sys.version_info >= (3, 10): def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... - def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ... - def __sub__(self, other: Counter[_T]) -> Counter[_T]: ... - def __and__(self, other: Counter[_T]) -> Counter[_T]: ... - def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override] - def __pos__(self) -> Counter[_T]: ... - def __neg__(self) -> Counter[_T]: ... + def __add__(self, other: Counter[_S, _V_I]) -> Counter[_T | _S, _V_I]: ... + def __sub__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ... + def __and__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ... + def __or__(self, other: Counter[_S, _V]) -> Counter[_T | _S, _V_I | _V]: ... # type: ignore[override] + def __pos__(self) -> Counter[_T, _V_I]: ... + def __neg__(self) -> Counter[_T, _V_I]: ... # several type: ignores because __iadd__ is supposedly incompatible with __add__, etc. - def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc] - def __isub__(self, other: SupportsItems[_T, int]) -> Self: ... - def __iand__(self, other: SupportsItems[_T, int]) -> Self: ... - def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc] + def __iadd__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc] + def __isub__(self, other: SupportsItems[_T, _V_I]) -> Self: ... + def __iand__(self, other: SupportsItems[_T, _V_I]) -> Self: ... + def __ior__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[override,misc] if sys.version_info >= (3, 10): def total(self) -> int: ... - def __le__(self, other: Counter[Any]) -> bool: ... - def __lt__(self, other: Counter[Any]) -> bool: ... - def __ge__(self, other: Counter[Any]) -> bool: ... - def __gt__(self, other: Counter[Any]) -> bool: ... + def __le__(self, other: Counter[Any, _V_I]) -> bool: ... + def __lt__(self, other: Counter[Any, _V_I]) -> bool: ... + def __ge__(self, other: Counter[Any, _V_I]) -> bool: ... + def __gt__(self, other: Counter[Any, _V_I]) -> bool: ... # The pure-Python implementations of the "views" classes # These are exposed at runtime in `collections/__init__.py`