Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: add type hints to autopush/utils.py
Browse files Browse the repository at this point in the history
issue #712
  • Loading branch information
bbangert committed Oct 15, 2016
1 parent 8aa1a7e commit 7790ce3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 49 deletions.
1 change: 0 additions & 1 deletion autopush/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ def __init__(self, results):
def __call__(self, *args, **kwargs):
try:
r = self.results[self.cur]
print r
if callable(r):
return r()
else:
Expand Down
91 changes: 43 additions & 48 deletions autopush/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
attrs,
attrib
)
from boto.dynamodb2.items import Item # flake8: noqa
from cryptography.fernet import Fernet # flake8: noqa
from jose import jwt
from typing import (
Optional,
Union,
Tuple,
) # flake8: noqa
from ua_parser import user_agent_parser

from autopush.web.push_validation import WebPushRequestSchema # flake8: noqa
from autopush.exceptions import InvalidTokenException


Expand Down Expand Up @@ -50,6 +58,7 @@


def normalize_id(ident):
# type: (Union[uuid.UUID, str]) -> str
if isinstance(ident, uuid.UUID):
return str(ident)
try:
Expand All @@ -59,13 +68,15 @@ def normalize_id(ident):


def canonical_url(scheme, hostname, port=None):
# type: (str, str, Optional[int]) -> str
"""Return a canonical URL given a scheme/hostname and optional port"""
if port is None or port == default_ports.get(scheme):
return "%s://%s" % (scheme, hostname)
return "%s://%s:%s" % (scheme, hostname, port)


def resolve_ip(hostname):
# type: (str) -> str
"""Resolve a hostname to its IP if possible"""
interfaces = socket.getaddrinfo(hostname, 0, socket.AF_INET,
socket.SOCK_STREAM,
Expand All @@ -77,6 +88,7 @@ def resolve_ip(hostname):


def validate_uaid(uaid):
# type: (str) -> Tuple[bool, str]
"""Validates a UAID a tuple indicating if its valid and the original
uaid, or a new uaid if its invalid"""
if uaid:
Expand All @@ -89,6 +101,7 @@ def validate_uaid(uaid):


def generate_hash(key, payload):
# type: (str, str) -> str
"""Generate a HMAC for the uaid using the secret
:returns: HMAC hash and the nonce used as a tuple (nonce, hash).
Expand All @@ -99,19 +112,21 @@ def generate_hash(key, payload):


def base64url_encode(string):
# type: (str) -> str
"""Encodes an unpadded Base64 URL-encoded string per RFC 7515."""
return base64.urlsafe_b64encode(string).strip('=')


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""

if len(string) % 4:
string += '===='[len(string) % 4:]
return string


def base64url_decode(string):
# type: (str) -> str
"""Decodes a Base64 URL-encoded string per RFC 7515.
RFC 7515 (used for Encrypted Content-Encoding and JWT) requires unpadded
Expand All @@ -122,6 +137,7 @@ def base64url_decode(string):


def get_amid():
# type: () -> str
"""Fetch the AMI instance ID
"""
Expand All @@ -135,6 +151,7 @@ def get_amid():


def decipher_public_key(key_data):
# type: (str) -> str
"""A public key may come in several flavors. Attempt to extract the
valid key bits from keys doing minimal validation checks.
Expand All @@ -161,6 +178,7 @@ def decipher_public_key(key_data):


def extract_jwt(token, crypto_key):
# type: (str, str) -> dict
"""Extract the claims from the validated JWT. """
# first split and convert the jwt.
if not token or not crypto_key:
Expand All @@ -180,14 +198,14 @@ def extract_jwt(token, crypto_key):


def parse_user_agent(agent_string):
# type: (str) -> Tuple[dict, dict]
"""Extracts user-agent data from a UA string
Parses the user-agent into two forms. A limited one suitable for Datadog
logging with limited tags, and a full string suitable for complete logging.
:returns: A tuple of dicts, the first being the Datadog limited and the
second being the complete info.
:rtype: (dict, dict)
"""
parsed = user_agent_parser.Parse(agent_string)
Expand Down Expand Up @@ -247,10 +265,10 @@ class WebPushNotification(object):
uaid = attrib() # type: uuid.UUID
channel_id = attrib() # type: uuid.UUID
ttl = attrib() # type: int
data = attrib(default=None)
headers = attrib(default=None) # type: dict
timestamp = attrib(default=Factory(lambda: int(time.time())))
topic = attrib(default=None)
data = attrib(default=None) # type: Optional[str]
headers = attrib(default=None) # type: Optional[dict]
timestamp = attrib(default=Factory(lambda: int(time.time()))) # type: int
topic = attrib(default=None) # type: Optional[str]

message_id = attrib(default=None) # type: str

Expand All @@ -259,6 +277,7 @@ class WebPushNotification(object):
update_id = attrib(default=None) # type: str

def generate_message_id(self, fernet):
# type: (Fernet) -> str
"""Generate a message-id suitable for accessing the message
For non-topic messages, no sort_key version is currently used and the
Expand All @@ -273,8 +292,6 @@ def generate_message_id(self, fernet):
This is a blocking call.
:type fernet: cryptography.fernet.Fernet
"""
if self.topic:
msg_key = ":".join(["01", self.uaid.hex, self.channel_id.hex,
Expand All @@ -287,12 +304,8 @@ def generate_message_id(self, fernet):

@staticmethod
def parse_decrypted_message_id(decrypted_token):
"""Parses a decrypted message-id into component parts
:type decrypted_token: str
:rtype: dict
"""
# type: (str) -> dict
"""Parses a decrypted message-id into component parts"""
topic = None
if decrypted_token.startswith("01:"):
info = decrypted_token.split(":")
Expand All @@ -313,6 +326,7 @@ def parse_decrypted_message_id(decrypted_token):
)

def cleanup_headers(self):
# type: () -> None
"""Sanitize the headers for this notification
This only needs to be run when creating a notification from passed
Expand Down Expand Up @@ -342,6 +356,7 @@ def cleanup_headers(self):

@property
def sort_key(self):
# type: () -> str
"""Return an appropriate sort_key for this notification"""
chid = normalize_id(self.channel_id)
if self.topic:
Expand All @@ -352,12 +367,8 @@ def sort_key(self):

@staticmethod
def parse_sort_key(sort_key):
"""Parse the sort key from the database
:type sort_key: str
:rtype: dict
"""
# type: (str) -> dict
"""Parse the sort key from the database"""
topic = None
message_id = None
if re.match(r'^\d\d:', sort_key):
Expand All @@ -370,29 +381,24 @@ def parse_sort_key(sort_key):

@property
def location(self):
# type: () -> str
"""Return an appropriate value for the Location header"""
return self.message_id

def expired(self, at_time=None):
# type: (Optional[int]) -> bool
"""Indicates whether the message has expired or not
:param at_time: Optional time to compare for expiration
:type at_time: int
"""
now = at_time or int(time.time())
return now >= ((self.ttl or 0) + self.timestamp)

@classmethod
def from_message_table(cls, uaid, item):
"""Create a WebPushNotification from a message table item
:type uaid: uuid.UUID
:type item: dict or boto.dynamodb2.item.Item
:rtype: WebPushNotification
"""
# type: (uuid.UUID, Union[dict, Item]) -> WebPushNotification
"""Create a WebPushNotification from a message table item"""
key_info = cls.parse_sort_key(item["chidmessageid"])
if key_info.get("topic"):
key_info["message_id"] = item["updateid"]
Expand All @@ -410,15 +416,10 @@ def from_message_table(cls, uaid, item):

@classmethod
def from_webpush_request_schema(cls, data, fernet):
# type: (WebPushRequestSchema, Fernet) -> WebPushNotification
"""Create a WebPushNotification from a validated WebPushRequestSchema
This is a blocking call.
:type data: autopush.web.push_validation.WebPushRequestSchema
:type fernet: cryptography.fernet.Fernet
:rtype: WebPushNotification
"""
sub = data["subscription"]
notif = cls(uaid=sub["uaid"], channel_id=sub["chid"],
Expand All @@ -436,6 +437,7 @@ def from_webpush_request_schema(cls, data, fernet):

@classmethod
def from_message_id(cls, message_id, fernet):
# type: (str, Fernet) -> WebPushNotification
"""Create a WebPushNotification from a message_id
This is a blocking call.
Expand All @@ -446,11 +448,6 @@ def from_message_id(cls, message_id, fernet):
This is suitable for passing to delete calls.
:type message_id: str
:type fernet: cryptography.fernet.Fernet
:rtype: WebPushNotification
"""
decrypted_message_id = fernet.decrypt(message_id)
key_info = cls.parse_decrypted_message_id(decrypted_message_id)
Expand All @@ -467,14 +464,8 @@ def from_message_id(cls, message_id, fernet):

@classmethod
def from_serialized(cls, uaid, data):
"""Create a WebPushNotification from a deserialized JSON dict
:type uaid: uuid.UUID
:type data: dict
:rtype: WebPushNotification
"""
# type: (uuid.UUID, dict) -> WebPushNotification
"""Create a WebPushNotification from a deserialized JSON dict"""
notif = cls(uaid=uaid, channel_id=uuid.UUID(data["channelID"]),
data=data.get("data"),
headers=data.get("headers"),
Expand All @@ -488,6 +479,7 @@ def from_serialized(cls, uaid, data):

@property
def version(self):
# type: () -> str
"""Return a 'version' for use with a websocket client
In our case we use the message-id as its a unique value for every
Expand All @@ -497,6 +489,7 @@ def version(self):
return self.message_id

def serialize(self):
# type: () -> dict
"""Serialize to a dict for delivery to a connection node"""
payload = dict(
channelID=normalize_id(self.channel_id),
Expand All @@ -511,6 +504,7 @@ def serialize(self):
return payload

def websocket_format(self):
# type: () -> dict
"""Format a notification for a websocket client"""
# Firefox currently requires channelIDs to be '-' formatted.
payload = dict(
Expand All @@ -527,5 +521,6 @@ def websocket_format(self):


def ms_time():
# type: () -> int
"""Return current time.time call as ms and a Python int"""
return int(time.time() * 1000)

0 comments on commit 7790ce3

Please sign in to comment.