From 4f3aa3a13c13a97dbab642ad7d5fccef392327b7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 10:08:32 -0700 Subject: [PATCH 01/10] fixed up unit tests --- tests/unit/test_client.py | 367 ++++++++++++++++++-------------------- 1 file changed, 175 insertions(+), 192 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 7009069d1..e1663b016 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -280,32 +280,38 @@ async def test_start_background_channel_refresh_tasks_names(self): @pytest.mark.asyncio async def test__ping_and_warm_instances(self): - # test with no instances + """ + test ping and warm with mocked asyncio.gather + """ + client_mock = mock.Mock() with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - client = self._make_one(project="project-id", pool_size=1) - channel = client.transport._grpc_channel._pool[0] - await client._ping_and_warm_instances(channel) + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + channel = mock.Mock() + # test with no instances + client_mock._active_instances = [] + result = await self._get_target_class()._ping_and_warm_instances(client_mock, channel) + assert len(result) == 0 gather.assert_called_once() gather.assert_awaited_once() assert not gather.call_args.args assert gather.call_args.kwargs == {"return_exceptions": True} # test with instances - client._active_instances = [ + client_mock._active_instances = [ "instance-1", "instance-2", "instance-3", "instance-4", ] - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - await client._ping_and_warm_instances(channel) + gather.reset_mock() + result = await self._get_target_class()._ping_and_warm_instances(client_mock, channel) + assert len(result) == 4 gather.assert_called_once() gather.assert_awaited_once() assert len(gather.call_args.args) == 4 assert gather.call_args.kwargs == {"return_exceptions": True} for idx, call in enumerate(gather.call_args.args): - assert isinstance(call, grpc.aio.UnaryUnaryCall) - call._request["name"] = client._active_instances[idx] - await client.close() + assert call == channel.unary_unary()() @pytest.mark.asyncio @pytest.mark.parametrize( @@ -866,6 +872,17 @@ def _make_client(self, *args, **kwargs): return BigtableDataClient(*args, **kwargs) + def _make_table(self, *args, **kwargs): + from google.cloud.bigtable.client import Table + client_mock = mock.Mock() + client_mock._register_instance.side_effect = lambda *args, **kwargs: asyncio.sleep(0) + client_mock._remove_instance_registration.side_effect = lambda *args, **kwargs: asyncio.sleep(0) + kwargs["instance_id"] = kwargs.get("instance_id", args[0] if args else "instance") + kwargs["table_id"] = kwargs.get("table_id", args[1] if len(args) > 1 else "table") + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return Table(client_mock, *args, **kwargs) + def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats from google.cloud.bigtable_v2.types import FullReadStatsView @@ -928,14 +945,13 @@ def cancel(self): @pytest.mark.asyncio async def test_read_rows(self): - client = self._make_client() - table = client.get_table("instance", "table") query = ReadRowsQuery() chunks = [ self._make_chunk(row_key=b"test_1"), self._make_chunk(row_key=b"test_2"), ] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( chunks ) @@ -943,18 +959,16 @@ async def test_read_rows(self): assert len(results) == 2 assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - await client.close() @pytest.mark.asyncio async def test_read_rows_stream(self): - client = self._make_client() - table = client.get_table("instance", "table") query = ReadRowsQuery() chunks = [ self._make_chunk(row_key=b"test_1"), self._make_chunk(row_key=b"test_2"), ] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( chunks ) @@ -963,16 +977,18 @@ async def test_read_rows_stream(self): assert len(results) == 2 assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - await client.close() @pytest.mark.parametrize("include_app_profile", [True, False]) @pytest.mark.asyncio async def test_read_rows_query_matches_request(self, include_app_profile): from google.cloud.bigtable import RowRange - async with self._make_client() as client: - app_profile_id = "app_profile_id" if include_app_profile else None - table = client.get_table("instance", "table", app_profile_id=app_profile_id) + app_profile_id = "app_profile_id" if include_app_profile else None + async with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [] + ) row_keys = [b"test_1", "test_2"] row_ranges = RowRange("start", "end") filter_ = {"test": "filter"} @@ -983,52 +999,44 @@ async def test_read_rows_query_matches_request(self, include_app_profile): row_filter=filter_, limit=limit, ) - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [] - ) - results = await table.read_rows(query, operation_timeout=3) - assert len(results) == 0 - call_request = read_rows.call_args_list[0][0][0] - query_dict = query._to_dict() - if include_app_profile: - assert set(call_request.keys()) == set(query_dict.keys()) | { - "table_name", - "app_profile_id", - } - else: - assert set(call_request.keys()) == set(query_dict.keys()) | { - "table_name" - } - assert call_request["rows"] == query_dict["rows"] - assert call_request["filter"] == filter_ - assert call_request["rows_limit"] == limit - assert call_request["table_name"] == table.table_name - if include_app_profile: - assert call_request["app_profile_id"] == app_profile_id + + results = await table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_dict = query._to_dict() + if include_app_profile: + assert set(call_request.keys()) == set(query_dict.keys()) | { + "table_name", + "app_profile_id", + } + else: + assert set(call_request.keys()) == set(query_dict.keys()) | { + "table_name" + } + assert call_request["rows"] == query_dict["rows"] + assert call_request["filter"] == filter_ + assert call_request["rows_limit"] == limit + assert call_request["table_name"] == table.table_name + if include_app_profile: + assert call_request["app_profile_id"] == app_profile_id @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) @pytest.mark.asyncio async def test_read_rows_timeout(self, operation_timeout): - async with self._make_client() as client: - table = client.get_table("instance", "table") + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() chunks = [self._make_chunk(row_key=b"test_1")] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=1 + ) + try: + await table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" ) - try: - await table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) @pytest.mark.parametrize( "per_request_t, operation_t, expected_num", @@ -1056,46 +1064,44 @@ async def test_read_rows_per_request_timeout( # mocking uniform ensures there are no sleeps between retries with mock.patch("random.uniform", side_effect=lambda a, b: 0): - async with self._make_client() as client: - table = client.get_table("instance", "table") + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = ( + lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + ) query = ReadRowsQuery() chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=per_request_t - ) + + try: + await table.read_rows( + query, + operation_timeout=operation_t, + per_request_timeout=per_request_t, ) - try: - await table.read_rows( - query, - operation_timeout=operation_t, - per_request_timeout=per_request_t, - ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) == RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" - assert read_rows.call_count == expected_num - # check timeouts - for _, call_kwargs in read_rows.call_args_list[:-1]: - assert call_kwargs["timeout"] == per_request_t - # last timeout should be adjusted to account for the time spent - assert ( - abs( - read_rows.call_args_list[-1][1]["timeout"] - - expected_last_timeout - ) - < 0.05 + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) == RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + # check timeouts + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + # last timeout should be adjusted to account for the time spent + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout ) + < 0.05 + ) @pytest.mark.asyncio async def test_read_rows_idle_timeout(self): @@ -1154,23 +1160,20 @@ async def test_read_rows_idle_timeout(self): ) @pytest.mark.asyncio async def test_read_rows_retryable_error(self, exc_type): - async with self._make_client() as client: - table = client.get_table("instance", "table") + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) query = ReadRowsQuery() expected_error = exc_type("mock error") - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - try: - await table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) == exc_type - assert root_cause == expected_error + try: + await table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) == exc_type + assert root_cause == expected_error @pytest.mark.parametrize( "exc_type", @@ -1188,20 +1191,17 @@ async def test_read_rows_retryable_error(self, exc_type): ) @pytest.mark.asyncio async def test_read_rows_non_retryable_error(self, exc_type): - async with self._make_client() as client: - table = client.get_table("instance", "table") + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) query = ReadRowsQuery() expected_error = exc_type("mock error") - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - try: - await table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error + try: + await table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error @pytest.mark.asyncio async def test_read_rows_revise_request(self): @@ -1216,32 +1216,29 @@ async def test_read_rows_revise_request(self): ) as revise_rowset: with mock.patch.object(_ReadRowsOperation, "aclose"): revise_rowset.return_value = "modified" - async with self._make_client() as client: - table = client.get_table("instance", "table") + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = ( + lambda *args, **kwargs: self._make_gapic_stream(chunks) + ) row_keys = [b"test_1", b"test_2", b"test_3"] query = ReadRowsQuery(row_keys=row_keys) chunks = [ self._make_chunk(row_key=b"test_1"), core_exceptions.Aborted("mock retryable error"), ] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: self._make_gapic_stream(chunks) + try: + await table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + revise_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert ( + revise_call_kwargs["row_set"] + == query._to_dict()["rows"] ) - try: - await table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - revise_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert ( - revise_call_kwargs["row_set"] - == query._to_dict()["rows"] - ) - assert revise_call_kwargs["last_seen_row_key"] == b"test_1" - read_rows_request = read_rows.call_args_list[1].args[0] - assert read_rows_request["rows"] == "modified" + assert revise_call_kwargs["last_seen_row_key"] == b"test_1" + read_rows_request = read_rows.call_args_list[1].args[0] + assert read_rows_request["rows"] == "modified" @pytest.mark.asyncio async def test_read_rows_default_timeouts(self): @@ -1254,20 +1251,14 @@ async def test_read_rows_default_timeouts(self): per_request_timeout = 4 with mock.patch.object(_ReadRowsOperation, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") - async with self._make_client() as client: - async with client.get_table( - "instance", - "table", - default_operation_timeout=operation_timeout, - default_per_request_timeout=per_request_timeout, - ) as table: - try: - await table.read_rows(ReadRowsQuery()) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["per_request_timeout"] == per_request_timeout + async with self._make_table(default_operation_timeout=operation_timeout,default_per_request_timeout=per_request_timeout) as table: + try: + await table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["per_request_timeout"] == per_request_timeout @pytest.mark.asyncio async def test_read_rows_default_timeout_override(self): @@ -1280,24 +1271,18 @@ async def test_read_rows_default_timeout_override(self): per_request_timeout = 4 with mock.patch.object(_ReadRowsOperation, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") - async with self._make_client() as client: - async with client.get_table( - "instance", - "table", - default_operation_timeout=99, - default_per_request_timeout=97, - ) as table: - try: - await table.read_rows( - ReadRowsQuery(), - operation_timeout=operation_timeout, - per_request_timeout=per_request_timeout, - ) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["per_request_timeout"] == per_request_timeout + async with self._make_table(default_operation_timeout=99, default_per_request_timeout=97) as table: + try: + await table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + per_request_timeout=per_request_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["per_request_timeout"] == per_request_timeout @pytest.mark.asyncio async def test_read_row(self): @@ -1456,24 +1441,22 @@ async def test_row_exists_w_invalid_input(self, input_row): async def test_read_rows_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "read_rows", AsyncMock() - ) as read_rows: - await table.read_rows(ReadRowsQuery()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + async with self._make_table(app_profile_id=profile) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.return_value = self._make_gapic_stream([]) + await table.read_rows(ReadRowsQuery()) + kwargs = read_rows.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata class TestMutateRow: From 0b99b89c8caf7907f3b440070b9c283e95888251 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 10:08:46 -0700 Subject: [PATCH 02/10] added ping and warm to system test --- tests/system/test_system.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 692911b10..b8abebcd6 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -224,6 +224,21 @@ async def test_ping_and_warm_gapic(client, table): await client._gapic_client.ping_and_warm(request) +@retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_ping_and_warm(client, table): + """ + Test ping and warm from handwritten client + """ + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + # for sync client + channel = client.transport._grpc_channel + results = await client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) @pytest.mark.asyncio async def test_mutation_set_cell(table, temp_rows): From 6a35be46ae5357487804bbb86217d4223111d773 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 10:11:50 -0700 Subject: [PATCH 03/10] fixed broken ping and warm --- google/cloud/bigtable/client.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index f75613098..6cbd3fa16 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -38,6 +38,7 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, ) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.api_core.exceptions import GoogleAPICallError from google.api_core import retry_async as retries @@ -190,10 +191,13 @@ async def _ping_and_warm_instances( - sequence of results or exceptions from the ping requests """ ping_rpc = channel.unary_unary( - "/google.bigtable.v2.Bigtable/PingAndWarmChannel" + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, ) tasks = [ping_rpc({"name": n}) for n in self._active_instances] - return await asyncio.gather(*tasks, return_exceptions=True) + result_list = await asyncio.gather(*tasks, return_exceptions=True) + # return None in place of empty successful responses + return [r or None for r in result_list] async def _manage_channel( self, From aa3590193cf1ceba3575b820ad4a610327efa48d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 10:30:31 -0700 Subject: [PATCH 04/10] added table and app profile to warming params --- google/cloud/bigtable/client.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index 6cbd3fa16..06af5f31e 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -32,6 +32,8 @@ import sys import random +from collections import namedtuple + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO @@ -66,6 +68,11 @@ from google.cloud.bigtable.mutations_batcher import MutationsBatcher from google.cloud.bigtable import RowKeySamples +# used to register instance data with the client for channel warming +_WarmedInstanceKey = namedtuple( + "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] +) + class BigtableDataClient(ClientWithProject): def __init__( @@ -131,10 +138,10 @@ def __init__( PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport ) # keep track of active instances to for warmup on channel refresh - self._active_instances: Set[str] = set() + self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it - self._instance_owners: dict[str, Set[int]] = {} + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} # attempt to start background tasks self._channel_init_time = time.time() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] @@ -194,7 +201,15 @@ async def _ping_and_warm_instances( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) - tasks = [ping_rpc({"name": n}) for n in self._active_instances] + tasks = [] + for (instance_name, table_name, app_profile_id) in self._active_instances: + tasks.append( + ping_rpc( + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=_make_metadata(table_name, app_profile_id), + wait_for_ready=True, + ) + ) result_list = await asyncio.gather(*tasks, return_exceptions=True) # return None in place of empty successful responses return [r or None for r in result_list] @@ -263,9 +278,12 @@ async def _register_instance(self, instance_id: str, owner: Table) -> None: owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - self._instance_owners.setdefault(instance_name, set()).add(id(owner)) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) if instance_name not in self._active_instances: - self._active_instances.add(instance_name) + self._active_instances.add(instance_key) if self._channel_refresh_tasks: # refresh tasks already running # call ping and warm on all existing channels From 4e1ec6f895736eadcf6bdff0ab10d88d65eede5d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 11:35:30 -0700 Subject: [PATCH 05/10] fixed _remove_instance_registration --- google/cloud/bigtable/client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index 06af5f31e..b9d632799 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -311,11 +311,14 @@ async def _remove_instance_registration( - True if instance was removed """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - owner_list = self._instance_owners.get(instance_name, set()) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) try: owner_list.remove(id(owner)) if len(owner_list) == 0: - self._active_instances.remove(instance_name) + self._active_instances.remove(instance_key) return True except KeyError: return False From 8f4b34f5ec9cb35fe16b978efc72a90ae5d570d7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 11:35:46 -0700 Subject: [PATCH 06/10] improved tests --- tests/system/test_system.py | 1 + tests/unit/test_client.py | 215 +++++++++++++++++++++--------------- 2 files changed, 130 insertions(+), 86 deletions(-) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index b8abebcd6..6b9d69242 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -239,6 +239,7 @@ async def test_ping_and_warm(client, table): assert len(results) == 1 assert results[0] is None + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) @pytest.mark.asyncio async def test_mutation_set_cell(table, temp_rows): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e1663b016..8f54cc4d0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -290,28 +290,36 @@ async def test__ping_and_warm_instances(self): channel = mock.Mock() # test with no instances client_mock._active_instances = [] - result = await self._get_target_class()._ping_and_warm_instances(client_mock, channel) + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) assert len(result) == 0 gather.assert_called_once() gather.assert_awaited_once() assert not gather.call_args.args assert gather.call_args.kwargs == {"return_exceptions": True} # test with instances - client_mock._active_instances = [ - "instance-1", - "instance-2", - "instance-3", - "instance-4", - ] + client_mock._active_instances = [(mock.Mock(), mock.Mock(), mock.Mock())] * 4 gather.reset_mock() - result = await self._get_target_class()._ping_and_warm_instances(client_mock, channel) + channel.reset_mock() + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) assert len(result) == 4 gather.assert_called_once() gather.assert_awaited_once() assert len(gather.call_args.args) == 4 - assert gather.call_args.kwargs == {"return_exceptions": True} - for idx, call in enumerate(gather.call_args.args): - assert call == channel.unary_unary()() + # check grpc call arguments + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + expected_instance, expected_table, expected_app_profile = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == 'x-goog-request-params' + assert metadata[0][1] == f'table_name={expected_table},app_profile_id={expected_app_profile}' @pytest.mark.asyncio @pytest.mark.parametrize( @@ -508,58 +516,84 @@ async def test__manage_channel_refresh(self, num_cycles): await client.close() @pytest.mark.asyncio - @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__register_instance(self): - # create the client without calling start_background_channel_refresh - with mock.patch.object(asyncio, "get_running_loop") as get_event_loop: - get_event_loop.side_effect = RuntimeError("no event loop") - client = self._make_one(project="project-id") - assert not client._channel_refresh_tasks - # first call should start background refresh - assert client._active_instances == set() - await client._register_instance("instance-1", mock.Mock()) - assert len(client._active_instances) == 1 - assert client._active_instances == {"projects/project-id/instances/instance-1"} - assert client._channel_refresh_tasks - # next call should not - with mock.patch.object( - type(self._make_one()), "start_background_channel_refresh" - ) as refresh_mock: - await client._register_instance("instance-2", mock.Mock()) - assert len(client._active_instances) == 2 - assert client._active_instances == { - "projects/project-id/instances/instance-1", - "projects/project-id/instances/instance-2", - } - refresh_mock.assert_not_called() - - @pytest.mark.asyncio - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - async def test__register_instance_ping_and_warm(self): - # should ping and warm each new instance - pool_size = 7 - with mock.patch.object(asyncio, "get_running_loop") as get_event_loop: - get_event_loop.side_effect = RuntimeError("no event loop") - client = self._make_one(project="project-id", pool_size=pool_size) + """ + test instance registration + """ + # set up mock client + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock.start_background_channel_refresh.side_effect = lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + await self._get_target_class()._register_instance(client_mock, "instance-1", table_mock) # first call should start background refresh - assert not client._channel_refresh_tasks - await client._register_instance("instance-1", mock.Mock()) - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client._channel_refresh_tasks) == pool_size - assert not client._active_instances - # next calls should trigger ping and warm - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_mock: - # new instance should trigger ping and warm - await client._register_instance("instance-2", mock.Mock()) - assert ping_mock.call_count == pool_size - await client._register_instance("instance-3", mock.Mock()) - assert ping_mock.call_count == pool_size * 2 - # duplcate instances should not trigger ping and warm - await client._register_instance("instance-3", mock.Mock()) - assert ping_mock.call_count == pool_size * 2 - await client.close() + assert client_mock.start_background_channel_refresh.call_count == 1 + # ensure active_instances and instance_owners were updated properly + expected_key = ("prefix/instance-1", table_mock.table_name, table_mock.app_profile_id) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + # should be a new task set + assert client_mock._channel_refresh_tasks + # # next call should not call start_background_channel_refresh again + table_mock2 = mock.Mock() + await self._get_target_class()._register_instance(client_mock, "instance-2", table_mock2) + assert client_mock.start_background_channel_refresh.call_count == 1 + # but it should call ping and warm with new instance key + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [call[0][0] for call in client_mock._ping_and_warm_instances.call_args_list] + # check for updated lists + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ("prefix/instance-2", table_mock2.table_name, table_mock2.app_profile_id) + assert any([expected_key2 == tuple(list(active_instances)[i]) for i in range(len(active_instances))]) + assert any([expected_key2 == tuple(list(instance_owners)[i]) for i in range(len(instance_owners))]) + + @pytest.mark.asyncio + @pytest.mark.parametrize("insert_instances,expected_active,expected_owner_keys", [ + ([('i','t',None)], [('i','t',None)], [('i','t',None)]), + ([('i','t','p')], [('i','t','p')], [('i','t','p')]), + ([('1','t','p'), ('1','t','p')], [('1','t','p')], [('1','t','p')]), + ([('1','t','p'), ('2','t','p')], [('1','t','p'), ('2','t','p')], [('1','t','p'), ('2','t','p')]), + ]) + async def test__register_instance_state(self, insert_instances, expected_active, expected_owner_keys): + """ + test that active_instances and instance_owners are updated as expected + """ + # set up mock client + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock.start_background_channel_refresh.side_effect = lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + # register instances + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + await self._get_target_class()._register_instance(client_mock, instance, table_mock) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any([expected == tuple(list(active_instances)[i]) for i in range(len(active_instances))]) + for expected in expected_owner_keys: + assert any([expected == tuple(list(instance_owners)[i]) for i in range(len(instance_owners))]) @pytest.mark.asyncio async def test__remove_instance_registration(self): @@ -572,20 +606,22 @@ async def test__remove_instance_registration(self): instance_1_path = client._gapic_client.instance_path( client.project, "instance-1" ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) instance_2_path = client._gapic_client.instance_path( client.project, "instance-2" ) - assert len(client._instance_owners[instance_1_path]) == 1 - assert list(client._instance_owners[instance_1_path])[0] == id(table) - assert len(client._instance_owners[instance_2_path]) == 1 - assert list(client._instance_owners[instance_2_path])[0] == id(table) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) success = await client._remove_instance_registration("instance-1", table) assert success assert len(client._active_instances) == 1 - assert len(client._instance_owners[instance_1_path]) == 0 - assert len(client._instance_owners[instance_2_path]) == 1 - assert client._active_instances == {"projects/project-id/instances/instance-2"} - success = await client._remove_instance_registration("nonexistant", table) + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = await client._remove_instance_registration("fake-key", table) assert not success assert len(client._active_instances) == 1 await client.close() @@ -874,11 +910,20 @@ def _make_client(self, *args, **kwargs): def _make_table(self, *args, **kwargs): from google.cloud.bigtable.client import Table + client_mock = mock.Mock() - client_mock._register_instance.side_effect = lambda *args, **kwargs: asyncio.sleep(0) - client_mock._remove_instance_registration.side_effect = lambda *args, **kwargs: asyncio.sleep(0) - kwargs["instance_id"] = kwargs.get("instance_id", args[0] if args else "instance") - kwargs["table_id"] = kwargs.get("table_id", args[1] if len(args) > 1 else "table") + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" + ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] return Table(client_mock, *args, **kwargs) @@ -986,9 +1031,7 @@ async def test_read_rows_query_matches_request(self, include_app_profile): app_profile_id = "app_profile_id" if include_app_profile else None async with self._make_table(app_profile_id=app_profile_id) as table: read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [] - ) + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) row_keys = [b"test_1", "test_2"] row_ranges = RowRange("start", "end") filter_ = {"test": "filter"} @@ -1066,10 +1109,8 @@ async def test_read_rows_per_request_timeout( with mock.patch("random.uniform", side_effect=lambda a, b: 0): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = ( - lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=per_request_t - ) + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t ) query = ReadRowsQuery() chunks = [core_exceptions.DeadlineExceeded("mock deadline")] @@ -1232,10 +1273,7 @@ async def test_read_rows_revise_request(self): except InvalidChunk: revise_rowset.assert_called() revise_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert ( - revise_call_kwargs["row_set"] - == query._to_dict()["rows"] - ) + assert revise_call_kwargs["row_set"] == query._to_dict()["rows"] assert revise_call_kwargs["last_seen_row_key"] == b"test_1" read_rows_request = read_rows.call_args_list[1].args[0] assert read_rows_request["rows"] == "modified" @@ -1251,7 +1289,10 @@ async def test_read_rows_default_timeouts(self): per_request_timeout = 4 with mock.patch.object(_ReadRowsOperation, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") - async with self._make_table(default_operation_timeout=operation_timeout,default_per_request_timeout=per_request_timeout) as table: + async with self._make_table( + default_operation_timeout=operation_timeout, + default_per_request_timeout=per_request_timeout, + ) as table: try: await table.read_rows(ReadRowsQuery()) except RuntimeError: @@ -1271,7 +1312,9 @@ async def test_read_rows_default_timeout_override(self): per_request_timeout = 4 with mock.patch.object(_ReadRowsOperation, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") - async with self._make_table(default_operation_timeout=99, default_per_request_timeout=97) as table: + async with self._make_table( + default_operation_timeout=99, default_per_request_timeout=97 + ) as table: try: await table.read_rows( ReadRowsQuery(), From 9c8df9fe6cafaedfdb9537ff9cf996135a6dcd55 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 12:52:17 -0700 Subject: [PATCH 07/10] allow warming for single instance --- google/cloud/bigtable/client.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index b9d632799..e9f90098e 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -143,7 +143,7 @@ def __init__( # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} # attempt to start background tasks - self._channel_init_time = time.time() + self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] try: self.start_background_channel_refresh() @@ -185,7 +185,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_tasks = [] async def _ping_and_warm_instances( - self, channel: grpc.aio.Channel + self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[GoogleAPICallError | None]: """ Prepares the backend for requests on a channel @@ -193,23 +193,27 @@ async def _ping_and_warm_instances( Pings each Bigtable instance registered in `_active_instances` on the client Args: - channel: grpc channel to ping + - channel: grpc channel to warm + - instance_key: if provided, only warm the instance associated with the key Returns: - sequence of results or exceptions from the ping requests """ + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) ping_rpc = channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) - tasks = [] - for (instance_name, table_name, app_profile_id) in self._active_instances: - tasks.append( - ping_rpc( - request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=_make_metadata(table_name, app_profile_id), - wait_for_ready=True, - ) - ) + # prepare list of coroutines to run + tasks = [ + ping_rpc( + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=_make_metadata(table_name, app_profile_id), + wait_for_ready=True, + ) for (instance_name, table_name, app_profile_id) in instance_list + ] + # execute coroutines in parallel result_list = await asyncio.gather(*tasks, return_exceptions=True) # return None in place of empty successful responses return [r or None for r in result_list] @@ -243,7 +247,7 @@ async def _manage_channel( first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) - next_sleep = max(first_refresh - time.time(), 0) + next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately channel = self.transport.channels[channel_idx] @@ -288,7 +292,7 @@ async def _register_instance(self, instance_id: str, owner: Table) -> None: # refresh tasks already running # call ping and warm on all existing channels for channel in self.transport.channels: - await self._ping_and_warm_instances(channel) + await self._ping_and_warm_instances(channel, instance_key) else: # refresh tasks aren't active. start them as background tasks self.start_background_channel_refresh() From 584f94df9dafb3f243896a0b900fc35b77fd5a3b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 12:56:34 -0700 Subject: [PATCH 08/10] updated tests --- google/cloud/bigtable/client.py | 3 +- tests/unit/test_client.py | 319 +++++++++++++++++++++++--------- 2 files changed, 234 insertions(+), 88 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index e9f90098e..326932f79 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -211,7 +211,8 @@ async def _ping_and_warm_instances( request={"name": instance_name, "app_profile_id": app_profile_id}, metadata=_make_metadata(table_name, app_profile_id), wait_for_ready=True, - ) for (instance_name, table_name, app_profile_id) in instance_list + ) + for (instance_name, table_name, app_profile_id) in instance_list ] # execute coroutines in parallel result_list = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 8f54cc4d0..9997205d7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -299,7 +299,9 @@ async def test__ping_and_warm_instances(self): assert not gather.call_args.args assert gather.call_args.kwargs == {"return_exceptions": True} # test with instances - client_mock._active_instances = [(mock.Mock(), mock.Mock(), mock.Mock())] * 4 + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 gather.reset_mock() channel.reset_mock() result = await self._get_target_class()._ping_and_warm_instances( @@ -312,14 +314,54 @@ async def test__ping_and_warm_instances(self): # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): - expected_instance, expected_table, expected_app_profile = client_mock._active_instances[idx] + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] request = kwargs["request"] assert request["name"] == expected_instance assert request["app_profile_id"] == expected_app_profile metadata = kwargs["metadata"] assert len(metadata) == 1 - assert metadata[0][0] == 'x-goog-request-params' - assert metadata[0][1] == f'table_name={expected_table},app_profile_id={expected_app_profile}' + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"table_name={expected_table},app_profile_id={expected_app_profile}" + ) + + @pytest.mark.asyncio + async def test_ping_and_warm_single_instance(self): + """ + should be able to call ping and warm with single instance + """ + client_mock = mock.Mock() + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + channel = mock.Mock() + # test with large set of instances + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + # should only have been called with test instance + assert len(result) == 1 + # check grpc call arguments + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == "table_name=test-table,app_profile_id=test-app-profile" + ) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -339,7 +381,7 @@ async def test__manage_channel_first_sleep( # first sleep time should be `refresh_interval` seconds after client init import time - with mock.patch.object(time, "time") as time: + with mock.patch.object(time, "monotonic") as time: time.return_value = 0 with mock.patch.object(asyncio, "sleep") as sleep: sleep.side_effect = asyncio.CancelledError @@ -358,46 +400,47 @@ async def test__manage_channel_first_sleep( @pytest.mark.asyncio async def test__manage_channel_ping_and_warm(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) + """ + _manage channel should call ping and warm internally + """ + import time + client_mock = mock.Mock() + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - client = self._make_one(project="project-id") - new_channel = grpc.aio.insecure_channel("localhost:8080") with mock.patch.object(asyncio, "sleep"): - create_channel = mock.Mock() - create_channel.return_value = new_channel - client.transport.grpc_channel._create_channel = create_channel - with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" - ) as replace_channel: - replace_channel.side_effect = asyncio.CancelledError - # should ping and warm old channel then new if sleep > 0 - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_and_warm: - try: - channel_idx = 2 - old_channel = client.transport._grpc_channel._pool[channel_idx] - await client._manage_channel(channel_idx, 10) - except asyncio.CancelledError: - pass - assert ping_and_warm.call_count == 2 - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - # should ping and warm instantly new channel only if not sleeping - with mock.patch.object( - type(self._make_one()), "_ping_and_warm_instances" - ) as ping_and_warm: - try: - await client._manage_channel(0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) - await client.close() + # stop process after replace_channel is called + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + # should ping and warm old channel then new if sleep > 0 + try: + channel_idx = 1 + await self._get_target_class()._manage_channel( + client_mock, channel_idx, 10 + ) + except asyncio.CancelledError: + pass + # should have called at loop start, and after replacement + assert ping_and_warm.call_count == 2 + # should have replaced channel once + assert client_mock.transport.replace_channel.call_count == 1 + # make sure new and old channels were warmed + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + # should ping and warm instantly new channel only if not sleeping + ping_and_warm.reset_mock() + try: + await self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -528,16 +571,24 @@ async def test__register_instance(self): client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners client_mock._channel_refresh_tasks = [] - client_mock.start_background_channel_refresh.side_effect = lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + client_mock.start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels client_mock._ping_and_warm_instances = AsyncMock() table_mock = mock.Mock() - await self._get_target_class()._register_instance(client_mock, "instance-1", table_mock) + await self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) # first call should start background refresh assert client_mock.start_background_channel_refresh.call_count == 1 # ensure active_instances and instance_owners were updated properly - expected_key = ("prefix/instance-1", table_mock.table_name, table_mock.app_profile_id) + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) assert len(active_instances) == 1 assert expected_key == tuple(list(active_instances)[0]) assert len(instance_owners) == 1 @@ -546,27 +597,55 @@ async def test__register_instance(self): assert client_mock._channel_refresh_tasks # # next call should not call start_background_channel_refresh again table_mock2 = mock.Mock() - await self._get_target_class()._register_instance(client_mock, "instance-2", table_mock2) + await self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) assert client_mock.start_background_channel_refresh.call_count == 1 # but it should call ping and warm with new instance key assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) for channel in mock_channels: - assert channel in [call[0][0] for call in client_mock._ping_and_warm_instances.call_args_list] + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] # check for updated lists assert len(active_instances) == 2 assert len(instance_owners) == 2 - expected_key2 = ("prefix/instance-2", table_mock2.table_name, table_mock2.app_profile_id) - assert any([expected_key2 == tuple(list(active_instances)[i]) for i in range(len(active_instances))]) - assert any([expected_key2 == tuple(list(instance_owners)[i]) for i in range(len(instance_owners))]) - - @pytest.mark.asyncio - @pytest.mark.parametrize("insert_instances,expected_active,expected_owner_keys", [ - ([('i','t',None)], [('i','t',None)], [('i','t',None)]), - ([('i','t','p')], [('i','t','p')], [('i','t','p')]), - ([('1','t','p'), ('1','t','p')], [('1','t','p')], [('1','t','p')]), - ([('1','t','p'), ('2','t','p')], [('1','t','p'), ('2','t','p')], [('1','t','p'), ('2','t','p')]), - ]) - async def test__register_instance_state(self, insert_instances, expected_active, expected_owner_keys): + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + async def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): """ test that active_instances and instance_owners are updated as expected """ @@ -578,7 +657,9 @@ async def test__register_instance_state(self, insert_instances, expected_active, client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners client_mock._channel_refresh_tasks = [] - client_mock.start_background_channel_refresh.side_effect = lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + client_mock.start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels client_mock._ping_and_warm_instances = AsyncMock() @@ -587,13 +668,25 @@ async def test__register_instance_state(self, insert_instances, expected_active, for instance, table, profile in insert_instances: table_mock.table_name = table table_mock.app_profile_id = profile - await self._get_target_class()._register_instance(client_mock, instance, table_mock) + await self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) assert len(active_instances) == len(expected_active) assert len(instance_owners) == len(expected_owner_keys) for expected in expected_active: - assert any([expected == tuple(list(active_instances)[i]) for i in range(len(active_instances))]) + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) for expected in expected_owner_keys: - assert any([expected == tuple(list(instance_owners)[i]) for i in range(len(instance_owners))]) + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) @pytest.mark.asyncio async def test__remove_instance_registration(self): @@ -628,58 +721,96 @@ async def test__remove_instance_registration(self): @pytest.mark.asyncio async def test__multiple_table_registration(self): + """ + registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances + """ + from google.cloud.bigtable.client import _WarmedInstanceKey + async with self._make_one(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) - assert len(client._instance_owners[instance_1_path]) == 1 + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_path] - async with client.get_table("instance_1", "table_2") as table_2: - assert len(client._instance_owners[instance_1_path]) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + # duplicate table should register in instance_owners under same key + async with client.get_table("instance_1", "table_1") as table_2: + assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_path] - assert id(table_2) in client._instance_owners[instance_1_path] - # table_2 should be unregistered, but instance should still be active + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + # unique table should register in instance_owners and active_instances + async with client.get_table("instance_1", "table_3") as table_3: + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + # sub-tables should be unregistered, but instance should still be active assert len(client._active_instances) == 1 - assert instance_1_path in client._active_instances - assert id(table_2) not in client._instance_owners[instance_1_path] + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] # both tables are gone. instance should be unregistered assert len(client._active_instances) == 0 - assert instance_1_path not in client._active_instances - assert len(client._instance_owners[instance_1_path]) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 @pytest.mark.asyncio async def test__multiple_instance_registration(self): + """ + registering with multiple instance keys should update the key + in instance_owners and active_instances + """ + from google.cloud.bigtable.client import _WarmedInstanceKey + async with self._make_one(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: async with client.get_table("instance_2", "table_2") as table_2: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) instance_2_path = client._gapic_client.instance_path( client.project, "instance_2" ) - assert len(client._instance_owners[instance_1_path]) == 1 - assert len(client._instance_owners[instance_2_path]) == 1 + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_path] - assert id(table_2) in client._instance_owners[instance_2_path] + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] # instance2 should be unregistered, but instance1 should still be active assert len(client._active_instances) == 1 - assert instance_1_path in client._active_instances - assert len(client._instance_owners[instance_2_path]) == 0 - assert len(client._instance_owners[instance_1_path]) == 1 - assert id(table_1) in client._instance_owners[instance_1_path] + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] # both tables are gone. instances should both be unregistered assert len(client._active_instances) == 0 - assert len(client._instance_owners[instance_1_path]) == 0 - assert len(client._instance_owners[instance_2_path]) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 @pytest.mark.asyncio async def test_get_table(self): from google.cloud.bigtable.client import Table + from google.cloud.bigtable.client import _WarmedInstanceKey client = self._make_one(project="project-id") assert not client._active_instances @@ -705,12 +836,17 @@ async def test_get_table(self): ) assert table.app_profile_id == expected_app_profile_id assert table.client is client - assert table.instance_name in client._active_instances + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} await client.close() @pytest.mark.asyncio async def test_get_table_context_manager(self): from google.cloud.bigtable.client import Table + from google.cloud.bigtable.client import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -738,7 +874,11 @@ async def test_get_table_context_manager(self): ) assert table.app_profile_id == expected_app_profile_id assert table.client is client - assert table.instance_name in client._active_instances + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 @pytest.mark.asyncio @@ -829,6 +969,7 @@ class TestTable: async def test_table_ctor(self): from google.cloud.bigtable.client import BigtableDataClient from google.cloud.bigtable.client import Table + from google.cloud.bigtable.client import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -851,7 +992,11 @@ async def test_table_ctor(self): assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id assert table.client is client - assert table.instance_name in client._active_instances + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} assert table.default_operation_timeout == expected_operation_timeout assert table.default_per_request_timeout == expected_per_request_timeout # ensure task reaches completion From 3c605ba6d98b01ca2020036693822246ef28efea Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 22 Jun 2023 13:00:37 -0700 Subject: [PATCH 09/10] fixed metadata format --- google/cloud/bigtable/_helpers.py | 2 +- tests/unit/test__helpers.py | 2 +- tests/unit/test_client.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/_helpers.py b/google/cloud/bigtable/_helpers.py index dec4c2014..722fac9f4 100644 --- a/google/cloud/bigtable/_helpers.py +++ b/google/cloud/bigtable/_helpers.py @@ -35,7 +35,7 @@ def _make_metadata( params.append(f"table_name={table_name}") if app_profile_id is not None: params.append(f"app_profile_id={app_profile_id}") - params_str = ",".join(params) + params_str = "&".join(params) return [("x-goog-request-params", params_str)] diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 2765afe24..9aa1a7bb4 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -23,7 +23,7 @@ class TestMakeMetadata: @pytest.mark.parametrize( "table,profile,expected", [ - ("table", "profile", "table_name=table,app_profile_id=profile"), + ("table", "profile", "table_name=table&app_profile_id=profile"), ("table", None, "table_name=table"), ], ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 9997205d7..75adffc72 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -327,7 +327,7 @@ async def test__ping_and_warm_instances(self): assert metadata[0][0] == "x-goog-request-params" assert ( metadata[0][1] - == f"table_name={expected_table},app_profile_id={expected_app_profile}" + == f"table_name={expected_table}&app_profile_id={expected_app_profile}" ) @pytest.mark.asyncio @@ -360,7 +360,7 @@ async def test_ping_and_warm_single_instance(self): assert metadata[0][0] == "x-goog-request-params" assert ( metadata[0][1] - == "table_name=test-table,app_profile_id=test-app-profile" + == "table_name=test-table&app_profile_id=test-app-profile" ) @pytest.mark.asyncio From 0ce16c61ab4c38bc4305462ef6429f4b6b19ff82 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 23 Jun 2023 16:04:23 -0700 Subject: [PATCH 10/10] fixed metadata --- google/cloud/bigtable/client.py | 7 ++++++- tests/unit/test_client.py | 7 +++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index a01bc0e61..3d33eebf9 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -217,7 +217,12 @@ async def _ping_and_warm_instances( tasks = [ ping_rpc( request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=_make_metadata(table_name, app_profile_id), + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], wait_for_ready=True, ) for (instance_name, table_name, app_profile_id) in instance_list diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 10711ce52..805a6340d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -327,11 +327,11 @@ async def test__ping_and_warm_instances(self): assert metadata[0][0] == "x-goog-request-params" assert ( metadata[0][1] - == f"table_name={expected_table}&app_profile_id={expected_app_profile}" + == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) @pytest.mark.asyncio - async def test_ping_and_warm_single_instance(self): + async def test__ping_and_warm_single_instance(self): """ should be able to call ping and warm with single instance """ @@ -359,8 +359,7 @@ async def test_ping_and_warm_single_instance(self): assert len(metadata) == 1 assert metadata[0][0] == "x-goog-request-params" assert ( - metadata[0][1] - == "table_name=test-table&app_profile_id=test-app-profile" + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) @pytest.mark.asyncio