-
Notifications
You must be signed in to change notification settings - Fork 949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle buffers in Widget._should_send_property
#1595
Changes from 7 commits
6445ea9
afe6ae6
40cd07c
349f556
db04817
4d912e4
7d56205
b9eb8c8
0f0761c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Jupyter Development Team. | ||
# Distributed under the terms of the Modified BSD License. | ||
|
||
from ipython_genutils.py3compat import PY3 | ||
|
||
import nose.tools as nt | ||
|
||
from traitlets import Bool, Tuple, List, Instance | ||
|
||
from .utils import setup, teardown | ||
|
||
from ..widget import Widget | ||
|
||
# | ||
# First some widgets to test on: | ||
# | ||
|
||
# A widget with simple traits (list + tuple to ensure both are handled) | ||
class SimpleWidget(Widget): | ||
a = Bool().tag(sync=True) | ||
b = Tuple(Bool(), Bool(), Bool(), default_value=(False, False, False)).tag(sync=True) | ||
c = List(Bool()).tag(sync=True) | ||
|
||
|
||
|
||
# A widget where the data might be changed on reception: | ||
def transform_fromjson(data, widget): | ||
# Switch the two last elements when setting from json, if the first element is True | ||
# and always set first element to False | ||
if not data[0]: | ||
return data | ||
return [False] + data[1:-2] + [data[-1], data[-2]] | ||
|
||
class TransformerWidget(Widget): | ||
d = List(Bool()).tag(sync=True, from_json=transform_fromjson) | ||
|
||
|
||
|
||
# A widget that has a buffer: | ||
class DataInstance(): | ||
def __init__(self, data=None): | ||
self.data = data | ||
|
||
def mview_serializer(instance, widget): | ||
return { 'data': memoryview(instance.data) if instance.data else None } | ||
|
||
def bytes_serializer(instance, widget): | ||
return { 'data': bytearray(memoryview(instance.data).tobytes()) if instance.data else None } | ||
|
||
def deserializer(json_data, widget): | ||
return DataInstance( memoryview(json_data['data']).tobytes() if json_data else None ) | ||
|
||
class DataWidget(SimpleWidget): | ||
d = Instance(DataInstance).tag(sync=True, to_json=mview_serializer, from_json=deserializer) | ||
|
||
# A widget that has a buffer that might be changed on reception: | ||
def truncate_deserializer(json_data, widget): | ||
return DataInstance( json_data['data'][:20].tobytes() if json_data else None ) | ||
|
||
class TruncateDataWidget(SimpleWidget): | ||
d = Instance(DataInstance).tag(sync=True, to_json=bytes_serializer, from_json=truncate_deserializer) | ||
|
||
|
||
# | ||
# Actual tests: | ||
# | ||
|
||
def test_set_state_simple(): | ||
w = SimpleWidget() | ||
w.set_state(dict( | ||
a=True, | ||
b=[True, False, True], | ||
c=[False, True, False], | ||
)) | ||
|
||
nt.assert_equal(w.comm.messages, []) | ||
|
||
|
||
def test_set_state_transformer(): | ||
w = TransformerWidget() | ||
w.set_state(dict( | ||
d=[True, False, True] | ||
)) | ||
# Since the deserialize step changes the state, this should send an update | ||
nt.assert_equal(w.comm.messages, [((), dict( | ||
buffers=[], | ||
data=dict( | ||
buffer_paths=[], | ||
method='update', | ||
state=dict(d=[False, True, False]) | ||
)))]) | ||
|
||
|
||
def test_set_state_data(): | ||
w = DataWidget() | ||
data = memoryview(b'x'*30) | ||
w.set_state(dict( | ||
a=True, | ||
d={'data': data}, | ||
)) | ||
nt.assert_equal(w.comm.messages, []) | ||
|
||
|
||
def test_set_state_data_truncate(): | ||
w = TruncateDataWidget() | ||
data = memoryview(b'x'*30) | ||
w.set_state(dict( | ||
a=True, | ||
d={'data': data}, | ||
)) | ||
# Get message for checking | ||
nt.assert_equal(len(w.comm.messages), 1) # ensure we didn't get more than expected | ||
msg = w.comm.messages[0] | ||
# Assert that the data update (truncation) sends an update | ||
buffers = msg[1].pop('buffers') | ||
nt.assert_equal(msg, ((), dict( | ||
data=dict( | ||
buffer_paths=[['d', 'data']], | ||
method='update', | ||
state=dict(d={}) | ||
)))) | ||
|
||
# Sanity: | ||
nt.assert_equal(len(buffers), 1) | ||
nt.assert_equal(buffers[0], data[:20].tobytes()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,9 +50,9 @@ def _json_to_widget(x, obj): | |
} | ||
|
||
if PY3: | ||
_binary_types = (memoryview, bytes) | ||
_binary_types = (memoryview, bytearray, bytes) | ||
else: | ||
_binary_types = (memoryview, buffer) | ||
_binary_types = (memoryview, bytearray) | ||
|
||
def _put_buffers(state, buffer_paths, buffers): | ||
"""The inverse of _remove_buffers, except here we modify the existing dict/lists. | ||
|
@@ -115,6 +115,8 @@ def _separate_buffers(substate, path, buffer_paths, buffers): | |
def _remove_buffers(state): | ||
"""Return (state_without_buffers, buffer_paths, buffers) for binary message parts | ||
|
||
A binary message part is a memoryview, bytearray, or python 3 bytes object. | ||
|
||
As an example: | ||
>>> state = {'plain': [0, 'text'], 'x': {'ar': memoryview(ar1)}, 'y': {'shape': (10,10), 'data': memoryview(ar2)}} | ||
>>> _remove_buffers(state) | ||
|
@@ -483,6 +485,30 @@ def _compare(self, a, b): | |
else: | ||
return a == b | ||
|
||
def _buffer_list_equal(self, a, b): | ||
"""Compare two lists of buffers for equality. | ||
|
||
Used to decide whether two sequences of buffers (memoryviews, | ||
bytearrays, or python 3 bytes) differ, such that a sync is needed. | ||
|
||
Returns True if equal, False if unequal | ||
""" | ||
if len(a) != len(b): | ||
return False | ||
if a == b: | ||
return True | ||
for ia, ib in zip(a, b): | ||
# Check byte equality: | ||
# NOTE: Simple ia != ib does not always work as intended, as | ||
# e.g. memoryview(np.frombuffer(ia, dtype='float32')) != | ||
# memoryview(np.frombuffer(b)), since the format info differs. | ||
# However, since we only transfer bytes, we use `tobytes()`. | ||
ia_bytes = ia.tobytes() if isinstance(ia, memoryview) else ia | ||
ib_bytes = ib.tobytes() if isinstance(ib, memoryview) else ib | ||
if ia_bytes != ib_bytes: | ||
return False | ||
return True | ||
|
||
def set_state(self, sync_data): | ||
"""Called when a state is received from the front-end.""" | ||
# The order of these context managers is important. Properties must | ||
|
@@ -597,10 +623,16 @@ def _should_send_property(self, key, value): | |
# A roundtrip conversion through json in the comparison takes care of | ||
# idiosyncracies of how python data structures map to json, for example | ||
# tuples get converted to lists. | ||
if (key in self._property_lock | ||
and jsonloads(jsondumps(to_json(value, self))) == self._property_lock[key]): | ||
return False | ||
elif self._holding_sync: | ||
if key in self._property_lock: | ||
# model_state, buffer_paths, buffers | ||
split_value = _remove_buffers({ key: to_json(value, self)}) | ||
split_lock = _remove_buffers({ key: self._property_lock[key]}) | ||
# Compare state and buffer_paths | ||
if (jsonloads(jsondumps(split_value[0])) == jsonloads(jsondumps(split_lock[0])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I also factored out this comparison into a top-level function. |
||
and split_value[1] == split_lock[1] | ||
and self._buffer_list_equal(split_value[2], split_lock[2])): | ||
return False | ||
if self._holding_sync: | ||
self._states_to_send.add(key) | ||
return False | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have this as an extension point, keeping the key could be very informative as to which kind of comparison to perform. If anything, it's a problem with the current
_compare
that it does not have it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. for certain keys you might want to do a floating point comparison with a given precision, while for another you need exact equality. Without the key, you will not be able to determine which comparison to apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on this being an extension point? I don't understand how you're thinking of this as an extension point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean extension point as in "something an inheriting class might want to override".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this comparison depends on the top-level put/remove buffer function implementations, so I factored out the comparison to also be a top-level function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah. Since it depends on the implementation of the top-level functions (i.e., what exactly can be in the buffers list), I think it makes more sense to not override it, and instead put it at the top level too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation works for me, but if someone uses a serializer/deserialize set that might introduce noise when data is passed through a loop (
a != serializer(deserialize(a))
), e.g. floating point inaccuracies, they will want to override this comparison.It is still possible to do that by overriding the
_should_send_property
function, but it will be more brittle as it needs to take into account the internals of that method (e.g._property_lock
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I might be over-engineering this one. Other fields in the state other than buffers might also have this issue, so there is no obvious reason to give buffers preferential treatment.