Skip to content

Commit

Permalink
Support type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Dec 11, 2024
1 parent e4e5c88 commit 30bf9ec
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 35 deletions.
46 changes: 30 additions & 16 deletions python/pycrdt/_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload

from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types
from ._pycrdt import Array as _Array
Expand All @@ -10,18 +10,20 @@
if TYPE_CHECKING:
from ._doc import Doc

T = TypeVar("T")

class Array(BaseType):

class Array(BaseType, Generic[T]):
"""
A collection used to store data in an indexed sequence structure, similar to a Python `list`.
"""

_prelim: list | None
_prelim: list[T] | None
_integrated: _Array | None

def __init__(
self,
init: list | None = None,
init: list[T] | None = None,
*,
_doc: Doc | None = None,
_integrated: _Array | None = None,
Expand All @@ -42,14 +44,14 @@ def __init__(
_integrated=_integrated,
)

def _init(self, value: list[Any] | None) -> None:
def _init(self, value: list[T] | None) -> None:
if value is None:
return
with self.doc.transaction():
for i, v in enumerate(value):
self._set(i, v)

def _set(self, index: int, value: Any) -> None:
def _set(self, index: int, value: T) -> None:
with self.doc.transaction() as txn:
self._forbid_read_transaction(txn)
if isinstance(value, BaseDoc):
Expand Down Expand Up @@ -79,7 +81,7 @@ def __len__(self) -> int:
with self.doc.transaction() as txn:
return self.integrated.len(txn._txn)

def append(self, value: Any) -> None:
def append(self, value: T) -> None:
"""
Appends an item to the array.
Expand All @@ -89,7 +91,7 @@ def append(self, value: Any) -> None:
with self.doc.transaction():
self += [value]

def extend(self, value: list[Any]) -> None:
def extend(self, value: list[T]) -> None:
"""
Extends the array with a list of items.
Expand All @@ -105,7 +107,7 @@ def clear(self) -> None:
"""
del self[:]

def insert(self, index: int, object: Any) -> None:
def insert(self, index: int, object: T) -> None:
"""
Inserts an item at a given index in the array.
Expand All @@ -115,7 +117,7 @@ def insert(self, index: int, object: Any) -> None:
"""
self[index:index] = [object]

def pop(self, index: int = -1) -> Any:
def pop(self, index: int = -1) -> T:
"""
Removes the item at the given index from the array, and returns it.
If no index is passed, removes and returns the last item.
Expand Down Expand Up @@ -148,7 +150,7 @@ def move(self, source_index: int, destination_index: int) -> None:
destination_index = self._check_index(destination_index)
self.integrated.move_to(txn._txn, source_index, destination_index)

def __add__(self, value: list[Any]) -> Array:
def __add__(self, value: list[T]) -> Array[T]:
"""
Extends the array with a list of items:
```py
Expand All @@ -168,7 +170,7 @@ def __add__(self, value: list[Any]) -> Array:
self[length:length] = value
return self

def __radd__(self, value: list[Any]) -> Array:
def __radd__(self, value: list[T]) -> Array[T]:
"""
Prepends a list of items to the array:
```py
Expand All @@ -187,7 +189,13 @@ def __radd__(self, value: list[Any]) -> Array:
self[0:0] = value
return self

def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None:
@overload
def __setitem__(self, key: int, value: T) -> None: ...

@overload
def __setitem__(self, key: slice, value: list[T]) -> None: ...

def __setitem__(self, key, value):
"""
Replaces the item at the given index with a new item:
```py
Expand Down Expand Up @@ -271,7 +279,13 @@ def __delitem__(self, key: int | slice) -> None:
f"Array indices must be integers or slices, not {type(key).__name__}"
)

def __getitem__(self, key: int) -> BaseType:
@overload
def __getitem__(self, key: int) -> T: ...

@overload
def __getitem__(self, key: slice) -> list[T]: ...

def __getitem__(self, key):
"""
Gets the item at the given index:
```py
Expand Down Expand Up @@ -304,7 +318,7 @@ def __iter__(self) -> ArrayIterator:
"""
return ArrayIterator(self)

def __contains__(self, item: Any) -> bool:
def __contains__(self, item: T) -> bool:
"""
Checks if the given item is in the array:
```py
Expand Down Expand Up @@ -333,7 +347,7 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.to_json(txn._txn)

def to_py(self) -> list | None:
def to_py(self) -> list[T] | None:
"""
Recursively converts the array's items to Python objects, and
returns them in a list. If the array was not yet inserted in a document,
Expand Down
55 changes: 36 additions & 19 deletions python/pycrdt/_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Iterable, cast
from typing import TYPE_CHECKING, Callable, Generic, Iterable, TypeVar, cast, overload

from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types
from ._pycrdt import Map as _Map
Expand All @@ -10,8 +10,11 @@
if TYPE_CHECKING:
from ._doc import Doc

T = TypeVar("T")
T_DefaultValue = TypeVar("T_DefaultValue")

class Map(BaseType):

class Map(BaseType, Generic[T]):
"""
A collection used to store key-value entries in an unordered manner, similar to a Python `dict`.
"""
Expand All @@ -21,7 +24,7 @@ class Map(BaseType):

def __init__(
self,
init: dict | None = None,
init: dict[str, T] | None = None,
*,
_doc: Doc | None = None,
_integrated: _Map | None = None,
Expand All @@ -42,14 +45,14 @@ def __init__(
_integrated=_integrated,
)

def _init(self, value: dict[str, Any] | None) -> None:
def _init(self, value: dict[str, T] | None) -> None:
if value is None:
return
with self.doc.transaction():
for k, v in value.items():
self._set(k, v)

def _set(self, key: str, value: Any) -> None:
def _set(self, key: str, value: T) -> None:
with self.doc.transaction() as txn:
self._forbid_read_transaction(txn)
if isinstance(value, BaseDoc):
Expand Down Expand Up @@ -91,7 +94,7 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.to_json(txn._txn)

def to_py(self) -> dict | None:
def to_py(self) -> dict[str, T] | None:
"""
Recursively converts the map's items to Python objects, and
returns them in a `dict`. If the map was not yet inserted in a document,
Expand Down Expand Up @@ -128,7 +131,7 @@ def __delitem__(self, key: str) -> None:
self._check_key(key)
self.integrated.remove(txn._txn, key)

def __getitem__(self, key: str) -> Any:
def __getitem__(self, key: str) -> T:
"""
Gets the value at the given key:
```py
Expand All @@ -143,7 +146,7 @@ def __getitem__(self, key: str) -> Any:
self._check_key(key)
return self._maybe_as_type_or_doc(self.integrated.get(txn._txn, key))

def __setitem__(self, key: str, value: Any) -> None:
def __setitem__(self, key: str, value: T) -> None:
"""
Sets a value at the given key:
```py
Expand Down Expand Up @@ -192,24 +195,38 @@ def __contains__(self, item: str) -> bool:
"""
return item in self.keys()

def get(self, key: str, default_value: Any | None = None) -> Any | None:
@overload
def get(self, key: str) -> T | None: ...

@overload
def get(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ...

def get(self, *args):
"""
Returns the value corresponding to the given key if it exists, otherwise
returns the `default_value`.
returns the default value if passed, or `None`.
Args:
key: The key of the value to get.
default_value: The optional default value to return if the key is not found.
args: The key of the value to get, and an optional default value.
Returns:
The value at the given key, or the default value.
The value at the given key, or the default value or `None`.
"""
key, *default_value = args
with self.doc.transaction():
if key in self.keys():
return self[key]
return default_value
if not default_value:
return None
return default_value[0]

@overload
def pop(self, key: str) -> T: ...

@overload
def pop(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ...

def pop(self, *args: Any) -> Any:
def pop(self, *args):
"""
Removes the entry at the given key from the map, and returns the corresponding value.
Expand All @@ -231,7 +248,7 @@ def pop(self, *args: Any) -> Any:
del self[key]
return res

def _check_key(self, key: str):
def _check_key(self, key: str) -> None:
if not isinstance(key, str):
raise RuntimeError("Key must be of type string")
if key not in self.keys():
Expand All @@ -245,7 +262,7 @@ def keys(self) -> Iterable[str]:
with self.doc.transaction() as txn:
return iter(self.integrated.keys(txn._txn))

def values(self) -> Iterable[Any]:
def values(self) -> Iterable[T]:
"""
Returns:
An iterable over the values of the map.
Expand All @@ -254,7 +271,7 @@ def values(self) -> Iterable[Any]:
for k in self.integrated.keys(txn._txn):
yield self[k]

def items(self) -> Iterable[tuple[str, Any]]:
def items(self) -> Iterable[tuple[str, T]]:
"""
Returns:
An iterable over the key-value pairs of the map.
Expand All @@ -271,7 +288,7 @@ def clear(self) -> None:
for k in self.integrated.keys(txn._txn):
del self[k]

def update(self, value: dict[str, Any]) -> None:
def update(self, value: dict[str, T]) -> None:
"""
Sets entries in the map from all entries in the passed `dict`.
Expand Down

0 comments on commit 30bf9ec

Please sign in to comment.