Skip to content

Commit

Permalink
apns: document public API functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JJTech0130 committed May 23, 2024
1 parent 51b04f8 commit 29c66bd
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 4 deletions.
7 changes: 7 additions & 0 deletions pypush/apns/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ async def activate(
build: str = "10.6.4",
model: str = "windows1,1",
) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]:
"""
Activate with Apple's Albert service, obtaining an activation certificate and private key.
By default, this will activate a Windows device with a random UDID, serial, version, build, and model.
Windows activations will not function for iMessage or FaceTime.
"""
if http_client is None:
# Do this here to ensure the client is not accidentally reused during tests
http_client = httpx.AsyncClient()
Expand Down
57 changes: 53 additions & 4 deletions pypush/apns/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ async def create_apns_connection(
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
"""
This context manager will create a connection to the APNs server and yield a Connection object.
Args:
certificate (x509.Certificate): A valid activation certificate obtained from Albert.
private_key (rsa.RSAPrivateKey): The private key corresponding to the activation certificate.
token (bytes, optional): An optional base token to use for the connection. If not provided, the connection will be established and the base token will be set to the token provided in the ConnectAck command.
sandbox (bool, optional): A boolean indicating whether to connect to the APNs sandbox or production server.
courier (str, optional): An optional string indicating the courier server to connect to. If not provided, a random courier server will be selected based on the `sandbox` parameter.
"""
async with anyio.create_task_group() as tg:
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
Expand Down Expand Up @@ -71,11 +81,16 @@ def __init__(
logging.debug(f"Using courier: {courier}")
self.courier = courier

self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._reconnect)
self._tg.start_soon(self._ping_task)

@property
async def base_token(self) -> bytes:
"""
`base_token` must be awaited to ensure a token is available
This may not complete until a connection has been established
"""
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
Expand All @@ -98,7 +113,7 @@ async def _ping_task(self):
) # Explicitly disable the backlog since we don't want to receive old acks

@_util.exponential_backoff
async def reconnect(self):
async def _reconnect(self):
async with (
self._reconnect_lock
): # Prevent weird situations where multiple reconnects are happening at once
Expand Down Expand Up @@ -161,9 +176,15 @@ async def reconnect(self):
await self._update_filter()

async def aclose(self):
"""
Closes the connection to the APNS server.
If the connection is open, it will be closed. This method is typically unnecessary if the connection is managed by `create_apns_connection`.
Note: The connection will be reopened if the task group is still open (the ping task is still running).
"""
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running

T = typing.TypeVar("T")

Expand Down Expand Up @@ -191,7 +212,7 @@ async def _send(self, command: protocol.Command):
await self._conn.send(command)
except Exception:
logging.warning("Error sending command, reconnecting")
await self.reconnect()
await self._reconnect()
await self._send(command)

async def _update_filter(self):
Expand All @@ -207,6 +228,8 @@ async def _update_filter(self):
@asynccontextmanager
async def _filter(self, topics: list[str]):
for topic in topics:
if topic not in protocol.KNOWN_TOPICS:
protocol.note_topic(topic)
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
yield
Expand All @@ -217,6 +240,14 @@ async def _filter(self, topics: list[str]):
await self._update_filter()

async def mint_scoped_token(self, topic: str) -> bytes:
"""
Mint a "scoped token" for the given topic/bundle ID.
This token is equivalent to the token provided to `application:didRegisterForRemoteNotificationsWithDeviceToken:` in iOS,
for an app with the given bundle ID.
This token can be used with `expect_notification` or `notification_stream`, but it will only function on connections with the same base token as the connection that originally minted the token.
"""
topic_hash = sha1(topic.encode()).digest()
await self._send(
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
Expand All @@ -234,6 +265,14 @@ async def notification_stream(
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
):
"""
Create a stream of notifications for the given topic and token.
If the token is not provided, the base token will be used.
A custom `Filter` can be provided to filter out unwanted notifications.
Notifications will NOT be ack'd automatically, you must call `ack` on each notification you process.
"""
if token is None:
token = await self.base_token
async with (
Expand All @@ -254,6 +293,9 @@ async def notification_stream(
yield stream

async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
"""
Acknowledge a notification.
"""
await self._send(
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
)
Expand All @@ -266,6 +308,13 @@ async def expect_notification(
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
) -> protocol.SendMessageCommand:
"""
Wait for a notification that matches the given topic and token.
If the token is not provided, the base token will be used.
A custom `Filter` can be provided to filter out unwanted notifications.
This method WILL ack the notification automatically."""
async with self.notification_stream(topic, token, filter) as stream:
command = await stream.receive()
await self.ack(command)
Expand Down
17 changes: 17 additions & 0 deletions pypush/apns/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
# fmt: on


def note_topic(topic: str):
"""
Add a topic to the KNOWN_TOPICS set, such that it can be recognized later.
This is mostly just a convenience, so that you do not have to work with SHA1 hashes directly.
"""
KNOWN_TOPICS.add(topic)
KNOWN_TOPICS_LOOKUP[sha1(topic.encode()).digest()] = topic


@dataclass
class Command:
@classmethod
Expand Down Expand Up @@ -148,6 +157,14 @@ class SetStateCommand(Command):
@command
@dataclass
class SendMessageCommand(Command):
"""
The most common form of command, used to send a message, also represents incoming messages.
May also be called a "Notification" in the context of APNs.
Important note: The `topic` field may be a string or bytes, depending on if the topic is in the `KNOWN_TOPICS` set. This should not happen if you are using the proper `Connection` API.
"""

PacketType = Packet.Type.SendMessage

payload: bytes = fid(3)
Expand Down
1 change: 1 addition & 0 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def test_lifecycle_2():
ASSETS_DIR = Path(__file__).parent / "assets"


# Not a part of pypush, this is public API
async def send_test_notification(device_token, payload=b"hello, world"):
async with httpx.AsyncClient(
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
Expand Down

0 comments on commit 29c66bd

Please sign in to comment.