Skip to content

Commit

Permalink
feat: ping and warm with metadata (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Jun 26, 2023
1 parent ec2b983 commit ceaf598
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 367 deletions.
2 changes: 1 addition & 1 deletion google/cloud/bigtable/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _make_metadata(
params.append(f"table_name={table_name}")
if app_profile_id is not None:
params.append(f"app_profile_id={app_profile_id}")
params_str = ",".join(params)
params_str = "&".join(params)
return [("x-goog-request-params", params_str)]


Expand Down
59 changes: 45 additions & 14 deletions google/cloud/bigtable/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import sys
import random

from collections import namedtuple

from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta
from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient
from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO
Expand Down Expand Up @@ -74,6 +76,11 @@
# used by read_rows_sharded to limit how many requests are attempted in parallel
CONCURRENCY_LIMIT = 10

# used to register instance data with the client for channel warming
_WarmedInstanceKey = namedtuple(
"_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"]
)


class BigtableDataClient(ClientWithProject):
def __init__(
Expand Down Expand Up @@ -139,12 +146,12 @@ def __init__(
PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport
)
# keep track of active instances to for warmup on channel refresh
self._active_instances: Set[str] = set()
self._active_instances: Set[_WarmedInstanceKey] = set()
# keep track of table objects associated with each instance
# only remove instance from _active_instances when all associated tables remove it
self._instance_owners: dict[str, Set[int]] = {}
self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {}
# attempt to start background tasks
self._channel_init_time = time.time()
self._channel_init_time = time.monotonic()
self._channel_refresh_tasks: list[asyncio.Task[None]] = []
try:
self.start_background_channel_refresh()
Expand Down Expand Up @@ -186,26 +193,44 @@ async def close(self, timeout: float = 2.0):
self._channel_refresh_tasks = []

async def _ping_and_warm_instances(
self, channel: grpc.aio.Channel
self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None
) -> list[GoogleAPICallError | None]:
"""
Prepares the backend for requests on a channel
Pings each Bigtable instance registered in `_active_instances` on the client
Args:
channel: grpc channel to ping
- channel: grpc channel to warm
- instance_key: if provided, only warm the instance associated with the key
Returns:
- sequence of results or exceptions from the ping requests
"""
instance_list = (
[instance_key] if instance_key is not None else self._active_instances
)
ping_rpc = channel.unary_unary(
"/google.bigtable.v2.Bigtable/PingAndWarm",
request_serializer=PingAndWarmRequest.serialize,
)
tasks = [ping_rpc({"name": n}) for n in self._active_instances]
result = await asyncio.gather(*tasks, return_exceptions=True)
# prepare list of coroutines to run
tasks = [
ping_rpc(
request={"name": instance_name, "app_profile_id": app_profile_id},
metadata=[
(
"x-goog-request-params",
f"name={instance_name}&app_profile_id={app_profile_id}",
)
],
wait_for_ready=True,
)
for (instance_name, table_name, app_profile_id) in instance_list
]
# execute coroutines in parallel
result_list = await asyncio.gather(*tasks, return_exceptions=True)
# return None in place of empty successful responses
return [r or None for r in result]
return [r or None for r in result_list]

async def _manage_channel(
self,
Expand Down Expand Up @@ -236,7 +261,7 @@ async def _manage_channel(
first_refresh = self._channel_init_time + random.uniform(
refresh_interval_min, refresh_interval_max
)
next_sleep = max(first_refresh - time.time(), 0)
next_sleep = max(first_refresh - time.monotonic(), 0)
if next_sleep > 0:
# warm the current channel immediately
channel = self.transport.channels[channel_idx]
Expand Down Expand Up @@ -271,14 +296,17 @@ async def _register_instance(self, instance_id: str, owner: Table) -> None:
owners call _remove_instance_registration
"""
instance_name = self._gapic_client.instance_path(self.project, instance_id)
self._instance_owners.setdefault(instance_name, set()).add(id(owner))
instance_key = _WarmedInstanceKey(
instance_name, owner.table_name, owner.app_profile_id
)
self._instance_owners.setdefault(instance_key, set()).add(id(owner))
if instance_name not in self._active_instances:
self._active_instances.add(instance_name)
self._active_instances.add(instance_key)
if self._channel_refresh_tasks:
# refresh tasks already running
# call ping and warm on all existing channels
for channel in self.transport.channels:
await self._ping_and_warm_instances(channel)
await self._ping_and_warm_instances(channel, instance_key)
else:
# refresh tasks aren't active. start them as background tasks
self.start_background_channel_refresh()
Expand All @@ -301,11 +329,14 @@ async def _remove_instance_registration(
- True if instance was removed
"""
instance_name = self._gapic_client.instance_path(self.project, instance_id)
owner_list = self._instance_owners.get(instance_name, set())
instance_key = _WarmedInstanceKey(
instance_name, owner.table_name, owner.app_profile_id
)
owner_list = self._instance_owners.get(instance_key, set())
try:
owner_list.remove(id(owner))
if len(owner_list) == 0:
self._active_instances.remove(instance_name)
self._active_instances.remove(instance_key)
return True
except KeyError:
return False
Expand Down
1 change: 1 addition & 0 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ async def test_ping_and_warm(client, table):
assert results[0] is None


@retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
@pytest.mark.asyncio
async def test_mutation_set_cell(table, temp_rows):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TestMakeMetadata:
@pytest.mark.parametrize(
"table,profile,expected",
[
("table", "profile", "table_name=table,app_profile_id=profile"),
("table", "profile", "table_name=table&app_profile_id=profile"),
("table", None, "table_name=table"),
],
)
Expand Down
Loading

0 comments on commit ceaf598

Please sign in to comment.