diff --git a/newrelic/config.py b/newrelic/config.py index 3e258a19cb..7083cc872b 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2269,23 +2269,52 @@ def _process_module_builtin_defaults(): "instrument_graphql_validate", ) + _process_module_definition( + "google.cloud.firestore_v1.base_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_base_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_client", + ) _process_module_definition( "google.cloud.firestore_v1.document", "newrelic.hooks.datastore_firestore", "instrument_google_cloud_firestore_v1_document", ) - _process_module_definition( "google.cloud.firestore_v1.collection", "newrelic.hooks.datastore_firestore", "instrument_google_cloud_firestore_v1_collection", ) - _process_module_definition( - "google.cloud.firestore_v1.base_client", + "google.cloud.firestore_v1.query", "newrelic.hooks.datastore_firestore", - "instrument_google_cloud_firestore_v1_base_client", + "instrument_google_cloud_firestore_v1_query", ) + _process_module_definition( + "google.cloud.firestore_v1.aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.bulk_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_bulk_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) + _process_module_definition( "ariadne.asgi", "newrelic.hooks.framework_ariadne", diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py index b59df5d0dc..76ed5a4894 100644 --- a/newrelic/hooks/datastore_firestore.py +++ b/newrelic/hooks/datastore_firestore.py @@ -12,22 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.common.object_wrapper import wrap_function_wrapper -from newrelic.api.datastore_trace import wrap_datastore_trace +from newrelic.api.datastore_trace import DatastoreTrace, wrap_datastore_trace from newrelic.api.function_trace import wrap_function_trace from newrelic.common.async_wrapper import generator_wrapper -from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.common.object_wrapper import wrap_function_wrapper -_get_object_id = lambda obj, *args, **kwargs: obj.id +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) -def wrap_generator_method(module, class_name, method_name): +def wrap_generator_method(module, class_name, method_name, target): def _wrapper(wrapped, instance, args, kwargs): - trace = DatastoreTrace(product="Firestore", target=instance.id, operation=method_name) + target_ = target(instance) if callable(target) else target + trace = DatastoreTrace(product="Firestore", target=target_, operation=method_name) wrapped = generator_wrapper(wrapped, trace) return wrapped(*args, **kwargs) - + class_ = getattr(module, class_name) if class_ is not None: if hasattr(class_, method_name): @@ -41,18 +43,30 @@ def instrument_google_cloud_firestore_v1_base_client(module): ) +def instrument_google_cloud_firestore_v1_client(module): + if hasattr(module, "Client"): + class_ = module.Client + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_generator_method(module, "Client", method, target=None) + + def instrument_google_cloud_firestore_v1_collection(module): if hasattr(module, "CollectionReference"): class_ = module.CollectionReference for method in ("add", "get"): if hasattr(class_, method): wrap_datastore_trace( - module, "CollectionReference.%s" % method, product="Firestore", target=_get_object_id, operation=method + module, + "CollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, ) for method in ("stream", "list_documents"): if hasattr(class_, method): - wrap_generator_method(module, "CollectionReference", method) + wrap_generator_method(module, "CollectionReference", method, target=_get_object_id) def instrument_google_cloud_firestore_v1_document(module): @@ -61,9 +75,82 @@ def instrument_google_cloud_firestore_v1_document(module): for method in ("create", "delete", "get", "set", "update"): if hasattr(class_, method): wrap_datastore_trace( - module, "DocumentReference.%s" % method, product="Firestore", target=_get_object_id, operation=method + module, + "DocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, ) for method in ("collections",): if hasattr(class_, method): - wrap_generator_method(module, "DocumentReference", method) + wrap_generator_method(module, "DocumentReference", method, target=_get_object_id) + + +def instrument_google_cloud_firestore_v1_query(module): + if hasattr(module, "Query"): + class_ = module.Query + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, "Query.%s" % method, product="Firestore", target=_get_parent_id, operation=method + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_generator_method(module, "Query", method, target=_get_parent_id) + + if hasattr(module, "CollectionGroup"): + class_ = module.CollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_generator_method(module, "CollectionGroup", method, target=_get_parent_id) + + +def instrument_google_cloud_firestore_v1_aggregation(module): + if hasattr(module, "AggregationQuery"): + class_ = module.AggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_generator_method(module, "AggregationQuery", method, target=_get_collection_ref_id) + + +def instrument_google_cloud_firestore_v1_batch(module): + if hasattr(module, "WriteBatch"): + class_ = module.WriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, "WriteBatch.%s" % method, product="Firestore", target=None, operation=method + ) + + +def instrument_google_cloud_firestore_v1_bulk_batch(module): + if hasattr(module, "BulkWriteBatch"): + class_ = module.BulkWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, "BulkWriteBatch.%s" % method, product="Firestore", target=None, operation=method + ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, "Transaction.%s" % method, product="Firestore", target=None, operation=operation + ) diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py index f4d76c3e41..e104bcd443 100644 --- a/tests/datastore_firestore/conftest.py +++ b/tests/datastore_firestore/conftest.py @@ -15,47 +15,64 @@ import uuid import pytest - from google.cloud.firestore import Client - from testing_support.db_settings import firestore_settings -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace DB_SETTINGS = firestore_settings()[0] FIRESTORE_HOST = DB_SETTINGS["host"] FIRESTORE_PORT = DB_SETTINGS["port"] _default_settings = { - 'transaction_tracer.explain_threshold': 0.0, - 'transaction_tracer.transaction_threshold': 0.0, - 'transaction_tracer.stack_trace_threshold': 0.0, - 'debug.log_data_collector_payloads': True, - 'debug.record_transaction_failure': True, - 'debug.log_explain_plan_queries': True + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, } collector_agent_registration = collector_agent_registration_fixture( - app_name='Python Agent Test (datastore_firestore)', - default_settings=_default_settings, - linked_applications=['Python Agent Test (datastore)']) + app_name="Python Agent Test (datastore_firestore)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) @pytest.fixture(scope="session") def client(): os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) client = Client() - client.collection("healthcheck").document("healthcheck").set({}, retry=None, timeout=5) # Ensure connection is available + client.collection("healthcheck").document("healthcheck").set( + {}, retry=None, timeout=5 + ) # Ensure connection is available return client @pytest.fixture(scope="function") def collection(client): - yield client.collection("firestore_collection_" + str(uuid.uuid4())) + collection_ = client.collection("firestore_collection_" + str(uuid.uuid4())) + yield collection_ + client.recursive_delete(collection_) + + +@pytest.fixture(scope="session") +def assert_trace_for_generator(): + def _assert_trace_for_generator(generator_func, *args, **kwargs): + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + # Check for generator trace on collections + _trace_check = [] + for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. -@pytest.fixture(scope="function", autouse=True) -def reset_firestore(client): - for coll in client.collections(): - for document in coll.list_documents(): - document.delete() + return _assert_trace_for_generator diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py new file mode 100644 index 0000000000..6b3bccf892 --- /dev/null +++ b/tests/datastore_firestore/test_batching.py @@ -0,0 +1,103 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + +# ===== WriteBatch ===== + + +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_write_batch + + +def test_firestore_write_batch(exercise_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_write_batch") + def _test(): + exercise_write_batch() + + _test() + + +# ===== BulkWriteBatch ===== + + +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_bulk_write_batch + + +def test_firestore_bulk_write_batch(exercise_bulk_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_bulk_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_bulk_write_batch") + def _test(): + exercise_bulk_write_batch() + + _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py new file mode 100644 index 0000000000..a44d9d7500 --- /dev/null +++ b/tests/datastore_firestore/test_client.py @@ -0,0 +1,69 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def sample_data(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_client + + +def test_firestore_client(exercise_client): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_client") + def _test(): + exercise_client() + + _test() + + +@background_task() +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): + assert_trace_for_generator(client.collections) + assert_trace_for_generator(client.get_all, [sample_data]) diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py index 30f26eb8f9..7fd035a7a2 100644 --- a/tests/datastore_firestore/test_collections.py +++ b/tests/datastore_firestore/test_collections.py @@ -12,29 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.api.time_trace import current_trace -from newrelic.api.datastore_trace import DatastoreTrace -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from newrelic.api.background_task import background_task +import pytest + from testing_support.validators.validate_database_duration import ( validate_database_duration, ) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task -def _exercise_firestore(collection): - collection.document("DoesNotExist") - collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") - collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") - - documents_get = collection.get() - assert len(documents_get) == 2 - documents_stream = list(collection.stream()) - assert len(documents_stream) == 2 - documents_list = list(collection.list_documents()) - assert len(documents_list) == 2 +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_collections -def test_firestore_collections(collection): + +def test_firestore_collections(exercise_collections, collection): _test_scoped_metrics = [ ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), ("Datastore/statement/Firestore/%s/get" % collection.id, 1), @@ -50,6 +56,8 @@ def test_firestore_collections(collection): ("Datastore/all", 5), ("Datastore/allOther", 5), ] + + @validate_database_duration() @validate_transaction_metrics( "test_firestore_collections", scoped_metrics=_test_scoped_metrics, @@ -58,34 +66,16 @@ def test_firestore_collections(collection): ) @background_task(name="test_firestore_collections") def _test(): - _exercise_firestore(collection) + exercise_collections() _test() @background_task() -def test_firestore_collections_generators(collection): - txn = current_trace() +def test_firestore_collections_generators(collection, assert_trace_for_generator): collection.add({}) collection.add({}) - assert len(list(collection.list_documents())) == 2 - - # Check for generator trace on stream - _trace_check = [] - for _ in collection.stream(): - _trace_check.append(isinstance(current_trace(), DatastoreTrace)) - assert _trace_check and all(_trace_check) - assert current_trace() is txn - - # Check for generator trace on list_documents - _trace_check = [] - for _ in collection.list_documents(): - _trace_check.append(isinstance(current_trace(), DatastoreTrace)) - assert _trace_check and all(_trace_check) - assert current_trace() is txn + assert len([_ for _ in collection.list_documents()]) == 2 - -@validate_database_duration() -@background_task() -def test_firestore_collections_db_duration(collection): - _exercise_firestore(collection) + assert_trace_for_generator(collection.stream) + assert_trace_for_generator(collection.list_documents) diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py index be47e820fd..38b3c4dd76 100644 --- a/tests/datastore_firestore/test_documents.py +++ b/tests/datastore_firestore/test_documents.py @@ -12,32 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.api.time_trace import current_trace -from newrelic.api.datastore_trace import DatastoreTrace -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from newrelic.api.background_task import background_task +import pytest + from testing_support.validators.validate_database_duration import ( validate_database_duration, ) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task -def _exercise_firestore(collection): - italy_doc = collection.document("Italy") - italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) - italy_doc.get() - italian_cities = italy_doc.collection("cities") - italian_cities.add({"capital": "Rome"}) - retrieved_coll = list(italy_doc.collections()) - assert len(retrieved_coll) == 1 +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 - usa_doc = collection.document("USA") - usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) - usa_doc.update({"president": "Joe Biden"}) + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) - collection.document("USA").delete() + collection.document("USA").delete() + return _exercise_documents -def test_firestore_documents(collection): +def test_firestore_documents(exercise_documents): _test_scoped_metrics = [ ("Datastore/statement/Firestore/Italy/set", 1), ("Datastore/statement/Firestore/Italy/get", 1), @@ -59,6 +65,8 @@ def test_firestore_documents(collection): ("Datastore/all", 7), ("Datastore/allOther", 7), ] + + @validate_database_duration() @validate_transaction_metrics( "test_firestore_documents", scoped_metrics=_test_scoped_metrics, @@ -67,30 +75,17 @@ def test_firestore_documents(collection): ) @background_task(name="test_firestore_documents") def _test(): - _exercise_firestore(collection) + exercise_documents() _test() @background_task() -def test_firestore_documents_generators(collection): - txn = current_trace() - +def test_firestore_documents_generators(collection, assert_trace_for_generator): subcollection_doc = collection.document("SubCollections") subcollection_doc.set({}) subcollection_doc.collection("collection1").add({}) subcollection_doc.collection("collection2").add({}) - assert len(list(subcollection_doc.collections())) == 2 - - # Check for generator trace on collections - _trace_check = [] - for _ in subcollection_doc.collections(): - _trace_check.append(isinstance(current_trace(), DatastoreTrace)) - assert _trace_check and all(_trace_check) - assert current_trace() is txn + assert len([_ for _ in subcollection_doc.collections()]) == 2 - -@validate_database_duration() -@background_task() -def test_firestore_documents_db_duration(collection): - _exercise_firestore(collection) + assert_trace_for_generator(subcollection_doc.collections) diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py new file mode 100644 index 0000000000..a71ad5c8e0 --- /dev/null +++ b/tests/datastore_firestore/test_query.py @@ -0,0 +1,200 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== Query ===== + + +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + return _exercise_query + + +def test_firestore_query(exercise_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_query") + def _test(): + exercise_query() + + _test() + + +@background_task() +def test_firestore_query_generators(collection, assert_trace_for_generator): + query = collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_generator(query.stream) + + +# ===== AggregationQuery ===== + + +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + return _exercise_aggregation_query + + +def test_firestore_aggregation_query(exercise_aggregation_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_aggregation_query") + def _test(): + exercise_aggregation_query() + + _test() + + +@background_task() +def test_firestore_aggregation_query_generators(collection, assert_trace_for_generator): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_generator(aggregation_query.stream) + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path. + To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available. + The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure + we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + def mock_partition_query(*args, **kwargs): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + monkeypatch.setattr(client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_collection_group(client, collection): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection, patch_partition_queries): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collection_group") + def _test(): + exercise_collection_group() + + _test() + + +@background_task() +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) + assert_trace_for_generator(collection_group.get_partitions, 1) diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 0000000000..8fc524c4e5 --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,128 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + return _exercise_transaction_commit + + +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + return _exercise_transaction_rollback + + +def test_firestore_transaction_commit(exercise_transaction_commit, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_rollback() + + _test()