Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Firestore: add driver for query conformance tests. #6839

Merged
merged 5 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions firestore/google/cloud/firestore_v1beta1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import math

from google.protobuf import wrappers_pb2
import six

from google.cloud.firestore_v1beta1 import _helpers
from google.cloud.firestore_v1beta1 import document
Expand Down Expand Up @@ -648,10 +649,8 @@ def _normalize_cursor(self, cursor, orders):
msg = _INVALID_CURSOR_TRANSFORM
raise ValueError(msg)

if key == "__name__" and "/" not in field:
document_fields[index] = "{}/{}/{}".format(
self._client._database_string, "/".join(self._parent._path), field
)
if key == "__name__" and isinstance(field, six.string_types):
document_fields[index] = self._parent.document(field)

return document_fields, before

Expand Down
126 changes: 125 additions & 1 deletion firestore/tests/unit/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def _load_testproto(filename):
if test_proto.WhichOneof("test") == "listen"
]

_QUERY_TESTPROTOS = [
test_proto
for test_proto in ALL_TESTPROTOS
if test_proto.WhichOneof("test") == "query"
]


def _mock_firestore_api():
firestore_api = mock.Mock(spec=["commit"])
Expand Down Expand Up @@ -201,10 +207,23 @@ def test_delete_testprotos(test_proto):

@pytest.mark.skip(reason="Watch aka listen not yet implemented in Python.")
@pytest.mark.parametrize("test_proto", _LISTEN_TESTPROTOS)
def test_listen_paths_testprotos(test_proto): # pragma: NO COVER
def test_listen_testprotos(test_proto): # pragma: NO COVER
pass


@pytest.mark.parametrize("test_proto", _QUERY_TESTPROTOS)
def test_query_testprotos(test_proto): # pragma: NO COVER
testcase = test_proto.query
if testcase.is_error:
with pytest.raises(Exception):
query = parse_query(testcase)
query._to_protobuf()
else:
query = parse_query(testcase)
found = query._to_protobuf()
assert found == testcase.query


def convert_data(v):
# Replace the strings 'ServerTimestamp' and 'Delete' with the corresponding
# sentinels.
Expand All @@ -225,6 +244,8 @@ def convert_data(v):
return [convert_data(e) for e in v]
elif isinstance(v, dict):
return {k: convert_data(v2) for k, v2 in v.items()}
elif v == "NaN":
return float(v)
else:
return v

Expand All @@ -249,3 +270,106 @@ def convert_precondition(precond):

assert precond.HasField("update_time")
return Client.write_option(last_update_time=precond.update_time)


def parse_query(testcase):
# 'query' testcase contains:
# - 'coll_path': collection ref path.
# - 'clauses': array of one or more 'Clause' elements
# - 'query': the actual google.firestore.v1beta1.StructuredQuery message
# to be constructed.
# - 'is_error' (as other testcases).
#
# 'Clause' elements are unions of:
# - 'select': [field paths]
# - 'where': (field_path, op, json_value)
# - 'order_by': (field_path, direction)
# - 'offset': int
# - 'limit': int
# - 'start_at': 'Cursor'
# - 'start_after': 'Cursor'
# - 'end_at': 'Cursor'
# - 'end_before': 'Cursor'
#
# 'Cursor' contains either:
# - 'doc_snapshot': 'DocSnapshot'
# - 'json_values': [string]
#
# 'DocSnapshot' contains:
# 'path': str
# 'json_data': str
from google.auth.credentials import Credentials
from google.cloud.firestore_v1beta1 import Client
from google.cloud.firestore_v1beta1 import Query

_directions = {"asc": Query.ASCENDING, "desc": Query.DESCENDING}

credentials = mock.create_autospec(Credentials)
client = Client("projectID", credentials)
path = parse_path(testcase.coll_path)
collection = client.collection(*path)
query = collection

for clause in testcase.clauses:
kind = clause.WhichOneof("clause")

if kind == "select":
field_paths = [
".".join(field_path.field) for field_path in clause.select.fields
]
query = query.select(field_paths)
elif kind == "where":
path = ".".join(clause.where.path.field)
value = convert_data(json.loads(clause.where.json_value))
query = query.where(path, clause.where.op, value)
elif kind == "order_by":
path = ".".join(clause.order_by.path.field)
direction = clause.order_by.direction
direction = _directions.get(direction, direction)
query = query.order_by(path, direction=direction)
elif kind == "offset":
query = query.offset(clause.offset)
elif kind == "limit":
query = query.limit(clause.limit)
elif kind == "start_at":
cursor = parse_cursor(clause.start_at, client)
query = query.start_at(cursor)
elif kind == "start_after":
cursor = parse_cursor(clause.start_after, client)
query = query.start_after(cursor)
elif kind == "end_at":
cursor = parse_cursor(clause.end_at, client)
query = query.end_at(cursor)
elif kind == "end_before":
cursor = parse_cursor(clause.end_before, client)
query = query.end_before(cursor)
else: # pragma: NO COVER
raise ValueError("Unknown query clause: {}".format(kind))

return query


def parse_path(path):
_, relative = path.split("documents/")
return relative.split("/")


def parse_cursor(cursor, client):
from google.cloud.firestore_v1beta1 import DocumentReference
from google.cloud.firestore_v1beta1 import DocumentSnapshot

if cursor.HasField("doc_snapshot"):
path = parse_path(cursor.doc_snapshot.path)
doc_ref = DocumentReference(*path, client=client)

return DocumentSnapshot(
reference=doc_ref,
data=json.loads(cursor.doc_snapshot.json_data),
exists=True,
read_time=None,
create_time=None,
update_time=None,
)

values = [json.loads(value) for value in cursor.json_values]
return convert_data(values)
15 changes: 10 additions & 5 deletions firestore/tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,16 +708,19 @@ def test__normalize_cursor_as_snapshot_hit(self):

self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True))

def test__normalize_cursor_w___name___w_slash(self):
def test__normalize_cursor_w___name___w_reference(self):
db_string = "projects/my-project/database/(default)"
client = mock.Mock(spec=["_database_string"])
client._database_string = db_string
parent = mock.Mock(spec=["_path", "_client"])
parent._client = client
parent._path = ["C"]
query = self._make_one(parent).order_by("__name__", "ASCENDING")
expected = "{}/C/b".format(db_string)
cursor = ([expected], True)
docref = self._make_docref("here", "doc_id")
values = {"a": 7}
snapshot = self._make_snapshot(docref, values)
expected = docref
cursor = (snapshot, True)

self.assertEqual(
query._normalize_cursor(cursor, query._orders), ([expected], True)
Expand All @@ -727,16 +730,18 @@ def test__normalize_cursor_w___name___wo_slash(self):
db_string = "projects/my-project/database/(default)"
client = mock.Mock(spec=["_database_string"])
client._database_string = db_string
parent = mock.Mock(spec=["_path", "_client"])
parent = mock.Mock(spec=["_path", "_client", "document"])
parent._client = client
parent._path = ["C"]
document = parent.document.return_value = mock.Mock(spec=[])
query = self._make_one(parent).order_by("__name__", "ASCENDING")
cursor = (["b"], True)
expected = "{}/C/b".format(db_string)
expected = document

self.assertEqual(
query._normalize_cursor(cursor, query._orders), ([expected], True)
)
parent.document.assert_called_once_with("b")

def test__to_protobuf_all_fields(self):
from google.protobuf import wrappers_pb2
Expand Down