diff --git a/python/pycrdt/_array.py b/python/pycrdt/_array.py index 3e146e2..9d80918 100644 --- a/python/pycrdt/_array.py +++ b/python/pycrdt/_array.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types from ._pycrdt import Array as _Array @@ -10,18 +10,20 @@ if TYPE_CHECKING: from ._doc import Doc +T = TypeVar("T") -class Array(BaseType): + +class Array(BaseType, Generic[T]): """ A collection used to store data in an indexed sequence structure, similar to a Python `list`. """ - _prelim: list | None + _prelim: list[T] | None _integrated: _Array | None def __init__( self, - init: list | None = None, + init: list[T] | None = None, *, _doc: Doc | None = None, _integrated: _Array | None = None, @@ -42,14 +44,14 @@ def __init__( _integrated=_integrated, ) - def _init(self, value: list[Any] | None) -> None: + def _init(self, value: list[T] | None) -> None: if value is None: return with self.doc.transaction(): for i, v in enumerate(value): self._set(i, v) - def _set(self, index: int, value: Any) -> None: + def _set(self, index: int, value: T) -> None: with self.doc.transaction() as txn: self._forbid_read_transaction(txn) if isinstance(value, BaseDoc): @@ -79,7 +81,7 @@ def __len__(self) -> int: with self.doc.transaction() as txn: return self.integrated.len(txn._txn) - def append(self, value: Any) -> None: + def append(self, value: T) -> None: """ Appends an item to the array. @@ -89,7 +91,7 @@ def append(self, value: Any) -> None: with self.doc.transaction(): self += [value] - def extend(self, value: list[Any]) -> None: + def extend(self, value: list[T]) -> None: """ Extends the array with a list of items. @@ -105,7 +107,7 @@ def clear(self) -> None: """ del self[:] - def insert(self, index: int, object: Any) -> None: + def insert(self, index: int, object: T) -> None: """ Inserts an item at a given index in the array. @@ -115,7 +117,7 @@ def insert(self, index: int, object: Any) -> None: """ self[index:index] = [object] - def pop(self, index: int = -1) -> Any: + def pop(self, index: int = -1) -> T: """ Removes the item at the given index from the array, and returns it. If no index is passed, removes and returns the last item. @@ -148,7 +150,7 @@ def move(self, source_index: int, destination_index: int) -> None: destination_index = self._check_index(destination_index) self.integrated.move_to(txn._txn, source_index, destination_index) - def __add__(self, value: list[Any]) -> Array: + def __add__(self, value: list[T]) -> Array[T]: """ Extends the array with a list of items: ```py @@ -168,7 +170,7 @@ def __add__(self, value: list[Any]) -> Array: self[length:length] = value return self - def __radd__(self, value: list[Any]) -> Array: + def __radd__(self, value: list[T]) -> Array[T]: """ Prepends a list of items to the array: ```py @@ -187,7 +189,13 @@ def __radd__(self, value: list[Any]) -> Array: self[0:0] = value return self - def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None: + @overload + def __setitem__(self, key: int, value: T) -> None: ... + + @overload + def __setitem__(self, key: slice, value: list[T]) -> None: ... + + def __setitem__(self, key, value): """ Replaces the item at the given index with a new item: ```py @@ -271,7 +279,13 @@ def __delitem__(self, key: int | slice) -> None: f"Array indices must be integers or slices, not {type(key).__name__}" ) - def __getitem__(self, key: int) -> BaseType: + @overload + def __getitem__(self, key: int) -> T: ... + + @overload + def __getitem__(self, key: slice) -> list[T]: ... + + def __getitem__(self, key): """ Gets the item at the given index: ```py @@ -304,7 +318,7 @@ def __iter__(self) -> ArrayIterator: """ return ArrayIterator(self) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: T) -> bool: """ Checks if the given item is in the array: ```py @@ -333,7 +347,7 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) - def to_py(self) -> list | None: + def to_py(self) -> list[T] | None: """ Recursively converts the array's items to Python objects, and returns them in a list. If the array was not yet inserted in a document, diff --git a/python/pycrdt/_map.py b/python/pycrdt/_map.py index 260e981..569aa3f 100644 --- a/python/pycrdt/_map.py +++ b/python/pycrdt/_map.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, cast +from typing import TYPE_CHECKING, Callable, Generic, Iterable, TypeVar, cast, overload from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types from ._pycrdt import Map as _Map @@ -10,8 +10,11 @@ if TYPE_CHECKING: from ._doc import Doc +T = TypeVar("T") +T_DefaultValue = TypeVar("T_DefaultValue") -class Map(BaseType): + +class Map(BaseType, Generic[T]): """ A collection used to store key-value entries in an unordered manner, similar to a Python `dict`. """ @@ -21,7 +24,7 @@ class Map(BaseType): def __init__( self, - init: dict | None = None, + init: dict[str, T] | None = None, *, _doc: Doc | None = None, _integrated: _Map | None = None, @@ -42,14 +45,14 @@ def __init__( _integrated=_integrated, ) - def _init(self, value: dict[str, Any] | None) -> None: + def _init(self, value: dict[str, T] | None) -> None: if value is None: return with self.doc.transaction(): for k, v in value.items(): self._set(k, v) - def _set(self, key: str, value: Any) -> None: + def _set(self, key: str, value: T) -> None: with self.doc.transaction() as txn: self._forbid_read_transaction(txn) if isinstance(value, BaseDoc): @@ -91,7 +94,7 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) - def to_py(self) -> dict | None: + def to_py(self) -> dict[str, T] | None: """ Recursively converts the map's items to Python objects, and returns them in a `dict`. If the map was not yet inserted in a document, @@ -128,7 +131,7 @@ def __delitem__(self, key: str) -> None: self._check_key(key) self.integrated.remove(txn._txn, key) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> T: """ Gets the value at the given key: ```py @@ -143,7 +146,7 @@ def __getitem__(self, key: str) -> Any: self._check_key(key) return self._maybe_as_type_or_doc(self.integrated.get(txn._txn, key)) - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: str, value: T) -> None: """ Sets a value at the given key: ```py @@ -192,24 +195,38 @@ def __contains__(self, item: str) -> bool: """ return item in self.keys() - def get(self, key: str, default_value: Any | None = None) -> Any | None: + @overload + def get(self, key: str) -> T | None: ... + + @overload + def get(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ... + + def get(self, *args): """ Returns the value corresponding to the given key if it exists, otherwise - returns the `default_value`. + returns the default value if passed, or `None`. Args: - key: The key of the value to get. - default_value: The optional default value to return if the key is not found. + args: The key of the value to get, and an optional default value. Returns: - The value at the given key, or the default value. + The value at the given key, or the default value or `None`. """ + key, *default_value = args with self.doc.transaction(): if key in self.keys(): return self[key] - return default_value + if not default_value: + return None + return default_value[0] + + @overload + def pop(self, key: str) -> T: ... + + @overload + def pop(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ... - def pop(self, *args: Any) -> Any: + def pop(self, *args): """ Removes the entry at the given key from the map, and returns the corresponding value. @@ -231,7 +248,7 @@ def pop(self, *args: Any) -> Any: del self[key] return res - def _check_key(self, key: str): + def _check_key(self, key: str) -> None: if not isinstance(key, str): raise RuntimeError("Key must be of type string") if key not in self.keys(): @@ -245,7 +262,7 @@ def keys(self) -> Iterable[str]: with self.doc.transaction() as txn: return iter(self.integrated.keys(txn._txn)) - def values(self) -> Iterable[Any]: + def values(self) -> Iterable[T]: """ Returns: An iterable over the values of the map. @@ -254,7 +271,7 @@ def values(self) -> Iterable[Any]: for k in self.integrated.keys(txn._txn): yield self[k] - def items(self) -> Iterable[tuple[str, Any]]: + def items(self) -> Iterable[tuple[str, T]]: """ Returns: An iterable over the key-value pairs of the map. @@ -271,7 +288,7 @@ def clear(self) -> None: for k in self.integrated.keys(txn._txn): del self[k] - def update(self, value: dict[str, Any]) -> None: + def update(self, value: dict[str, T]) -> None: """ Sets entries in the map from all entries in the passed `dict`.