From 93d567e478c72281c2ddfb7fc0ae9fc298b3560e Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 6 Nov 2023 18:11:55 +0100 Subject: [PATCH] Add observe_subdocs (#25) --- pyproject.toml | 2 +- python/pycrdt/_pycrdt.pyi | 11 ++++++- python/pycrdt/array.py | 15 +++++++++- python/pycrdt/doc.py | 9 +++++- src/array.rs | 9 ++++-- src/doc.rs | 63 ++++++++++++++++++++++++++++++++++++++- src/lib.rs | 2 ++ src/map.rs | 8 +++-- src/text.rs | 8 +++-- tests/test_doc.py | 21 +++++++++---- 10 files changed, 130 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f908306..0323885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,5 +44,5 @@ python-source = "python" module-name = "pycrdt._pycrdt" [tool.ruff] -line-length = 88 +line-length = 100 select = ["F", "E", "W", "I001"] diff --git a/python/pycrdt/_pycrdt.pyi b/python/pycrdt/_pycrdt.pyi index f0cd04f..d47e124 100644 --- a/python/pycrdt/_pycrdt.pyi +++ b/python/pycrdt/_pycrdt.pyi @@ -7,7 +7,9 @@ class Doc: """Create a new document with an optional global client ID. If no client ID is passed, a random one will be generated.""" def client_id(self) -> int: - """Returns the document global client ID.""" + """Returns the document unique client identifier.""" + def guid(self) -> int: + """Returns the document globally unique identifier.""" def create_transaction(self) -> Transaction: """Create a document transaction.""" def get_or_insert_text(self, name: str) -> Text: @@ -25,6 +27,9 @@ class Doc: def observe(self, callback: Callable[[TransactionEvent], None]) -> int: """Subscribes a callback to be called with the shared document change event. Returns a subscription ID that can be used to unsubscribe.""" + def observe_subdocs(self, callback: Callable[[SubdocsEvent], None]) -> int: + """Subscribes a callback to be called with the shared document subdoc change event. + Returns a subscription ID that can be used to unsubscribe.""" class Transaction: """Document transaction""" @@ -38,6 +43,10 @@ class TransactionEvent: """Event generated by `Doc.observe` method. Emitted during transaction commit phase.""" +class SubdocsEvent: + """Event generated by `Doc.observe_subdocs` method. Emitted during transaction commit + phase.""" + class TextEvent: """Event generated by `Text.observe` method. Emitted during transaction commit phase.""" diff --git a/python/pycrdt/array.py b/python/pycrdt/array.py index 420245b..dae00b8 100644 --- a/python/pycrdt/array.py +++ b/python/pycrdt/array.py @@ -77,8 +77,11 @@ def __radd__(self, value: list[Any]) -> Array: def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None: with self.doc.transaction(): if isinstance(key, int): + length = len(self) + if length == 0: + raise IndexError("Array index out of range") if key < 0: - key += len(self) + key += length del self[key] self[key:key] = [value] elif isinstance(key, slice): @@ -96,6 +99,11 @@ def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None: def __delitem__(self, key: int | slice) -> None: with self.doc.transaction() as txn: if isinstance(key, int): + length = len(self) + if length == 0: + raise IndexError("Array index out of range") + if key < 0: + key += length self.integrated.remove_range(txn, key, 1) elif isinstance(key, slice): if key.step is not None: @@ -119,6 +127,11 @@ def __delitem__(self, key: int | slice) -> None: def __getitem__(self, key: int) -> BaseType: with self.doc.transaction() as txn: if isinstance(key, int): + length = len(self) + if length == 0: + raise IndexError("Array index out of range") + if key < 0: + key += length return self._maybe_as_type_or_doc(self.integrated.get(txn, key)) elif isinstance(key, slice): i0 = 0 if key.start is None else key.start diff --git a/python/pycrdt/doc.py b/python/pycrdt/doc.py index 6a86c69..53ccac1 100644 --- a/python/pycrdt/doc.py +++ b/python/pycrdt/doc.py @@ -3,12 +3,16 @@ from typing import Callable from ._pycrdt import Doc as _Doc -from ._pycrdt import TransactionEvent +from ._pycrdt import TransactionEvent, SubdocsEvent from .base import BaseDoc, BaseType, integrated_types from .transaction import Transaction class Doc(BaseDoc): + @property + def guid(self) -> int: + return self._doc.guid() + @property def client_id(self) -> int: return self._doc.client_id() @@ -39,5 +43,8 @@ def __setitem__(self, key: str, value: BaseType) -> None: def observe(self, callback: Callable[[TransactionEvent], None]) -> int: return self._doc.observe(callback) + def observe_subdocs(self, callback: Callable[[SubdocsEvent], None]) -> int: + return self._doc.observe_subdocs(callback) + integrated_types[_Doc] = Doc diff --git a/src/array.rs b/src/array.rs index ef0312a..b7a2c4f 100644 --- a/src/array.rs +++ b/src/array.rs @@ -165,13 +165,17 @@ impl ArrayEvent { pub fn new(event: &_ArrayEvent, txn: &TransactionMut) -> Self { let event = event as *const _ArrayEvent; let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) }; - ArrayEvent { + let mut array_event = ArrayEvent { event, txn, target: None, delta: None, path: None, - } + }; + array_event.target(); + array_event.path(); + array_event.delta(); + array_event } fn event(&self) -> &_ArrayEvent { @@ -216,7 +220,6 @@ impl ArrayEvent { let delta = self.event().delta(self.txn()).iter().map(|change| { Python::with_gil(|py| change.clone().into_py(py)) }); - PyList::new(py, delta).into() }); self.delta = Some(delta.clone()); diff --git a/src/doc.rs b/src/doc.rs index aa0f910..384e72c 100644 --- a/src/doc.rs +++ b/src/doc.rs @@ -1,11 +1,12 @@ use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyLong}; +use pyo3::types::{PyBytes, PyLong, PyList}; use yrs::{ Doc as _Doc, ReadTxn, Transact, TransactionMut, TransactionCleanupEvent, + SubdocsEvent as _SubdocsEvent, StateVector, Update, }; @@ -42,6 +43,10 @@ impl Doc { Doc { doc } } + fn guid(&mut self) -> String { + self.doc.guid().to_string() + } + fn client_id(&mut self) -> u64 { self.doc.client_id() } @@ -109,6 +114,21 @@ impl Doc { .into(); Ok(id) } + + pub fn observe_subdocs(&mut self, f: PyObject) -> PyResult { + let id: u32 = self.doc + .observe_subdocs(move |_, event| { + Python::with_gil(|py| { + let event = SubdocsEvent::new(event); + if let Err(err) = f.call1(py, (event,)) { + err.restore(py) + } + }) + }) + .unwrap() + .into(); + Ok(id) + } } #[pyclass(unsendable)] @@ -161,3 +181,44 @@ impl TransactionEvent { self.update.clone() } } + +#[pyclass(unsendable)] +pub struct SubdocsEvent { + added: PyObject, + removed: PyObject, + loaded: PyObject, +} + +impl SubdocsEvent { + fn new(event: &_SubdocsEvent) -> Self { + let added: Vec = event.added().map(|d| d.guid().clone().to_string()).collect(); + let added: PyObject = Python::with_gil(|py| PyList::new(py, &added).into()); + let removed: Vec = event.removed().map(|d| d.guid().clone().to_string()).collect(); + let removed: PyObject = Python::with_gil(|py| PyList::new(py, &removed).into()); + let loaded: Vec = event.loaded().map(|d| d.guid().clone().to_string()).collect(); + let loaded: PyObject = Python::with_gil(|py| PyList::new(py, &loaded).into()); + SubdocsEvent { + added, + removed, + loaded, + } + } +} + +#[pymethods] +impl SubdocsEvent { + #[getter] + pub fn added(&mut self) -> PyObject { + self.added.clone() + } + + #[getter] + pub fn removed(&mut self) -> PyObject { + self.removed.clone() + } + + #[getter] + pub fn loaded(&mut self) -> PyObject { + self.loaded.clone() + } +} diff --git a/src/lib.rs b/src/lib.rs index 0e96c61..76364fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ mod transaction; mod type_conversions; use crate::doc::Doc; use crate::doc::TransactionEvent; +use crate::doc::SubdocsEvent; use crate::text::{Text, TextEvent}; use crate::array::{Array, ArrayEvent}; use crate::map::{Map, MapEvent}; @@ -16,6 +17,7 @@ use crate::transaction::Transaction; fn _pycrdt(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/map.rs b/src/map.rs index 664a685..7cc357d 100644 --- a/src/map.rs +++ b/src/map.rs @@ -176,13 +176,17 @@ impl MapEvent { pub fn new(event: &_MapEvent, txn: &TransactionMut) -> Self { let event = event as *const _MapEvent; let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) }; - MapEvent { + let mut map_event = MapEvent { event, txn, target: None, keys: None, path: None, - } + }; + map_event.target(); + map_event.path(); + map_event.keys(); + map_event } fn event(&self) -> &_MapEvent { diff --git a/src/text.rs b/src/text.rs index 33461e3..5a3c1bb 100644 --- a/src/text.rs +++ b/src/text.rs @@ -87,13 +87,17 @@ impl TextEvent { pub fn new(event: &_TextEvent, txn: &TransactionMut) -> Self { let event = event as *const _TextEvent; let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) }; - TextEvent { + let mut text_event = TextEvent { event, txn, target: None, delta: None, path: None, - } + }; + text_event.target(); + text_event.path(); + text_event.delta(); + text_event } fn event(&self) -> &_TextEvent { diff --git a/tests/test_doc.py b/tests/test_doc.py index 49eb905..c59948e 100644 --- a/tests/test_doc.py +++ b/tests/test_doc.py @@ -1,5 +1,6 @@ from functools import partial +import pytest from pycrdt import Array, Doc, Map, Text @@ -20,26 +21,25 @@ def encode_client_id(client_id_bytes): def test_subdoc(): doc0 = Doc() - state0 = doc0.get_state() map0 = Map() doc0["map0"] = map0 doc1 = Doc() - state1 = doc1.get_state() map1 = Map() doc1["map1"] = map1 doc2 = Doc() - state2 = doc2.get_state() array2 = Array() doc2["array2"] = array2 doc0["array0"] = Array(["hello", 1, doc1]) map0.update({"key0": "val0", "key1": doc2}) - update0 = doc0.get_update(state0) + update0 = doc0.get_update() remote_doc = Doc() + events = [] + remote_doc.observe_subdocs(partial(callback, events)) remote_doc.apply_update(update0) remote_array0 = Array() remote_map0 = Map() @@ -56,11 +56,11 @@ def test_subdoc(): map1["foo"] = "bar" - update1 = doc1.get_update(state1) + update1 = doc1.get_update() array2 += ["baz", 3] - update2 = doc2.get_update(state2) + update2 = doc2.get_update() remote_doc1.apply_update(update1) remote_doc2.apply_update(update2) @@ -68,6 +68,15 @@ def test_subdoc(): assert str(map1) == str(remote_map1) assert str(array2) == str(remote_array2) + assert len(events) == 1 + event = events[0] + assert len(event.added) == 2 + assert event.added[0] in (doc1.guid, doc2.guid) + assert event.added[1] in (doc1.guid, doc2.guid) + assert doc1.guid != doc2.guid + assert event.removed == [] + assert event.loaded == [] + def test_transaction_event(): doc = Doc()