Skip to content

Commit

Permalink
Add observe_subdocs (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Nov 6, 2023
1 parent 9056567 commit 93d567e
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ python-source = "python"
module-name = "pycrdt._pycrdt"

[tool.ruff]
line-length = 88
line-length = 100
select = ["F", "E", "W", "I001"]
11 changes: 10 additions & 1 deletion python/pycrdt/_pycrdt.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ class Doc:
"""Create a new document with an optional global client ID.
If no client ID is passed, a random one will be generated."""
def client_id(self) -> int:
"""Returns the document global client ID."""
"""Returns the document unique client identifier."""
def guid(self) -> int:
"""Returns the document globally unique identifier."""
def create_transaction(self) -> Transaction:
"""Create a document transaction."""
def get_or_insert_text(self, name: str) -> Text:
Expand All @@ -25,6 +27,9 @@ class Doc:
def observe(self, callback: Callable[[TransactionEvent], None]) -> int:
"""Subscribes a callback to be called with the shared document change event.
Returns a subscription ID that can be used to unsubscribe."""
def observe_subdocs(self, callback: Callable[[SubdocsEvent], None]) -> int:
"""Subscribes a callback to be called with the shared document subdoc change event.
Returns a subscription ID that can be used to unsubscribe."""

class Transaction:
"""Document transaction"""
Expand All @@ -38,6 +43,10 @@ class TransactionEvent:
"""Event generated by `Doc.observe` method. Emitted during transaction commit
phase."""

class SubdocsEvent:
"""Event generated by `Doc.observe_subdocs` method. Emitted during transaction commit
phase."""

class TextEvent:
"""Event generated by `Text.observe` method. Emitted during transaction commit
phase."""
Expand Down
15 changes: 14 additions & 1 deletion python/pycrdt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ def __radd__(self, value: list[Any]) -> Array:
def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None:
with self.doc.transaction():
if isinstance(key, int):
length = len(self)
if length == 0:
raise IndexError("Array index out of range")
if key < 0:
key += len(self)
key += length
del self[key]
self[key:key] = [value]
elif isinstance(key, slice):
Expand All @@ -96,6 +99,11 @@ def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None:
def __delitem__(self, key: int | slice) -> None:
with self.doc.transaction() as txn:
if isinstance(key, int):
length = len(self)
if length == 0:
raise IndexError("Array index out of range")
if key < 0:
key += length
self.integrated.remove_range(txn, key, 1)
elif isinstance(key, slice):
if key.step is not None:
Expand All @@ -119,6 +127,11 @@ def __delitem__(self, key: int | slice) -> None:
def __getitem__(self, key: int) -> BaseType:
with self.doc.transaction() as txn:
if isinstance(key, int):
length = len(self)
if length == 0:
raise IndexError("Array index out of range")
if key < 0:
key += length
return self._maybe_as_type_or_doc(self.integrated.get(txn, key))
elif isinstance(key, slice):
i0 = 0 if key.start is None else key.start
Expand Down
9 changes: 8 additions & 1 deletion python/pycrdt/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from typing import Callable

from ._pycrdt import Doc as _Doc
from ._pycrdt import TransactionEvent
from ._pycrdt import TransactionEvent, SubdocsEvent
from .base import BaseDoc, BaseType, integrated_types
from .transaction import Transaction


class Doc(BaseDoc):
@property
def guid(self) -> int:
return self._doc.guid()

@property
def client_id(self) -> int:
return self._doc.client_id()
Expand Down Expand Up @@ -39,5 +43,8 @@ def __setitem__(self, key: str, value: BaseType) -> None:
def observe(self, callback: Callable[[TransactionEvent], None]) -> int:
return self._doc.observe(callback)

def observe_subdocs(self, callback: Callable[[SubdocsEvent], None]) -> int:
return self._doc.observe_subdocs(callback)


integrated_types[_Doc] = Doc
9 changes: 6 additions & 3 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,17 @@ impl ArrayEvent {
pub fn new(event: &_ArrayEvent, txn: &TransactionMut) -> Self {
let event = event as *const _ArrayEvent;
let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) };
ArrayEvent {
let mut array_event = ArrayEvent {
event,
txn,
target: None,
delta: None,
path: None,
}
};
array_event.target();
array_event.path();
array_event.delta();
array_event
}

fn event(&self) -> &_ArrayEvent {
Expand Down Expand Up @@ -216,7 +220,6 @@ impl ArrayEvent {
let delta = self.event().delta(self.txn()).iter().map(|change| {
Python::with_gil(|py| change.clone().into_py(py))
});

PyList::new(py, delta).into()
});
self.delta = Some(delta.clone());
Expand Down
63 changes: 62 additions & 1 deletion src/doc.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyLong};
use pyo3::types::{PyBytes, PyLong, PyList};
use yrs::{
Doc as _Doc,
ReadTxn,
Transact,
TransactionMut,
TransactionCleanupEvent,
SubdocsEvent as _SubdocsEvent,
StateVector,
Update,
};
Expand Down Expand Up @@ -42,6 +43,10 @@ impl Doc {
Doc { doc }
}

fn guid(&mut self) -> String {
self.doc.guid().to_string()
}

fn client_id(&mut self) -> u64 {
self.doc.client_id()
}
Expand Down Expand Up @@ -109,6 +114,21 @@ impl Doc {
.into();
Ok(id)
}

pub fn observe_subdocs(&mut self, f: PyObject) -> PyResult<u32> {
let id: u32 = self.doc
.observe_subdocs(move |_, event| {
Python::with_gil(|py| {
let event = SubdocsEvent::new(event);
if let Err(err) = f.call1(py, (event,)) {
err.restore(py)
}
})
})
.unwrap()
.into();
Ok(id)
}
}

#[pyclass(unsendable)]
Expand Down Expand Up @@ -161,3 +181,44 @@ impl TransactionEvent {
self.update.clone()
}
}

#[pyclass(unsendable)]
pub struct SubdocsEvent {
added: PyObject,
removed: PyObject,
loaded: PyObject,
}

impl SubdocsEvent {
fn new(event: &_SubdocsEvent) -> Self {
let added: Vec<String> = event.added().map(|d| d.guid().clone().to_string()).collect();
let added: PyObject = Python::with_gil(|py| PyList::new(py, &added).into());
let removed: Vec<String> = event.removed().map(|d| d.guid().clone().to_string()).collect();
let removed: PyObject = Python::with_gil(|py| PyList::new(py, &removed).into());
let loaded: Vec<String> = event.loaded().map(|d| d.guid().clone().to_string()).collect();
let loaded: PyObject = Python::with_gil(|py| PyList::new(py, &loaded).into());
SubdocsEvent {
added,
removed,
loaded,
}
}
}

#[pymethods]
impl SubdocsEvent {
#[getter]
pub fn added(&mut self) -> PyObject {
self.added.clone()
}

#[getter]
pub fn removed(&mut self) -> PyObject {
self.removed.clone()
}

#[getter]
pub fn loaded(&mut self) -> PyObject {
self.loaded.clone()
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod transaction;
mod type_conversions;
use crate::doc::Doc;
use crate::doc::TransactionEvent;
use crate::doc::SubdocsEvent;
use crate::text::{Text, TextEvent};
use crate::array::{Array, ArrayEvent};
use crate::map::{Map, MapEvent};
Expand All @@ -16,6 +17,7 @@ use crate::transaction::Transaction;
fn _pycrdt(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Doc>()?;
m.add_class::<TransactionEvent>()?;
m.add_class::<SubdocsEvent>()?;
m.add_class::<Text>()?;
m.add_class::<TextEvent>()?;
m.add_class::<Array>()?;
Expand Down
8 changes: 6 additions & 2 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,17 @@ impl MapEvent {
pub fn new(event: &_MapEvent, txn: &TransactionMut) -> Self {
let event = event as *const _MapEvent;
let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) };
MapEvent {
let mut map_event = MapEvent {
event,
txn,
target: None,
keys: None,
path: None,
}
};
map_event.target();
map_event.path();
map_event.keys();
map_event
}

fn event(&self) -> &_MapEvent {
Expand Down
8 changes: 6 additions & 2 deletions src/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,17 @@ impl TextEvent {
pub fn new(event: &_TextEvent, txn: &TransactionMut) -> Self {
let event = event as *const _TextEvent;
let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) };
TextEvent {
let mut text_event = TextEvent {
event,
txn,
target: None,
delta: None,
path: None,
}
};
text_event.target();
text_event.path();
text_event.delta();
text_event
}

fn event(&self) -> &_TextEvent {
Expand Down
21 changes: 15 additions & 6 deletions tests/test_doc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import pytest
from pycrdt import Array, Doc, Map, Text


Expand All @@ -20,26 +21,25 @@ def encode_client_id(client_id_bytes):

def test_subdoc():
doc0 = Doc()
state0 = doc0.get_state()
map0 = Map()
doc0["map0"] = map0

doc1 = Doc()
state1 = doc1.get_state()
map1 = Map()
doc1["map1"] = map1

doc2 = Doc()
state2 = doc2.get_state()
array2 = Array()
doc2["array2"] = array2

doc0["array0"] = Array(["hello", 1, doc1])
map0.update({"key0": "val0", "key1": doc2})

update0 = doc0.get_update(state0)
update0 = doc0.get_update()

remote_doc = Doc()
events = []
remote_doc.observe_subdocs(partial(callback, events))
remote_doc.apply_update(update0)
remote_array0 = Array()
remote_map0 = Map()
Expand All @@ -56,18 +56,27 @@ def test_subdoc():

map1["foo"] = "bar"

update1 = doc1.get_update(state1)
update1 = doc1.get_update()

array2 += ["baz", 3]

update2 = doc2.get_update(state2)
update2 = doc2.get_update()

remote_doc1.apply_update(update1)
remote_doc2.apply_update(update2)

assert str(map1) == str(remote_map1)
assert str(array2) == str(remote_array2)

assert len(events) == 1
event = events[0]
assert len(event.added) == 2
assert event.added[0] in (doc1.guid, doc2.guid)
assert event.added[1] in (doc1.guid, doc2.guid)
assert doc1.guid != doc2.guid
assert event.removed == []
assert event.loaded == []


def test_transaction_event():
doc = Doc()
Expand Down

0 comments on commit 93d567e

Please sign in to comment.