Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add awareness features to handle server state #170

Merged
merged 21 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ test = [
"mypy",
"coverage[toml] >=7",
"exceptiongroup; python_version<'3.11'",
"dirty_equals",
brichet marked this conversation as resolved.
Show resolved Hide resolved
]
docs = [
"mkdocs",
Expand Down
143 changes: 130 additions & 13 deletions python/pycrdt/_awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,62 @@

import json
import time
from typing import Any
from typing import Any, Callable
from uuid import uuid4

from ._doc import Doc
from ._sync import Decoder, read_message
from ._sync import Decoder, YMessageType, read_message, write_var_uint

DEFAULT_USER = {"username": str(uuid4()), "name": "Jupyter server"}
brichet marked this conversation as resolved.
Show resolved Hide resolved

class Awareness: # pragma: no cover
def __init__(self, ydoc: Doc):

class Awareness:
client_id: int
meta: dict[int, dict[str, Any]]
_states: dict[int, dict[str, Any]]
_subscriptions: list[Callable[[dict[str, Any]], None]]
_user: dict[str, str] | None
brichet marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
ydoc: Doc,
on_change: Callable[[bytes], None] | None = None,
user: dict[str, str] | None = None,
brichet marked this conversation as resolved.
Show resolved Hide resolved
):
self.client_id = ydoc.client_id
self.meta: dict[int, dict[str, Any]] = {}
self.states: dict[int, dict[str, Any]] = {}
self.meta = {}
self._states = {}
self.on_change = on_change

if user is not None:
self.user = user
else:
self._user = DEFAULT_USER
self._states[self.client_id] = {"user": DEFAULT_USER}

self._subscriptions = []

@property
def states(self) -> dict[int, dict[str, Any]]:
return self._states

@property
def user(self) -> dict[str, str] | None:
return self._user

@user.setter
def user(self, user: dict[str, str]):
self._user = user
self.set_local_state_field("user", self._user)
brichet marked this conversation as resolved.
Show resolved Hide resolved

def get_changes(self, message: bytes) -> dict[str, Any]:
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
Updates the states with a user state.
This function sends the changes to subscribers.
brichet marked this conversation as resolved.
Show resolved Hide resolved

Args:
message: Bytes representing the user state.
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
message = read_message(message)
decoder = Decoder(message)
timestamp = int(time.time() * 1000)
Expand All @@ -32,19 +75,19 @@ def get_changes(self, message: bytes) -> dict[str, Any]:
if state is not None:
states.append(state)
client_meta = self.meta.get(client_id)
prev_state = self.states.get(client_id)
prev_state = self._states.get(client_id)
curr_clock = 0 if client_meta is None else client_meta["clock"]
if curr_clock < clock or (
curr_clock == clock and state is None and client_id in self.states
curr_clock == clock and state is None and client_id in self._states
):
if state is None:
if client_id == self.client_id and self.states.get(client_id) is not None:
if client_id == self.client_id and self._states.get(client_id) is not None:
clock += 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be

clock = curr_clock + 1

to update the local clock ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mimics the JavaScript implementation. Should it be different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstood the clock, but I understood that each client has its own.
If this is the case, why should we rely on a value coming from a client to update the local clock ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to open an issue in https://github.com/yjs/y-protocols to discuss this, or maybe @dmonad can comment here?

else:
if client_id in self.states:
del self.states[client_id]
if client_id in self._states:
del self._states[client_id]
else:
self.states[client_id] = state
self._states[client_id] = state
self.meta[client_id] = {
"clock": clock,
"last_updated": timestamp,
Expand All @@ -57,10 +100,84 @@ def get_changes(self, message: bytes) -> dict[str, Any]:
if state != prev_state:
filtered_updated.append(client_id)
updated.append(client_id)
return {

changes = {
"added": added,
"updated": updated,
"filtered_updated": filtered_updated,
"removed": removed,
"states": states,
}

# Do not trigger the callbacks if it is only a keep alive update
if added or filtered_updated or removed:
for callback in self._subscriptions:
callback(changes)

return changes

def get_local_state(self) -> dict[str, Any]:
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the local state (the state of the current awareness client).
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._states.get(self.client_id, {})

def set_local_state(self, state: dict[str, Any]) -> None:
"""
Updates the local state and meta.
This function calls the `on_change()` callback (if provided), with the serialized states
as argument.

Args:
state: The dictionary representing the state.
"""
timestamp = int(time.time() * 1000)
clock = self.meta.get(self.client_id, {}).get("clock", -1) + 1
self._states[self.client_id] = state
self.meta[self.client_id] = {"clock": clock, "last_updated": timestamp}
# Build the message to broadcast, with the following information:
# - message type
# - length in bytes of the updates
# - number of updates
# - for each update
# - client_id
# - clock
# - length in bytes of the update
# - encoded update
msg = json.dumps(state, separators=(",", ":")).encode("utf-8")
msg = write_var_uint(len(msg)) + msg
msg = write_var_uint(clock) + msg
msg = write_var_uint(self.client_id) + msg
msg = write_var_uint(1) + msg
msg = write_var_uint(len(msg)) + msg
msg = write_var_uint(YMessageType.AWARENESS) + msg

if self.on_change:
self.on_change(msg)

def set_local_state_field(self, field: str, value: Any) -> None:
"""
Set a local state field.
brichet marked this conversation as resolved.
Show resolved Hide resolved

Args:
field: The field to set (str)
value: the value of the field
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
current_state = self.get_local_state()
current_state[field] = value
self.set_local_state(current_state)

def observe(self, callback: Callable[[dict[str, Any]], None]) -> None:
brichet marked this conversation as resolved.
Show resolved Hide resolved
"""
Subscribes to awareness changes.

Args:
callback: Callback that will be called when the document changes.
"""
self._subscriptions.append(callback)

def unobserve(self) -> None:
"""
Unsubscribes to awareness changes. This method removes all the callbacks.
"""
self._subscriptions = []
184 changes: 184 additions & 0 deletions tests/test_awareness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import json
from copy import deepcopy
from uuid import uuid4

from dirty_equals import IsStr
from pycrdt import Awareness, Doc, write_var_uint

DEFAULT_USER = {"username": IsStr(), "name": "Jupyter server"}
TEST_USER = {"username": str(uuid4()), "name": "Test user"}
REMOTE_CLIENT_ID = 853790970
REMOTE_USER = {
"user": {
"username": "2460ab00fd28415b87e49ec5aa2d482d",
"name": "Anonymous Ersa",
"display_name": "Anonymous Ersa",
"initials": "AE",
"avatar_url": None,
"color": "var(--jp-collaborator-color7)",
}
}


def create_bytes_message(client_id, user, clock=1) -> bytes:
if type(user) is str:
new_user_bytes = user.encode("utf-8")
else:
new_user_bytes = json.dumps(user, separators=(",", ":")).encode("utf-8")
msg = write_var_uint(len(new_user_bytes)) + new_user_bytes
msg = write_var_uint(clock) + msg
msg = write_var_uint(client_id) + msg
msg = write_var_uint(1) + msg
msg = write_var_uint(len(msg)) + msg
return msg


def test_awareness_default_user():
ydoc = Doc()
awareness = Awareness(ydoc)

assert awareness.user == DEFAULT_USER


def test_awareness_with_user():
ydoc = Doc()
awareness = Awareness(ydoc, user=TEST_USER)

assert awareness.user == TEST_USER


def test_awareness_set_user():
ydoc = Doc()
awareness = Awareness(ydoc)
user = {"username": "test_username", "name": "test_name"}
awareness.user = user
assert awareness.user == user


def test_awareness_get_local_state():
ydoc = Doc()
awareness = Awareness(ydoc)

assert awareness.get_local_state() == {"user": DEFAULT_USER}


def test_awareness_set_local_state_field():
ydoc = Doc()
awareness = Awareness(ydoc)

awareness.set_local_state_field("new_field", "new_value")
assert awareness.get_local_state() == {"user": DEFAULT_USER, "new_field": "new_value"}


def test_awareness_add_user():
ydoc = Doc()
awareness = Awareness(ydoc)

changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER))
assert changes == {
"added": [REMOTE_CLIENT_ID],
"updated": [],
"filtered_updated": [],
"removed": [],
"states": [REMOTE_USER],
}
assert awareness.states == {
awareness.client_id: {"user": DEFAULT_USER},
REMOTE_CLIENT_ID: REMOTE_USER,
}


def test_awareness_update_user():
ydoc = Doc()
awareness = Awareness(ydoc)

# Add a remote user.
awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER))

# Update it
remote_user = deepcopy(REMOTE_USER)
remote_user["user"]["name"] = "New user name"
changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2))

assert changes == {
"added": [],
"updated": [REMOTE_CLIENT_ID],
"filtered_updated": [REMOTE_CLIENT_ID],
"removed": [],
"states": [remote_user],
}
assert awareness.states == {
awareness.client_id: {"user": DEFAULT_USER},
REMOTE_CLIENT_ID: remote_user,
}


def test_awareness_remove_user():
ydoc = Doc()
awareness = Awareness(ydoc)

# Add a remote user.
awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER))

# Remove it
changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null", 2))

assert changes == {
"added": [],
"updated": [],
"filtered_updated": [],
"removed": [REMOTE_CLIENT_ID],
"states": [],
}
assert awareness.states == {awareness.client_id: {"user": DEFAULT_USER}}


def test_awareness_increment_clock():
ydoc = Doc()
awareness = Awareness(ydoc)
changes = awareness.get_changes(create_bytes_message(awareness.client_id, "null"))
assert changes == {
"added": [],
"updated": [],
"filtered_updated": [],
"removed": [],
"states": [],
}
assert awareness.meta.get(awareness.client_id, {}).get("clock", 0) == 2


def test_awareness_observes():
ydoc = Doc()
awareness = Awareness(ydoc)

called = {}

def callback(value):
called.update(value)

awareness.observe(callback)
changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER))
assert called == changes

called = {}
awareness.unobserve()
changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER))
assert called != changes
assert called == {}


def test_awareness_on_change():
ydoc = Doc()

changes = []

def callback(value):
changes.append(value)

awareness = Awareness(ydoc, on_change=callback)

awareness.set_local_state_field("new_field", "new_value")

assert len(changes) == 1

assert type(changes[0]) is bytes
Loading