diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index 1d23ecd66a8d..ccd0023a9190 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -34,6 +34,7 @@ _KT = TypeVar("_KT") _VT = TypeVar("_VT") _KT_co = TypeVar("_KT_co", covariant=True) _VT_co = TypeVar("_VT_co", covariant=True) +_C = TypeVar("_C", default=int) # namedtuple is special-cased in the type checker; the initializer is ignored. def namedtuple( @@ -261,13 +262,13 @@ class deque(MutableSequence[_T]): if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... -class Counter(dict[_T, int], Generic[_T]): +class Counter(dict[_T, _C]): @overload def __init__(self, iterable: None = None, /) -> None: ... @overload - def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ... + def __init__(self: Counter[str, _C], iterable: None = None, /, **kwargs: _C) -> None: ... @overload - def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ... + def __init__(self: Counter[_T, _C], mapping: SupportsKeysAndGetItem[_T, _C], /) -> None: ... @overload def __init__(self, iterable: Iterable[_T], /) -> None: ... def copy(self) -> Self: ... diff --git a/test_cases/stdlib/collections/check_counter.py b/test_cases/stdlib/collections/check_counter.py new file mode 100644 index 000000000000..9b6f74b4736a --- /dev/null +++ b/test_cases/stdlib/collections/check_counter.py @@ -0,0 +1,31 @@ +from collections import Counter +from typing_extensions import assert_type, Never + + +class Foo: ... + + +# Test the constructor +assert_type(Counter(), Counter[Never, int]) +assert_type(Counter(foo=42.2), Counter[str, float]) +assert_type(Counter({42: "bar"}), Counter[int, str]) +assert_type(Counter([1, 2, 3]), Counter[int, int]) + +int_c: Counter[str] = Counter() +assert_type(int_c, Counter[str, int]) +int_c["a"] = 1 +int_c["a"] += 3 +int_c["a"] += 3.5 # type: ignore + +float_c = Counter(foo=42.2) +assert_type(float_c, Counter[str, float]) +float_c["a"] = 1.0 +float_c["a"] += 3.0 +float_c["a"] += 42 +float_c["a"] += "42" # type: ignore + +custom_c: Counter[str, Foo] = Counter() +assert_type(custom_c, Counter[str, Foo]) +custom_c["a"] = Foo() +custom_c["a"] += Foo() # type: ignore +custom_c["a"] += 42 # type: ignore