Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit timeout for subscriber creation #1029

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions google_nest_sdm/google_nest_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@

MESSAGE_ACK_TIMEOUT_SECONDS = 30.0

NEW_SUBSCRIBER_TIMEOUT_SECONDS = 30.0

# Note: Users of non-prod instances will have to manually configure a topic
TOPIC_FORMAT = "projects/sdm-prod/topics/enterprise-{project_id}"

Expand Down Expand Up @@ -294,14 +296,17 @@ def callback_wrapper(message: pubsub_v1.subscriber.message.Message) -> None:
callback_wrapper,
)


def _new_subscriber(
self,
creds: Credentials,
subscription_name: str,
callback_wrapper: Callable[[pubsub_v1.subscriber.message.Message], None],
) -> pubsub_v1.subscriber.futures.StreamingPullFuture:
"""Issue a command to verify subscriber creds are correct."""
_LOGGER.debug("Creating subscriber '%s'", subscription_name)
creds = refresh_creds(creds)
_LOGGER.debug("Subscriber credentials refreshed")
subscriber = pubsub_v1.SubscriberClient(credentials=creds)
subscription = subscriber.get_subscription(subscription=subscription_name)
if subscription.topic:
Expand All @@ -311,6 +316,7 @@ def _new_subscriber(
subscription_name,
subscription.topic,
)
_LOGGER.debug("Starting subscriber '%s'", subscription_name)
return subscriber.subscribe(subscription_name, callback_wrapper)


Expand Down Expand Up @@ -432,11 +438,17 @@ async def start_async(self) -> None:
raise AuthException(f"Access token failure: {err}") from err

try:
self._subscriber_future = (
await self._subscriber_factory.async_new_subscriber(
creds, self._subscriber_id, self._loop, self._async_message_callback_with_timeout
async with asyncio.timeout(NEW_SUBSCRIBER_TIMEOUT_SECONDS):
self._subscriber_future = (
await self._subscriber_factory.async_new_subscriber(
creds, self._subscriber_id, self._loop, self._async_message_callback_with_timeout
)
)
)
except asyncio.TimeoutError as err:
DIAGNOSTICS.increment("start.timeout_error")
raise SubscriberException(
f"Failed to create subscriber '{self._subscriber_id}': {err}"
) from err
except NotFound as err:
DIAGNOSTICS.increment("start.not_found_error")
raise ConfigurationException(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = google_nest_sdm
version = 4.0.4
version = 4.0.5
description = Library for the Google Nest SDM API
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down
36 changes: 36 additions & 0 deletions tests/test_google_nest_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,42 @@ async def task2() -> None:
subscriber.stop_async()


async def test_subscriber_timeout(
app: aiohttp.web.Application,
device_handler: DeviceHandler,
structure_handler: StructureHandler,
subscriber_client: Callable[
[Optional[AbstractSubscriberFactory]], Awaitable[GoogleNestSubscriber]
],
) -> None:
class FailingFactory(FakeSubscriberFactory):
async def async_new_subscriber(
self,
creds: Credentials,
subscription_name: str,
loop: asyncio.AbstractEventLoop,
async_callback: Callable[
[pubsub_v1.subscriber.message.Message], Awaitable[None]
],
) -> pubsub_v1.subscriber.futures.StreamingPullFuture:
raise asyncio.TimeoutError("Some error")

subscriber = await subscriber_client(FailingFactory())

with pytest.raises(SubscriberException):
await subscriber.start_async()
subscriber.stop_async()

assert_diagnostics(
diagnostics.get_diagnostics(),
{
"subscriber": {
"start": 1,
"start.timeout_error": 1,
"stop": 1,
},
},
)
async def test_subscriber_error(
app: aiohttp.web.Application,
device_handler: DeviceHandler,
Expand Down