From 51b04f816db62c21083c310df0b25c83fc785977 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Thu, 23 May 2024 08:29:51 -0400 Subject: [PATCH 1/2] apns: filters: move FilteredStream --- pypush/apns/_util.py | 29 +---------------------------- pypush/apns/filters.py | 28 ++++++++++++++++++++++++++++ pypush/apns/lifecycle.py | 27 ++++++++++++++++----------- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 3564892..77e5d35 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -3,9 +3,7 @@ from typing import Generic, TypeVar import anyio -from anyio.abc import ObjectReceiveStream, ObjectSendStream - -from . import filters +from anyio.abc import ObjectSendStream T = TypeVar("T") @@ -44,31 +42,6 @@ async def open_stream(self, backlog: bool = True): await send.aclose() -W = TypeVar("W") -F = TypeVar("F") - - -class FilteredStream(ObjectReceiveStream[F]): - """ - A stream that filters out unwanted items - - filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it - """ - - def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]): - self.source = source - self.filter = filter - - async def receive(self) -> F: - async for item in self.source: - if (filtered := self.filter(item)) is not None: - return filtered - raise anyio.EndOfStream - - async def aclose(self): - await self.source.aclose() - - def exponential_backoff(f): async def wrapper(*args, **kwargs): backoff = 1 diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py index 63bb784..66ba48e 100644 --- a/pypush/apns/filters.py +++ b/pypush/apns/filters.py @@ -1,6 +1,9 @@ import logging from typing import Callable, Optional, Type, TypeVar +import anyio +from anyio.abc import ObjectReceiveStream + from pypush.apns import protocol T1 = TypeVar("T1") @@ -42,3 +45,28 @@ def ALL(c): def NONE(_): return None + + +W = TypeVar("W") +F = TypeVar("F") + + +class FilteredStream(ObjectReceiveStream[F]): + """ + A stream that filters out unwanted items + + filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it + """ + + def __init__(self, source: ObjectReceiveStream[W], filter: Filter[W, F]): + self.source = source + self.filter = filter + + async def receive(self) -> F: + async for item in self.source: + if (filtered := self.filter(item)) is not None: + return filtered + raise anyio.EndOfStream + + async def aclose(self): + await self.source.aclose() diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 23d3f94..8b97a9c 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -99,7 +99,9 @@ async def _ping_task(self): @_util.exponential_backoff async def reconnect(self): - async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once + async with ( + self._reconnect_lock + ): # Prevent weird situations where multiple reconnects are happening at once if self._conn is not None: logging.warning("Closing existing connection") await self._conn.aclose() @@ -172,7 +174,7 @@ async def _receive_stream( backlog: bool = True, ): async with self._broadcast.open_stream(backlog) as stream: - yield _util.FilteredStream(stream, filter) + yield filters.FilteredStream(stream, filter) async def _receive( self, filter: filters.Filter[protocol.Command, T], backlog: bool = True @@ -234,18 +236,21 @@ async def notification_stream( ): if token is None: token = await self.base_token - async with self._filter([topic]), self._receive_stream( - filters.chain( + async with ( + self._filter([topic]), + self._receive_stream( filters.chain( filters.chain( - filters.cmd(protocol.SendMessageCommand), - lambda c: c if c.token == token else None, + filters.chain( + filters.cmd(protocol.SendMessageCommand), + lambda c: c if c.token == token else None, + ), + lambda c: (c if c.topic == topic else None), ), - lambda c: (c if c.topic == topic else None), - ), - filter, - ) - ) as stream: + filter, + ) + ) as stream, + ): yield stream async def ack(self, command: protocol.SendMessageCommand, status: int = 0): From 29c66bde667669b26e3a2e322ddcb2bc45db0426 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Thu, 23 May 2024 09:26:27 -0400 Subject: [PATCH 2/2] apns: document public API functions --- pypush/apns/albert.py | 7 +++++ pypush/apns/lifecycle.py | 57 +++++++++++++++++++++++++++++++++++++--- pypush/apns/protocol.py | 17 ++++++++++++ tests/test_apns.py | 1 + 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py index 3706807..24a2ad6 100644 --- a/pypush/apns/albert.py +++ b/pypush/apns/albert.py @@ -50,6 +50,13 @@ async def activate( build: str = "10.6.4", model: str = "windows1,1", ) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]: + """ + Activate with Apple's Albert service, obtaining an activation certificate and private key. + + By default, this will activate a Windows device with a random UDID, serial, version, build, and model. + + Windows activations will not function for iMessage or FaceTime. + """ if http_client is None: # Do this here to ensure the client is not accidentally reused during tests http_client = httpx.AsyncClient() diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 8b97a9c..77917e2 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -25,6 +25,16 @@ async def create_apns_connection( sandbox: bool = False, courier: typing.Optional[str] = None, ): + """ + This context manager will create a connection to the APNs server and yield a Connection object. + + Args: + certificate (x509.Certificate): A valid activation certificate obtained from Albert. + private_key (rsa.RSAPrivateKey): The private key corresponding to the activation certificate. + token (bytes, optional): An optional base token to use for the connection. If not provided, the connection will be established and the base token will be set to the token provided in the ConnectAck command. + sandbox (bool, optional): A boolean indicating whether to connect to the APNs sandbox or production server. + courier (str, optional): An optional string indicating the courier server to connect to. If not provided, a random courier server will be selected based on the `sandbox` parameter. + """ async with anyio.create_task_group() as tg: conn = Connection( tg, certificate, private_key, token, sandbox, courier @@ -71,11 +81,16 @@ def __init__( logging.debug(f"Using courier: {courier}") self.courier = courier - self._tg.start_soon(self.reconnect) + self._tg.start_soon(self._reconnect) self._tg.start_soon(self._ping_task) @property async def base_token(self) -> bytes: + """ + `base_token` must be awaited to ensure a token is available + + This may not complete until a connection has been established + """ if self._base_token is None: await self._connected.wait() assert self._base_token is not None @@ -98,7 +113,7 @@ async def _ping_task(self): ) # Explicitly disable the backlog since we don't want to receive old acks @_util.exponential_backoff - async def reconnect(self): + async def _reconnect(self): async with ( self._reconnect_lock ): # Prevent weird situations where multiple reconnects are happening at once @@ -161,9 +176,15 @@ async def reconnect(self): await self._update_filter() async def aclose(self): + """ + Closes the connection to the APNS server. + + If the connection is open, it will be closed. This method is typically unnecessary if the connection is managed by `create_apns_connection`. + + Note: The connection will be reopened if the task group is still open (the ping task is still running). + """ if self._conn is not None: await self._conn.aclose() - # Note: Will be reopened if task group is still running and ping task is still running T = typing.TypeVar("T") @@ -191,7 +212,7 @@ async def _send(self, command: protocol.Command): await self._conn.send(command) except Exception: logging.warning("Error sending command, reconnecting") - await self.reconnect() + await self._reconnect() await self._send(command) async def _update_filter(self): @@ -207,6 +228,8 @@ async def _update_filter(self): @asynccontextmanager async def _filter(self, topics: list[str]): for topic in topics: + if topic not in protocol.KNOWN_TOPICS: + protocol.note_topic(topic) self._filters[topic] = self._filters.get(topic, 0) + 1 await self._update_filter() yield @@ -217,6 +240,14 @@ async def _filter(self, topics: list[str]): await self._update_filter() async def mint_scoped_token(self, topic: str) -> bytes: + """ + Mint a "scoped token" for the given topic/bundle ID. + + This token is equivalent to the token provided to `application:didRegisterForRemoteNotificationsWithDeviceToken:` in iOS, + for an app with the given bundle ID. + + This token can be used with `expect_notification` or `notification_stream`, but it will only function on connections with the same base token as the connection that originally minted the token. + """ topic_hash = sha1(topic.encode()).digest() await self._send( protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash) @@ -234,6 +265,14 @@ async def notification_stream( protocol.SendMessageCommand, protocol.SendMessageCommand ] = filters.ALL, ): + """ + Create a stream of notifications for the given topic and token. + If the token is not provided, the base token will be used. + + A custom `Filter` can be provided to filter out unwanted notifications. + + Notifications will NOT be ack'd automatically, you must call `ack` on each notification you process. + """ if token is None: token = await self.base_token async with ( @@ -254,6 +293,9 @@ async def notification_stream( yield stream async def ack(self, command: protocol.SendMessageCommand, status: int = 0): + """ + Acknowledge a notification. + """ await self._send( protocol.SendMessageAck(status=status, token=command.token, id=command.id) ) @@ -266,6 +308,13 @@ async def expect_notification( protocol.SendMessageCommand, protocol.SendMessageCommand ] = filters.ALL, ) -> protocol.SendMessageCommand: + """ + Wait for a notification that matches the given topic and token. + If the token is not provided, the base token will be used. + + A custom `Filter` can be provided to filter out unwanted notifications. + + This method WILL ack the notification automatically.""" async with self.notification_stream(topic, token, filter) as stream: command = await stream.receive() await self.ack(command) diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index 147119c..75ce939 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -13,6 +13,15 @@ # fmt: on +def note_topic(topic: str): + """ + Add a topic to the KNOWN_TOPICS set, such that it can be recognized later. + This is mostly just a convenience, so that you do not have to work with SHA1 hashes directly. + """ + KNOWN_TOPICS.add(topic) + KNOWN_TOPICS_LOOKUP[sha1(topic.encode()).digest()] = topic + + @dataclass class Command: @classmethod @@ -148,6 +157,14 @@ class SetStateCommand(Command): @command @dataclass class SendMessageCommand(Command): + """ + The most common form of command, used to send a message, also represents incoming messages. + + May also be called a "Notification" in the context of APNs. + + Important note: The `topic` field may be a string or bytes, depending on if the topic is in the `KNOWN_TOPICS` set. This should not happen if you are using the proper `Connection` API. + """ + PacketType = Packet.Type.SendMessage payload: bytes = fid(3) diff --git a/tests/test_apns.py b/tests/test_apns.py index 501e862..46313c2 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -28,6 +28,7 @@ async def test_lifecycle_2(): ASSETS_DIR = Path(__file__).parent / "assets" +# Not a part of pypush, this is public API async def send_test_notification(device_token, payload=b"hello, world"): async with httpx.AsyncClient( cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True