diff --git a/firestore/google/cloud/firestore_v1beta1/watch.py b/firestore/google/cloud/firestore_v1beta1/watch.py index 05cc4f89c62b..31743913df75 100644 --- a/firestore/google/cloud/firestore_v1beta1/watch.py +++ b/firestore/google/cloud/firestore_v1beta1/watch.py @@ -79,9 +79,6 @@ def __init__(self): def keys(self): return list(self._dict.keys()) - def items(self): - return list(self._dict.items()) - def _copy(self): wdt = WatchDocTree() wdt._dict = self._dict.copy() @@ -115,9 +112,9 @@ def __contains__(self, k): class ChangeType(Enum): - ADDED = 0 - MODIFIED = 1 + ADDED = 1 REMOVED = 2 + MODIFIED = 3 class DocumentChange(object): @@ -380,9 +377,9 @@ def _on_snapshot_target_change_no_change(self, proto): def _on_snapshot_target_change_add(self, proto): _LOGGER.debug("on_snapshot: target change: ADD") - assert ( - WATCH_TARGET_ID == proto.target_change.target_ids[0] - ), "Unexpected target ID sent by server" + target_id = proto.target_change.target_ids[0] + if target_id != WATCH_TARGET_ID: + raise RuntimeError("Unexpected target ID %s sent by server" % target_id) def _on_snapshot_target_change_remove(self, proto): _LOGGER.debug("on_snapshot: target change: REMOVE") @@ -394,9 +391,9 @@ def _on_snapshot_target_change_remove(self, proto): code = change.cause.code message = change.cause.message - # TODO: Consider surfacing a code property on the exception. - # TODO: Consider a more exact exception - raise Exception("Error %s: %s" % (code, message)) + message = "Error %s: %s" % (code, message) + + raise RuntimeError(message) def _on_snapshot_target_change_reset(self, proto): # Whatever changes have happened so far no longer matter. @@ -495,7 +492,6 @@ def on_snapshot(self, proto): create_time=document.create_time, update_time=document.update_time, ) - self.change_map[document.name] = snapshot elif removed: @@ -503,9 +499,17 @@ def on_snapshot(self, proto): document = proto.document_change.document self.change_map[document.name] = ChangeType.REMOVED - elif proto.document_delete or proto.document_remove: - _LOGGER.debug("on_snapshot: document change: DELETE/REMOVE") - name = (proto.document_delete or proto.document_remove).document + # NB: document_delete and document_remove (as far as we, the client, + # are concerned) are functionally equivalent + + elif str(proto.document_delete): + _LOGGER.debug("on_snapshot: document change: DELETE") + name = proto.document_delete.document + self.change_map[name] = ChangeType.REMOVED + + elif str(proto.document_remove): + _LOGGER.debug("on_snapshot: document change: REMOVE") + name = proto.document_remove.document self.change_map[name] = ChangeType.REMOVED elif proto.filter: @@ -710,7 +714,8 @@ def _reset_docs(self): # Mark each document as deleted. If documents are not deleted # they will be sent again by the server. - for name, snapshot in self.doc_tree.items(): + for snapshot in self.doc_tree.keys(): + name = snapshot.reference._document_path self.change_map[name] = ChangeType.REMOVED self.current = False diff --git a/firestore/tests/unit/test_cross_language.py b/firestore/tests/unit/test_cross_language.py index e4a689337815..448ab6ff8cdf 100644 --- a/firestore/tests/unit/test_cross_language.py +++ b/firestore/tests/unit/test_cross_language.py @@ -205,10 +205,79 @@ def test_delete_testprotos(test_proto): _run_testcase(testcase, call, firestore_api, client) -@pytest.mark.skip(reason="Watch aka listen not yet implemented in Python.") @pytest.mark.parametrize("test_proto", _LISTEN_TESTPROTOS) def test_listen_testprotos(test_proto): # pragma: NO COVER - pass + # test_proto.listen has 'reponses' messages, + # 'google.firestore.v1beta1.ListenResponse' + # and then an expected list of 'snapshots' (local 'Snapshot'), containing + # 'docs' (list of 'google.firestore.v1beta1.Document'), + # 'changes' (list lof local 'DocChange', and 'read_time' timestamp. + from google.cloud.firestore_v1beta1 import Client + from google.cloud.firestore_v1beta1 import DocumentReference + from google.cloud.firestore_v1beta1 import DocumentSnapshot + from google.cloud.firestore_v1beta1 import Watch + import google.auth.credentials + + testcase = test_proto.listen + testname = test_proto.description + + credentials = mock.Mock(spec=google.auth.credentials.Credentials) + client = Client(project="project", credentials=credentials) + modulename = "google.cloud.firestore_v1beta1.watch" + with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): + with mock.patch( + "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer + ): + with mock.patch( # conformance data sets WATCH_TARGET_ID to 1 + "%s.WATCH_TARGET_ID" % modulename, 1 + ): + snapshots = [] + + def callback(keys, applied_changes, read_time): + snapshots.append((keys, applied_changes, read_time)) + + query = DummyQuery(client=client) + watch = Watch.for_query( + query, callback, DocumentSnapshot, DocumentReference + ) + # conformance data has db string as this + db_str = "projects/projectID/databases/(default)" + watch._firestore._database_string_internal = db_str + + if testcase.is_error: + try: + for proto in testcase.responses: + watch.on_snapshot(proto) + except RuntimeError: + # listen-target-add-wrong-id.textpro + # listen-target-remove.textpro + pass + + else: + for proto in testcase.responses: + watch.on_snapshot(proto) + + assert len(snapshots) == len(testcase.snapshots) + for i, (expected_snapshot, actual_snapshot) in enumerate( + zip(testcase.snapshots, snapshots) + ): + expected_changes = expected_snapshot.changes + actual_changes = actual_snapshot[1] + if len(expected_changes) != len(actual_changes): + raise AssertionError( + "change length mismatch in %s (snapshot #%s)" + % (testname, i) + ) + for y, (expected_change, actual_change) in enumerate( + zip(expected_changes, actual_changes) + ): + expected_change_kind = expected_change.kind + actual_change_kind = actual_change.type.value + if expected_change_kind != actual_change_kind: + raise AssertionError( + "change type mismatch in %s (snapshot #%s, change #%s')" + % (testname, i, y) + ) @pytest.mark.parametrize("test_proto", _QUERY_TESTPROTOS) @@ -272,6 +341,57 @@ def convert_precondition(precond): return Client.write_option(last_update_time=precond.update_time) +class DummyRpc(object): # pragma: NO COVER + def __init__(self, listen, initial_request, should_recover): + self.listen = listen + self.initial_request = initial_request + self.should_recover = should_recover + self.closed = False + self.callbacks = [] + + def add_done_callback(self, callback): + self.callbacks.append(callback) + + def close(self): + self.closed = True + + +class DummyBackgroundConsumer(object): # pragma: NO COVER + started = False + stopped = False + is_active = True + + def __init__(self, rpc, on_snapshot): + self._rpc = rpc + self.on_snapshot = on_snapshot + + def start(self): + self.started = True + + def stop(self): + self.stopped = True + self.is_active = False + + +class DummyQuery(object): # pragma: NO COVER + def __init__(self, **kw): + self._client = kw["client"] + self._comparator = lambda x, y: 1 + + def _to_protobuf(self): + from google.cloud.firestore_v1beta1.proto import query_pb2 + + query_kwargs = { + "select": None, + "from": None, + "where": None, + "order_by": None, + "start_at": None, + "end_at": None, + } + return query_pb2.StructuredQuery(**query_kwargs) + + def parse_query(testcase): # 'query' testcase contains: # - 'coll_path': collection ref path. diff --git a/firestore/tests/unit/test_watch.py b/firestore/tests/unit/test_watch.py index d0ce9d8ecc6c..78e543e493b9 100644 --- a/firestore/tests/unit/test_watch.py +++ b/firestore/tests/unit/test_watch.py @@ -272,7 +272,7 @@ def test_on_snapshot_target_add(self): proto.target_change.target_ids = [1] # not "Py" with self.assertRaises(Exception) as exc: inst.on_snapshot(proto) - self.assertEqual(str(exc.exception), "Unexpected target ID sent by server") + self.assertEqual(str(exc.exception), "Unexpected target ID 1 sent by server") def test_on_snapshot_target_remove(self): inst = self._makeOne() @@ -403,7 +403,7 @@ class DummyRemove(object): remove = DummyRemove() proto.document_remove = remove - proto.document_delete = None + proto.document_delete = "" inst.on_snapshot(proto) self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED) @@ -412,8 +412,8 @@ def test_on_snapshot_filter_update(self): proto = DummyProto() proto.target_change = "" proto.document_change = "" - proto.document_remove = None - proto.document_delete = None + proto.document_remove = "" + proto.document_delete = "" class DummyFilter(object): count = 999 @@ -432,8 +432,8 @@ def test_on_snapshot_filter_update_no_size_change(self): proto = DummyProto() proto.target_change = "" proto.document_change = "" - proto.document_remove = None - proto.document_delete = None + proto.document_remove = "" + proto.document_delete = "" class DummyFilter(object): count = 0 @@ -449,8 +449,8 @@ def test_on_snapshot_unknown_listen_type(self): proto = DummyProto() proto.target_change = "" proto.document_change = "" - proto.document_remove = None - proto.document_delete = None + proto.document_remove = "" + proto.document_delete = "" proto.filter = "" with self.assertRaises(Exception) as exc: inst.on_snapshot(proto) @@ -659,13 +659,11 @@ def test__reset_docs(self): inst.change_map = {None: None} from google.cloud.firestore_v1beta1.watch import WatchDocTree - doc = DummyDocumentReference() - doc._document_path = "/doc" + doc = DummyDocumentReference("doc") doc_tree = WatchDocTree() - doc_tree = doc_tree.insert("/doc", doc) - doc_tree = doc_tree.insert("/doc", doc) snapshot = DummyDocumentSnapshot(doc, None, True, None, None, None) snapshot.reference = doc + doc_tree = doc_tree.insert(snapshot, None) inst.doc_tree = doc_tree inst._reset_docs() self.assertEqual(inst.change_map, {"/doc": ChangeType.REMOVED}) @@ -691,10 +689,9 @@ def __init__(self, *document_path, **kw): self._client = kw["client"] self._path = document_path + self._document_path = "/" + "/".join(document_path) self.__dict__.update(kw) - _document_path = "/" - class DummyQuery(object): # pragma: NO COVER def __init__(self, **kw): @@ -737,6 +734,12 @@ def __init__(self, reference, data, exists, read_time, create_time, update_time) self.create_time = create_time self.update_time = update_time + def __str__(self): + return "%s-%s" % (self.reference._document_path, self.read_time) + + def __hash__(self): + return hash(str(self)) + class DummyBackgroundConsumer(object): started = False