Skip to content

Commit

Permalink
Rewrite: APNs: Scoped App Tokens (#101)
Browse files Browse the repository at this point in the history
+ Adds a lot of new API surface around filters
+ Adds CI type checking and linting
  • Loading branch information
JJTech0130 authored May 19, 2024
1 parent b1c30a9 commit ce88a8e
Show file tree
Hide file tree
Showing 17 changed files with 495 additions and 111 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Pyright
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
cache: 'pip'

- run: |
python -m venv .venv
source .venv/bin/activate
pip install -e '.[test,cli]'
- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
- uses: jakebailey/pyright-action@v2
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "-q"]
testpaths = ["tests"]
testpaths = ["tests"]

[tool.ruff.lint]
select = ["E", "F", "B", "SIM", "I"]
ignore = ["E501", "B010"]
6 changes: 3 additions & 3 deletions pypush/apns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ["protocol", "create_apns_connection", "activate"]
__all__ = ["protocol", "create_apns_connection", "activate", "filters"]

from . import protocol
from .lifecycle import create_apns_connection
from . import filters, protocol
from .albert import activate
from .lifecycle import create_apns_connection
14 changes: 7 additions & 7 deletions pypush/apns/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import MISSING, field
from dataclasses import fields as dataclass_fields
from typing import Any, TypeVar, get_origin, get_args, Union
from typing import Any, TypeVar, Union, get_args, get_origin

from pypush.apns.transport import Packet

Expand Down Expand Up @@ -67,14 +67,14 @@ def from_packet(cls, packet: Packet):
)

# Check for extra fields
for field in packet.fields:
if field.id not in [
for current_field in packet.fields:
if current_field.id not in [
f.metadata["packet_id"]
for f in dataclass_fields(cls)
if f.metadata is not None and "packet_id" in f.metadata
]:
logging.warning(
f"Unexpected field with packet ID {field.id} in packet {packet}"
f"Unexpected field with packet ID {current_field.id} in packet {packet}"
)
return cls(**field_values)

Expand Down Expand Up @@ -122,15 +122,15 @@ def fid(
:param byte_len: The length of the field in bytes (for int fields)
:param default: The default value of the field
"""
if not default == MISSING and not default_factory == MISSING:
if default != MISSING and default_factory != MISSING:
raise ValueError("Cannot specify both default and default_factory")
if not default == MISSING:
if default != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default=default,
repr=repr,
)
if not default_factory == MISSING:
if default_factory != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default_factory=default_factory,
Expand Down
50 changes: 45 additions & 5 deletions pypush/apns/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,72 @@
from typing import Generic, TypeVar

import anyio
from anyio.abc import ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from . import filters

T = TypeVar("T")


class BroadcastStream(Generic[T]):
def __init__(self):
def __init__(self, backlog: int = 50):
self.streams: list[ObjectSendStream[T]] = []
self.backlog: list[T] = []
self._backlog_size = backlog

async def broadcast(self, packet):
logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
for stream in self.streams:
try:
await stream.send(packet)
except anyio.BrokenResourceError:
self.streams.remove(stream)
logging.error("Broken resource error")
# self.streams.remove(stream)
# If we have a backlog, add the packet to it
if len(self.backlog) >= self._backlog_size:
self.backlog.pop(0)
self.backlog.append(packet)

@asynccontextmanager
async def open_stream(self):
send, recv = anyio.create_memory_object_stream[T]()
async def open_stream(self, backlog: bool = True):
# 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
# start stalling the APNs connection itself
send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
if backlog:
for packet in self.backlog:
await send.send(packet)
self.streams.append(send)
async with recv:
yield recv
self.streams.remove(send)
await send.aclose()


W = TypeVar("W")
F = TypeVar("F")


class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""

def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter

async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream

async def aclose(self):
await self.source.aclose()


def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1
Expand Down
6 changes: 3 additions & 3 deletions pypush/apns/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import uuid
from base64 import b64decode
from typing import Tuple, Optional
from typing import Optional, Tuple

import httpx
from cryptography import x509
Expand Down Expand Up @@ -96,10 +96,10 @@ async def activate(

try:
protocol = re.search("<Protocol>(.*)</Protocol>", resp.text).group(1) # type: ignore
except AttributeError:
except AttributeError as e:
# Search for error text between <b> and </b>
error = re.search("<b>(.*)</b>", resp.text).group(1) # type: ignore
raise Exception(f"Failed to get certificate from Albert: {error}")
raise Exception(f"Failed to get certificate from Albert: {error}") from e

protocol = plistlib.loads(protocol.encode("utf-8"))

Expand Down
44 changes: 44 additions & 0 deletions pypush/apns/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from typing import Callable, Optional, Type, TypeVar

from pypush.apns import protocol

T1 = TypeVar("T1")
T2 = TypeVar("T2")
Filter = Callable[[T1], Optional[T2]]

# Chain with proper types so that subsequent filters only need to take output type of previous filter
T_IN = TypeVar("T_IN", bound=protocol.Command)
T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
T_OUT = TypeVar("T_OUT", bound=protocol.Command)


def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
def filter(command: T_IN) -> Optional[T_OUT]:
logging.debug(f"Filtering {command} with {first} and {second}")
filtered = first(command)
if filtered is None:
return None
return second(filtered)

return filter


T = TypeVar("T", bound=protocol.Command)


def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
def filter(command: protocol.Command) -> Optional[T]:
if isinstance(command, type):
return command
return None

return filter


def ALL(c):
return c


def NONE(_):
return None
Loading

0 comments on commit ce88a8e

Please sign in to comment.