Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite: Minor refactors and documentation #104

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 1 addition & 28 deletions pypush/apns/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import Generic, TypeVar

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from . import filters
from anyio.abc import ObjectSendStream

T = TypeVar("T")

Expand Down Expand Up @@ -44,31 +42,6 @@ async def open_stream(self, backlog: bool = True):
await send.aclose()


W = TypeVar("W")
F = TypeVar("F")


class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items

filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""

def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter

async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream

async def aclose(self):
await self.source.aclose()


def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1
Expand Down
7 changes: 7 additions & 0 deletions pypush/apns/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ async def activate(
build: str = "10.6.4",
model: str = "windows1,1",
) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]:
"""
Activate with Apple's Albert service, obtaining an activation certificate and private key.

By default, this will activate a Windows device with a random UDID, serial, version, build, and model.

Windows activations will not function for iMessage or FaceTime.
"""
if http_client is None:
# Do this here to ensure the client is not accidentally reused during tests
http_client = httpx.AsyncClient()
Expand Down
28 changes: 28 additions & 0 deletions pypush/apns/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from typing import Callable, Optional, Type, TypeVar

import anyio
from anyio.abc import ObjectReceiveStream

from pypush.apns import protocol

T1 = TypeVar("T1")
Expand Down Expand Up @@ -42,3 +45,28 @@ def ALL(c):

def NONE(_):
return None


W = TypeVar("W")
F = TypeVar("F")


class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items

filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""

def __init__(self, source: ObjectReceiveStream[W], filter: Filter[W, F]):
self.source = source
self.filter = filter

async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream

async def aclose(self):
await self.source.aclose()
84 changes: 69 additions & 15 deletions pypush/apns/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ async def create_apns_connection(
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
"""
This context manager will create a connection to the APNs server and yield a Connection object.

Args:
certificate (x509.Certificate): A valid activation certificate obtained from Albert.
private_key (rsa.RSAPrivateKey): The private key corresponding to the activation certificate.
token (bytes, optional): An optional base token to use for the connection. If not provided, the connection will be established and the base token will be set to the token provided in the ConnectAck command.
sandbox (bool, optional): A boolean indicating whether to connect to the APNs sandbox or production server.
courier (str, optional): An optional string indicating the courier server to connect to. If not provided, a random courier server will be selected based on the `sandbox` parameter.
"""
async with anyio.create_task_group() as tg:
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
Expand Down Expand Up @@ -71,11 +81,16 @@ def __init__(
logging.debug(f"Using courier: {courier}")
self.courier = courier

self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._reconnect)
self._tg.start_soon(self._ping_task)

@property
async def base_token(self) -> bytes:
"""
`base_token` must be awaited to ensure a token is available

This may not complete until a connection has been established
"""
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
Expand All @@ -98,8 +113,10 @@ async def _ping_task(self):
) # Explicitly disable the backlog since we don't want to receive old acks

@_util.exponential_backoff
async def reconnect(self):
async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once
async def _reconnect(self):
async with (
self._reconnect_lock
): # Prevent weird situations where multiple reconnects are happening at once
if self._conn is not None:
logging.warning("Closing existing connection")
await self._conn.aclose()
Expand Down Expand Up @@ -159,9 +176,15 @@ async def reconnect(self):
await self._update_filter()

async def aclose(self):
"""
Closes the connection to the APNS server.

If the connection is open, it will be closed. This method is typically unnecessary if the connection is managed by `create_apns_connection`.

Note: The connection will be reopened if the task group is still open (the ping task is still running).
"""
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running

T = typing.TypeVar("T")

Expand All @@ -172,7 +195,7 @@ async def _receive_stream(
backlog: bool = True,
):
async with self._broadcast.open_stream(backlog) as stream:
yield _util.FilteredStream(stream, filter)
yield filters.FilteredStream(stream, filter)

async def _receive(
self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
Expand All @@ -189,7 +212,7 @@ async def _send(self, command: protocol.Command):
await self._conn.send(command)
except Exception:
logging.warning("Error sending command, reconnecting")
await self.reconnect()
await self._reconnect()
await self._send(command)

async def _update_filter(self):
Expand All @@ -205,6 +228,8 @@ async def _update_filter(self):
@asynccontextmanager
async def _filter(self, topics: list[str]):
for topic in topics:
if topic not in protocol.KNOWN_TOPICS:
protocol.note_topic(topic)
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
yield
Expand All @@ -215,6 +240,14 @@ async def _filter(self, topics: list[str]):
await self._update_filter()

async def mint_scoped_token(self, topic: str) -> bytes:
"""
Mint a "scoped token" for the given topic/bundle ID.

This token is equivalent to the token provided to `application:didRegisterForRemoteNotificationsWithDeviceToken:` in iOS,
for an app with the given bundle ID.

This token can be used with `expect_notification` or `notification_stream`, but it will only function on connections with the same base token as the connection that originally minted the token.
"""
topic_hash = sha1(topic.encode()).digest()
await self._send(
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
Expand All @@ -232,23 +265,37 @@ async def notification_stream(
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
):
"""
Create a stream of notifications for the given topic and token.
If the token is not provided, the base token will be used.

A custom `Filter` can be provided to filter out unwanted notifications.

Notifications will NOT be ack'd automatically, you must call `ack` on each notification you process.
"""
if token is None:
token = await self.base_token
async with self._filter([topic]), self._receive_stream(
filters.chain(
async with (
self._filter([topic]),
self._receive_stream(
filters.chain(
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
),
lambda c: (c if c.topic == topic else None),
),
lambda c: (c if c.topic == topic else None),
),
filter,
)
) as stream:
filter,
)
) as stream,
):
yield stream

async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
"""
Acknowledge a notification.
"""
await self._send(
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
)
Expand All @@ -261,6 +308,13 @@ async def expect_notification(
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
) -> protocol.SendMessageCommand:
"""
Wait for a notification that matches the given topic and token.
If the token is not provided, the base token will be used.

A custom `Filter` can be provided to filter out unwanted notifications.

This method WILL ack the notification automatically."""
async with self.notification_stream(topic, token, filter) as stream:
command = await stream.receive()
await self.ack(command)
Expand Down
17 changes: 17 additions & 0 deletions pypush/apns/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
# fmt: on


def note_topic(topic: str):
"""
Add a topic to the KNOWN_TOPICS set, such that it can be recognized later.
This is mostly just a convenience, so that you do not have to work with SHA1 hashes directly.
"""
KNOWN_TOPICS.add(topic)
KNOWN_TOPICS_LOOKUP[sha1(topic.encode()).digest()] = topic


@dataclass
class Command:
@classmethod
Expand Down Expand Up @@ -148,6 +157,14 @@ class SetStateCommand(Command):
@command
@dataclass
class SendMessageCommand(Command):
"""
The most common form of command, used to send a message, also represents incoming messages.

May also be called a "Notification" in the context of APNs.

Important note: The `topic` field may be a string or bytes, depending on if the topic is in the `KNOWN_TOPICS` set. This should not happen if you are using the proper `Connection` API.
"""

PacketType = Packet.Type.SendMessage

payload: bytes = fid(3)
Expand Down
1 change: 1 addition & 0 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def test_lifecycle_2():
ASSETS_DIR = Path(__file__).parent / "assets"


# Not a part of pypush, this is public API
async def send_test_notification(device_token, payload=b"hello, world"):
async with httpx.AsyncClient(
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
Expand Down
Loading