From 89d4b856197362f59a4bbcbf1a006533bdd1939a Mon Sep 17 00:00:00 2001 From: David Brochart Date: Tue, 7 Nov 2023 14:23:37 +0100 Subject: [PATCH] Improve Array and Map API (#27) --- python/pycrdt/array.py | 9 +++++++++ python/pycrdt/map.py | 12 ++++++++++++ tests/test_array.py | 19 +++++++++++++++++++ tests/test_map.py | 11 +++++++++++ 4 files changed, 51 insertions(+) diff --git a/python/pycrdt/array.py b/python/pycrdt/array.py index b7166a8..072a4a2 100644 --- a/python/pycrdt/array.py +++ b/python/pycrdt/array.py @@ -63,6 +63,15 @@ def extend(self, value: list[Any]) -> None: def clear(self) -> None: del self[:] + def insert(self, index, object) -> None: + self[index:index] = [object] + + def pop(self, index: int = -1) -> Any: + with self.doc.transaction(): + res = self[index] + del self[index] + return res + def __add__(self, value: list[Any]) -> Array: with self.doc.transaction(): length = len(self) diff --git a/python/pycrdt/map.py b/python/pycrdt/map.py index 84b3877..bae393e 100644 --- a/python/pycrdt/map.py +++ b/python/pycrdt/map.py @@ -80,6 +80,18 @@ def get(self, key: str, default_value: Any | None = None) -> Any | None: return self[key] return default_value + def pop(self, key: str, default_value: Any | None = None) -> Any: + with self.doc.transaction(): + if key not in self.keys(): + if ( + default_value is None + ): # FIXME: how to know if default_value was passed? + raise KeyError + return default_value + res = self[key] + del self[key] + return res + def keys(self): with self.doc.transaction() as txn: return iter(self.integrated.keys(txn)) diff --git a/tests/test_array.py b/tests/test_array.py index 7db8cd8..369c118 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -109,3 +109,22 @@ def callback(e): assert sid1 == "o_1" assert sid2 == "od0" assert sid3 == "od1" + + +def test_api(): + # pop + doc = Doc() + array = Array([1, 2, 3]) + doc["array"] = array + v = array.pop() + assert v == 3 + v = array.pop(0) + assert v == 1 + assert str(array) == "[2]" + + # insert + doc = Doc() + array = Array([1, 2, 3]) + doc["array"] = array + array.insert(1, 4) + assert str(array) == "[1,4,2,3]" diff --git a/tests/test_map.py b/tests/test_map.py index 64808fb..52020ec 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -45,3 +45,14 @@ def test_api(): assert dict(map0.items()) == items map0.clear() assert len(map0) == 0 + + # pop + doc = Doc() + map0 = Map({"foo": 1, "bar": 2}) + doc["map0"] = map0 + v = map0.pop("foo") + assert v == 1 + assert str(map0) == '{"bar":2}' + v = map0.pop("bar") + assert v == 2 + assert str(map0) == "{}"