Skip to content

Commit

Permalink
Implement listen conformance (#6935)
Browse files Browse the repository at this point in the history
Closes #6533
  • Loading branch information
mcdonc authored and tseaver committed Dec 18, 2018
1 parent e09995f commit 6b7d6cf
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 32 deletions.
37 changes: 21 additions & 16 deletions firestore/google/cloud/firestore_v1beta1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def __init__(self):
def keys(self):
return list(self._dict.keys())

def items(self):
return list(self._dict.items())

def _copy(self):
wdt = WatchDocTree()
wdt._dict = self._dict.copy()
Expand Down Expand Up @@ -115,9 +112,9 @@ def __contains__(self, k):


class ChangeType(Enum):
ADDED = 0
MODIFIED = 1
ADDED = 1
REMOVED = 2
MODIFIED = 3


class DocumentChange(object):
Expand Down Expand Up @@ -380,9 +377,9 @@ def _on_snapshot_target_change_no_change(self, proto):

def _on_snapshot_target_change_add(self, proto):
_LOGGER.debug("on_snapshot: target change: ADD")
assert (
WATCH_TARGET_ID == proto.target_change.target_ids[0]
), "Unexpected target ID sent by server"
target_id = proto.target_change.target_ids[0]
if target_id != WATCH_TARGET_ID:
raise RuntimeError("Unexpected target ID %s sent by server" % target_id)

def _on_snapshot_target_change_remove(self, proto):
_LOGGER.debug("on_snapshot: target change: REMOVE")
Expand All @@ -394,9 +391,9 @@ def _on_snapshot_target_change_remove(self, proto):
code = change.cause.code
message = change.cause.message

# TODO: Consider surfacing a code property on the exception.
# TODO: Consider a more exact exception
raise Exception("Error %s: %s" % (code, message))
message = "Error %s: %s" % (code, message)

raise RuntimeError(message)

def _on_snapshot_target_change_reset(self, proto):
# Whatever changes have happened so far no longer matter.
Expand Down Expand Up @@ -495,17 +492,24 @@ def on_snapshot(self, proto):
create_time=document.create_time,
update_time=document.update_time,
)

self.change_map[document.name] = snapshot

elif removed:
_LOGGER.debug("on_snapshot: document change: REMOVED")
document = proto.document_change.document
self.change_map[document.name] = ChangeType.REMOVED

elif proto.document_delete or proto.document_remove:
_LOGGER.debug("on_snapshot: document change: DELETE/REMOVE")
name = (proto.document_delete or proto.document_remove).document
# NB: document_delete and document_remove (as far as we, the client,
# are concerned) are functionally equivalent

elif str(proto.document_delete):
_LOGGER.debug("on_snapshot: document change: DELETE")
name = proto.document_delete.document
self.change_map[name] = ChangeType.REMOVED

elif str(proto.document_remove):
_LOGGER.debug("on_snapshot: document change: REMOVE")
name = proto.document_remove.document
self.change_map[name] = ChangeType.REMOVED

elif proto.filter:
Expand Down Expand Up @@ -710,7 +714,8 @@ def _reset_docs(self):

# Mark each document as deleted. If documents are not deleted
# they will be sent again by the server.
for name, snapshot in self.doc_tree.items():
for snapshot in self.doc_tree.keys():
name = snapshot.reference._document_path
self.change_map[name] = ChangeType.REMOVED

self.current = False
124 changes: 122 additions & 2 deletions firestore/tests/unit/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,79 @@ def test_delete_testprotos(test_proto):
_run_testcase(testcase, call, firestore_api, client)


@pytest.mark.skip(reason="Watch aka listen not yet implemented in Python.")
@pytest.mark.parametrize("test_proto", _LISTEN_TESTPROTOS)
def test_listen_testprotos(test_proto): # pragma: NO COVER
pass
# test_proto.listen has 'reponses' messages,
# 'google.firestore.v1beta1.ListenResponse'
# and then an expected list of 'snapshots' (local 'Snapshot'), containing
# 'docs' (list of 'google.firestore.v1beta1.Document'),
# 'changes' (list lof local 'DocChange', and 'read_time' timestamp.
from google.cloud.firestore_v1beta1 import Client
from google.cloud.firestore_v1beta1 import DocumentReference
from google.cloud.firestore_v1beta1 import DocumentSnapshot
from google.cloud.firestore_v1beta1 import Watch
import google.auth.credentials

testcase = test_proto.listen
testname = test_proto.description

credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(project="project", credentials=credentials)
modulename = "google.cloud.firestore_v1beta1.watch"
with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc):
with mock.patch(
"%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer
):
with mock.patch( # conformance data sets WATCH_TARGET_ID to 1
"%s.WATCH_TARGET_ID" % modulename, 1
):
snapshots = []

def callback(keys, applied_changes, read_time):
snapshots.append((keys, applied_changes, read_time))

query = DummyQuery(client=client)
watch = Watch.for_query(
query, callback, DocumentSnapshot, DocumentReference
)
# conformance data has db string as this
db_str = "projects/projectID/databases/(default)"
watch._firestore._database_string_internal = db_str

if testcase.is_error:
try:
for proto in testcase.responses:
watch.on_snapshot(proto)
except RuntimeError:
# listen-target-add-wrong-id.textpro
# listen-target-remove.textpro
pass

else:
for proto in testcase.responses:
watch.on_snapshot(proto)

assert len(snapshots) == len(testcase.snapshots)
for i, (expected_snapshot, actual_snapshot) in enumerate(
zip(testcase.snapshots, snapshots)
):
expected_changes = expected_snapshot.changes
actual_changes = actual_snapshot[1]
if len(expected_changes) != len(actual_changes):
raise AssertionError(
"change length mismatch in %s (snapshot #%s)"
% (testname, i)
)
for y, (expected_change, actual_change) in enumerate(
zip(expected_changes, actual_changes)
):
expected_change_kind = expected_change.kind
actual_change_kind = actual_change.type.value
if expected_change_kind != actual_change_kind:
raise AssertionError(
"change type mismatch in %s (snapshot #%s, change #%s')"
% (testname, i, y)
)


@pytest.mark.parametrize("test_proto", _QUERY_TESTPROTOS)
Expand Down Expand Up @@ -272,6 +341,57 @@ def convert_precondition(precond):
return Client.write_option(last_update_time=precond.update_time)


class DummyRpc(object): # pragma: NO COVER
def __init__(self, listen, initial_request, should_recover):
self.listen = listen
self.initial_request = initial_request
self.should_recover = should_recover
self.closed = False
self.callbacks = []

def add_done_callback(self, callback):
self.callbacks.append(callback)

def close(self):
self.closed = True


class DummyBackgroundConsumer(object): # pragma: NO COVER
started = False
stopped = False
is_active = True

def __init__(self, rpc, on_snapshot):
self._rpc = rpc
self.on_snapshot = on_snapshot

def start(self):
self.started = True

def stop(self):
self.stopped = True
self.is_active = False


class DummyQuery(object): # pragma: NO COVER
def __init__(self, **kw):
self._client = kw["client"]
self._comparator = lambda x, y: 1

def _to_protobuf(self):
from google.cloud.firestore_v1beta1.proto import query_pb2

query_kwargs = {
"select": None,
"from": None,
"where": None,
"order_by": None,
"start_at": None,
"end_at": None,
}
return query_pb2.StructuredQuery(**query_kwargs)


def parse_query(testcase):
# 'query' testcase contains:
# - 'coll_path': collection ref path.
Expand Down
31 changes: 17 additions & 14 deletions firestore/tests/unit/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_on_snapshot_target_add(self):
proto.target_change.target_ids = [1] # not "Py"
with self.assertRaises(Exception) as exc:
inst.on_snapshot(proto)
self.assertEqual(str(exc.exception), "Unexpected target ID sent by server")
self.assertEqual(str(exc.exception), "Unexpected target ID 1 sent by server")

def test_on_snapshot_target_remove(self):
inst = self._makeOne()
Expand Down Expand Up @@ -403,7 +403,7 @@ class DummyRemove(object):

remove = DummyRemove()
proto.document_remove = remove
proto.document_delete = None
proto.document_delete = ""
inst.on_snapshot(proto)
self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED)

Expand All @@ -412,8 +412,8 @@ def test_on_snapshot_filter_update(self):
proto = DummyProto()
proto.target_change = ""
proto.document_change = ""
proto.document_remove = None
proto.document_delete = None
proto.document_remove = ""
proto.document_delete = ""

class DummyFilter(object):
count = 999
Expand All @@ -432,8 +432,8 @@ def test_on_snapshot_filter_update_no_size_change(self):
proto = DummyProto()
proto.target_change = ""
proto.document_change = ""
proto.document_remove = None
proto.document_delete = None
proto.document_remove = ""
proto.document_delete = ""

class DummyFilter(object):
count = 0
Expand All @@ -449,8 +449,8 @@ def test_on_snapshot_unknown_listen_type(self):
proto = DummyProto()
proto.target_change = ""
proto.document_change = ""
proto.document_remove = None
proto.document_delete = None
proto.document_remove = ""
proto.document_delete = ""
proto.filter = ""
with self.assertRaises(Exception) as exc:
inst.on_snapshot(proto)
Expand Down Expand Up @@ -659,13 +659,11 @@ def test__reset_docs(self):
inst.change_map = {None: None}
from google.cloud.firestore_v1beta1.watch import WatchDocTree

doc = DummyDocumentReference()
doc._document_path = "/doc"
doc = DummyDocumentReference("doc")
doc_tree = WatchDocTree()
doc_tree = doc_tree.insert("/doc", doc)
doc_tree = doc_tree.insert("/doc", doc)
snapshot = DummyDocumentSnapshot(doc, None, True, None, None, None)
snapshot.reference = doc
doc_tree = doc_tree.insert(snapshot, None)
inst.doc_tree = doc_tree
inst._reset_docs()
self.assertEqual(inst.change_map, {"/doc": ChangeType.REMOVED})
Expand All @@ -691,10 +689,9 @@ def __init__(self, *document_path, **kw):
self._client = kw["client"]

self._path = document_path
self._document_path = "/" + "/".join(document_path)
self.__dict__.update(kw)

_document_path = "/"


class DummyQuery(object): # pragma: NO COVER
def __init__(self, **kw):
Expand Down Expand Up @@ -737,6 +734,12 @@ def __init__(self, reference, data, exists, read_time, create_time, update_time)
self.create_time = create_time
self.update_time = update_time

def __str__(self):
return "%s-%s" % (self.reference._document_path, self.read_time)

def __hash__(self):
return hash(str(self))


class DummyBackgroundConsumer(object):
started = False
Expand Down

0 comments on commit 6b7d6cf

Please sign in to comment.