diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py index 9f28e7caf3..69fd8f1cc9 100644 --- a/newrelic/hooks/datastore_firestore.py +++ b/newrelic/hooks/datastore_firestore.py @@ -21,25 +21,9 @@ from newrelic.api.datastore_trace import DatastoreTrace -def _get_object_id(obj, *args, **kwargs): - try: - return obj.id - except Exception: - return None - - -def _get_parent_id(obj, *args, **kwargs): - try: - return obj._parent.id - except Exception: - return None - - -def _get_collection_ref_id(obj, *args, **kwargs): - try: - return obj._collection_ref.id - except Exception: - return None +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) def wrap_generator_method(module, class_name, method_name, target, is_async=False): diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py index 95f2d6cc54..2e8b883e4f 100644 --- a/tests/datastore_firestore/conftest.py +++ b/tests/datastore_firestore/conftest.py @@ -15,7 +15,10 @@ import uuid import pytest + +from google.cloud.firestore import Client from google.cloud.firestore import Client, AsyncClient + from testing_support.db_settings import firestore_settings from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py index c0f52530b2..6b3bccf892 100644 --- a/tests/datastore_firestore/test_batching.py +++ b/tests/datastore_firestore/test_batching.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from testing_support.validators.validate_database_duration import ( validate_database_duration, ) @@ -24,16 +26,19 @@ # ===== WriteBatch ===== -def _exercise_write_batch(client, collection): - docs = [collection.document(str(x)) for x in range(1, 4)] - batch = client.batch() - for doc in docs: - batch.set(doc, {}) +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) - batch.commit() + batch.commit() + return _exercise_write_batch -def test_firestore_write_batch(client, collection): +def test_firestore_write_batch(exercise_write_batch): _test_scoped_metrics = [ ("Datastore/operation/Firestore/commit", 1), ] @@ -52,7 +57,7 @@ def test_firestore_write_batch(client, collection): ) @background_task(name="test_firestore_write_batch") def _test(): - _exercise_write_batch(client, collection) + exercise_write_batch() _test() @@ -60,18 +65,21 @@ def _test(): # ===== BulkWriteBatch ===== -def _exercise_bulk_write_batch(client, collection): - from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch - docs = [collection.document(str(x)) for x in range(1, 4)] - batch = BulkWriteBatch(client) - for doc in docs: - batch.set(doc, {}) + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) - batch.commit() + batch.commit() + return _exercise_bulk_write_batch -def test_firestore_bulk_write_batch(client, collection): +def test_firestore_bulk_write_batch(exercise_bulk_write_batch): _test_scoped_metrics = [ ("Datastore/operation/Firestore/commit", 1), ] @@ -90,6 +98,6 @@ def test_firestore_bulk_write_batch(client, collection): ) @background_task(name="test_firestore_bulk_write_batch") def _test(): - _exercise_bulk_write_batch(client, collection) + exercise_bulk_write_batch() _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py index d9d46ea206..a44d9d7500 100644 --- a/tests/datastore_firestore/test_client.py +++ b/tests/datastore_firestore/test_client.py @@ -29,13 +29,16 @@ def sample_data(collection): return doc -def _exercise_client(client, collection, sample_data): - assert len([_ for _ in client.collections()]) - doc = [_ for _ in client.get_all([sample_data])][0] - assert doc.to_dict()["x"] == 1 +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_client -def test_firestore_client(client, collection, sample_data): +def test_firestore_client(exercise_client): _test_scoped_metrics = [ ("Datastore/operation/Firestore/collections", 1), ("Datastore/operation/Firestore/get_all", 1), @@ -55,12 +58,12 @@ def test_firestore_client(client, collection, sample_data): ) @background_task(name="test_firestore_client") def _test(): - _exercise_client(client, collection, sample_data) + exercise_client() _test() @background_task() -def test_firestore_client_generators(client, collection, sample_data, assert_trace_for_generator): +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): assert_trace_for_generator(client.collections) assert_trace_for_generator(client.get_all, [sample_data]) diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py index 49a97810af..ec3bb7ed5e 100644 --- a/tests/datastore_firestore/test_collections.py +++ b/tests/datastore_firestore/test_collections.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from testing_support.validators.validate_database_duration import ( validate_database_duration, ) @@ -22,20 +24,23 @@ from newrelic.api.background_task import background_task -def _exercise_collections(collection): - collection.document("DoesNotExist") - collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") - collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") - documents_get = collection.get() - assert len(documents_get) == 2 - documents_stream = [_ for _ in collection.stream()] - assert len(documents_stream) == 2 - documents_list = [_ for _ in collection.list_documents()] - assert len(documents_list) == 2 + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_collections -def test_firestore_collections(collection): +def test_firestore_collections(exercise_collections, collection): _test_scoped_metrics = [ ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), ("Datastore/statement/Firestore/%s/get" % collection.id, 1), @@ -60,7 +65,7 @@ def test_firestore_collections(collection): ) @background_task(name="test_firestore_collections") def _test(): - _exercise_collections(collection) + exercise_collections() _test() diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py index 7873cf024d..3634f66575 100644 --- a/tests/datastore_firestore/test_documents.py +++ b/tests/datastore_firestore/test_documents.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from testing_support.validators.validate_database_duration import ( validate_database_duration, ) @@ -22,23 +24,26 @@ from newrelic.api.background_task import background_task -def _exercise_documents(collection): - italy_doc = collection.document("Italy") - italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) - italy_doc.get() - italian_cities = italy_doc.collection("cities") - italian_cities.add({"capital": "Rome"}) - retrieved_coll = [_ for _ in italy_doc.collections()] - assert len(retrieved_coll) == 1 +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 - usa_doc = collection.document("USA") - usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) - usa_doc.update({"president": "Joe Biden"}) + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) - collection.document("USA").delete() + collection.document("USA").delete() + return _exercise_documents -def test_firestore_documents(collection): +def test_firestore_documents(exercise_documents): _test_scoped_metrics = [ ("Datastore/statement/Firestore/Italy/set", 1), ("Datastore/statement/Firestore/Italy/get", 1), @@ -69,7 +74,7 @@ def test_firestore_documents(collection): ) @background_task(name="test_firestore_documents") def _test(): - _exercise_documents(collection) + exercise_documents() _test() diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py index 144b582f49..a71ad5c8e0 100644 --- a/tests/datastore_firestore/test_query.py +++ b/tests/datastore_firestore/test_query.py @@ -36,13 +36,16 @@ def sample_data(collection): # ===== Query ===== -def _exercise_query(collection): - query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) - assert len(query.get()) == 3 - assert len([_ for _ in query.stream()]) == 3 +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + return _exercise_query -def test_firestore_query(collection): +def test_firestore_query(exercise_query, collection): _test_scoped_metrics = [ ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), ("Datastore/statement/Firestore/%s/get" % collection.id, 1), @@ -64,7 +67,7 @@ def test_firestore_query(collection): ) @background_task(name="test_firestore_query") def _test(): - _exercise_query(collection) + exercise_query() _test() @@ -78,13 +81,16 @@ def test_firestore_query_generators(collection, assert_trace_for_generator): # ===== AggregationQuery ===== -def _exercise_aggregation_query(collection): - aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() - assert aggregation_query.get()[0][0].value == 3 - assert [_ for _ in aggregation_query.stream()][0][0].value == 3 +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + return _exercise_aggregation_query -def test_firestore_aggregation_query(collection): +def test_firestore_aggregation_query(exercise_aggregation_query, collection): _test_scoped_metrics = [ ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), ("Datastore/statement/Firestore/%s/get" % collection.id, 1), @@ -106,7 +112,7 @@ def test_firestore_aggregation_query(collection): ) @background_task(name="test_firestore_aggregation_query") def _test(): - _exercise_aggregation_query(collection) + exercise_aggregation_query() _test() @@ -143,22 +149,23 @@ def mock_partition_query(*args, **kwargs): yield -def _exercise_collection_group(collection): - from google.cloud.firestore import CollectionGroup - - collection_group = CollectionGroup(collection) - assert len(collection_group.get()) - assert len([d for d in collection_group.stream()]) - - partitions = [p for p in collection_group.get_partitions(1)] - assert len(partitions) == 2 - documents = [] - while partitions: - documents.extend(partitions.pop().query().get()) - assert len(documents) == 6 - - -def test_firestore_collection_group(collection, patch_partition_queries): +@pytest.fixture() +def exercise_collection_group(client, collection): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection, patch_partition_queries): _test_scoped_metrics = [ ("Datastore/statement/Firestore/%s/get" % collection.id, 3), ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), @@ -182,14 +189,12 @@ def test_firestore_collection_group(collection, patch_partition_queries): ) @background_task(name="test_firestore_collection_group") def _test(): - _exercise_collection_group(collection) + exercise_collection_group() _test() @background_task() -def test_firestore_collection_group_generators(collection, assert_trace_for_generator, patch_partition_queries): - from google.cloud.firestore import CollectionGroup - - collection_group = CollectionGroup(collection) +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) assert_trace_for_generator(collection_group.get_partitions, 1) diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py index a32ad90754..8fc524c4e5 100644 --- a/tests/datastore_firestore/test_transaction.py +++ b/tests/datastore_firestore/test_transaction.py @@ -28,46 +28,52 @@ def sample_data(collection): collection.add({"x": x}, "doc%d" % x) -def _exercise_transaction_commit(client, collection): - from google.cloud.firestore import transactional +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional - @transactional - def _exercise(transaction): - # get a DocumentReference - [_ for _ in transaction.get(collection.document("doc1"))] + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] - # get a Query - query = collection.select("x").where(field_path="x", op_string=">", value=2) - assert len([_ for _ in transaction.get(query)]) == 1 + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 - # get_all on a list of DocumentReferences - all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) - assert len([_ for _ in all_docs]) == 3 + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 - # set and delete methods - transaction.set(collection.document("doc2"), {"x": 0}) - transaction.delete(collection.document("doc3")) + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) - _exercise(client.transaction()) - assert len([_ for _ in collection.list_documents()]) == 2 + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + return _exercise_transaction_commit -def _exercise_transaction_rollback(client, collection): - from google.cloud.firestore import transactional +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional - @transactional - def _exercise(transaction): - # set and delete methods - transaction.set(collection.document("doc2"), {"x": 99}) - transaction.delete(collection.document("doc1")) - raise RuntimeError() + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() - with pytest.raises(RuntimeError): - _exercise(client.transaction()) - assert len([_ for _ in collection.list_documents()]) == 3 + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + return _exercise_transaction_rollback -def test_firestore_transaction_commit(client, collection): +def test_firestore_transaction_commit(exercise_transaction_commit, collection): _test_scoped_metrics = [ ("Datastore/operation/Firestore/commit", 1), ("Datastore/operation/Firestore/get_all", 2), @@ -91,12 +97,12 @@ def test_firestore_transaction_commit(client, collection): ) @background_task(name="test_firestore_transaction") def _test(): - _exercise_transaction_commit(client, collection) + exercise_transaction_commit() _test() -def test_firestore_transaction_rollback(client, collection): +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection): _test_scoped_metrics = [ ("Datastore/operation/Firestore/rollback", 1), ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), @@ -117,6 +123,6 @@ def test_firestore_transaction_rollback(client, collection): ) @background_task(name="test_firestore_transaction") def _test(): - _exercise_transaction_rollback(client, collection) + exercise_transaction_rollback() _test()