diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..69cf479 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,29 @@ +# SOP Python circleci file + +version: 2.1 + +orbs: + python: circleci/python@2.1.1 + +jobs: + build_and_test: + executor: python/default + steps: + - checkout + - python/install-packages: + pkg-manager: pip + - run: + name: Build + command: pip3 install -r test-requirements.txt + - run: + name: Run tests + command: python -m pytest pywebpush + - persist_to_workspace: + root: ~/project + paths: + - . + +workflows: + build_and_test: + jobs: + - build_and_test diff --git a/.gitignore b/.gitignore index 5e68aad..d7a9866 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ __pycache__/ # Distribution / packaging .Python -env/ bin/ build/ develop-eggs/ @@ -23,9 +22,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -40,13 +42,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo @@ -54,13 +60,106 @@ coverage.xml # Django stuff: *.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy # Sphinx documentation docs/_build/ # PyBuilder +.pybuilder/ target/ -#Ipython Notebook +# Jupyter Notebook .ipynb_checkpoints -*.swp + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.vscode/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e4d4e..0999302 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # I am terrible at keeping this up-to-date. +## 2.0.0 (2024-01-02) +chore: Update to modern python practices +* include pyproject.toml file +* use python typing +* update to use pytest + + *BREAKING_CHANGE* + `Webpusher.encode` will now return a `NoData` exception if no data is present to encode. Chances are + you probably won't be impacted by this change since most push messages contain data, but one never knows. + This alters the prior behavior where it would return `None`. + ## 1.14.0 (2021-07-28) bug: accept all VAPID key instances (thanks @mthu) diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..c900a40 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,13 @@ +## Description + +*_NOTE_*: All commits MUST be signed! See https://docs.github.com/en/github/authenticating-to-github/signing-commits + +_Describe these changes._ + +## Testing + +_How should reviewers test?_ + +## Issue(s) + +Closes _#IssueNumber_ diff --git a/README.md b/README.md index 1222d49..f20d3ec 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,9 @@ make of that what you will. ## Installation -You'll need to run `python -m venv venv`. -Then +To work with this repo locally, you'll need to run `python -m venv venv`. +Then `venv/bin/pip install --editable .` -```bash -venv/bin/pip install -r requirements.txt -venv/bin/python setup.py develop -``` ## Usage @@ -60,7 +56,7 @@ webpush(subscription_info, This will encode `data`, add the appropriate VAPID auth headers if required and send it to the push server identified in the `subscription_info` block. -**Parameters** +##### Parameters _subscription_info_ - The `dict` of the subscription info (described above). @@ -85,7 +81,7 @@ e.g. the output of: openssl ecparam -name prime256v1 -genkey -noout -out private_key.pem ``` -**Example** +##### Example ```python from pywebpush import webpush, WebPushException @@ -127,7 +123,7 @@ The following methods are available: Send the data using additional parameters. On error, returns a `WebPushException` -**Parameters** +##### Parameters _data_ Binary string of data to send @@ -148,7 +144,7 @@ named `encrpypted.data`. This command is meant to be used for debugging purposes _timeout_ timeout for requests POST query. See [requests documentation](http://docs.python-requests.org/en/master/user/quickstart/#timeouts). -**Example** +##### Example to send from Chrome using the old GCM mode: @@ -160,13 +156,17 @@ WebPusher(subscription_info).send(data, headers, ttl, gcm_key) Encode the `data` for future use. On error, returns a `WebPushException` -**Parameters** +##### Parameters _data_ Binary string of data to send _content_encoding_ ECE content encoding type (defaults to "aes128gcm") -**Example** +*Note* This will return a `NoData` exception if the data is not present or empty. It is completely +valid to send a WebPush notification with no data, but encoding is a no-op in that case. Best not +to call it if you don't have data. + +##### Example ```python encoded_data = WebPush(subscription_info).encode(data) diff --git a/entry_points.txt b/entry_points.txt new file mode 100644 index 0000000..4b8a7bf --- /dev/null +++ b/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +pywebpush = "pywebpush.__main__:main" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fcca473 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +# This uses the semi-built-in "setuptools" which is currently the +# python pariah, but there are a lot of behaviors that still carry. +# This will draw a lot of information from `setup.py` and `setup.cfg` +# For more info see https://packaging.python.org/en/latest/ +# (although, be fore-warned, it gets fairly wonky and obsessed with +# details that you may not care about.) +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +# `dependencies` are taken from `setup.py` and include the contents of the +# `requirements.txt` file +name = "pywebpush" +authors = [{ name = "JR Conlin", email = "src+webpusher@jrconlin.com" }] +description = "WebPush publication library" +readme = "README.md" +# Use the LICENSE file for our license, since "MPL2" isn't included in the +# canonical list +license = { file = "LICENSE" } +keywords = ["webpush", "vapid", "notification"] +classifiers = [ + "Topic :: Internet :: WWW/HTTP", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Python", + "Programming Language :: Python :: 3", +] +# use the following fields defined in the setup.py file +# (When the guides talk about something being "dynamic", they +# want you to add the field here. +dynamic = ["version", "entry-points"] + +[project.urls] +Homepage = "https://github.com/web-push-libs/pywebpush" + +[project.optional-dependencies] +dev = ["black", "mock", "pytest"] + +# create the `pywebpush` helper using `python -m pip install --editable .` +[project.scripts] +pywebpush = "pywebpush.__main__:main" diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index 7ffc95e..e58cbd3 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -2,25 +2,30 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +import asyncio import base64 -from copy import deepcopy import json import os import time import logging +from copy import deepcopy +from typing import cast, Union, Dict try: - from urllib.parse import urlparse -except ImportError: # pragma nocover from urlparse import urlparse +except ImportError: # pragma nocover + from urllib.parse import urlparse -import six +import aiohttp import http_ece import requests +import six from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization +from functools import partial from py_vapid import Vapid, Vapid01 +from requests import Response class WebPushException(Exception): @@ -46,6 +51,10 @@ def __str__(self): return "WebPushException: {}{}".format(self.message, extra) +class NoData(Exception): + """Message contained No Data, no encoding required.""" + + class CaseInsensitiveDict(dict): """A dictionary that has case-insensitive keys""" @@ -111,16 +120,25 @@ class WebPusher: WebPusher(subscription_info).send(data, headers) """ + subscription_info = {} valid_encodings = [ # "aesgcm128", # this is draft-0, but DO NOT USE. "aesgcm", # draft-httpbis-encryption-encoding-01 - "aes128gcm" # RFC8188 Standard encoding + "aes128gcm", # RFC8188 Standard encoding ] verbose = False - def __init__(self, subscription_info, requests_session=None, - verbose=False): + # Note: the type declarations are not valid under python 3.8, + def __init__( + self, + subscription_info: Dict[ + str, Union[Union[str, bytes], Dict[str, Union[str, bytes]]] + ], + requests_session: Union[None, requests.Session] = None, + aiohttp_session: Union[None, aiohttp.client.ClientSession] = None, + verbose: bool = False, + ): """Initialize using the info provided by the client PushSubscription object (See https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe) @@ -144,34 +162,42 @@ def __init__(self, subscription_info, requests_session=None, else: self.requests_method = requests_session - if 'endpoint' not in subscription_info: + self.aiohttp_session = aiohttp_session + + if "endpoint" not in subscription_info: raise WebPushException("subscription_info missing endpoint URL") self.subscription_info = deepcopy(subscription_info) self.auth_key = self.receiver_key = None - if 'keys' in subscription_info: - keys = self.subscription_info['keys'] - for k in ['p256dh', 'auth']: + if "keys" in subscription_info: + keys: Dict[str, Union[str, bytes]] = cast( + Dict[str, Union[str, bytes]], self.subscription_info["keys"] + ) + for k in ["p256dh", "auth"]: if keys.get(k) is None: raise WebPushException("Missing keys value: {}".format(k)) if isinstance(keys[k], six.text_type): - keys[k] = bytes(keys[k].encode('utf8')) + keys[k] = bytes(cast(str, keys[k]).encode("utf8")) receiver_raw = base64.urlsafe_b64decode( - self._repad(keys['p256dh'])) + self._repad(cast(bytes, keys["p256dh"])) + ) if len(receiver_raw) != 65 and receiver_raw[0] != "\x04": raise WebPushException("Invalid p256dh key specified") self.receiver_key = receiver_raw self.auth_key = base64.urlsafe_b64decode( - self._repad(keys['auth'])) + self._repad(cast(bytes, keys["auth"])) + ) - def verb(self, msg, *args, **kwargs): + def verb(self, msg: str, *args, **kwargs): if self.verbose: logging.info(msg.format(*args, **kwargs)) - def _repad(self, data): + def _repad(self, data: bytes): """Add base64 padding to the end of a string, if required""" - return data + b"===="[:len(data) % 4] + return data + b"===="[: len(data) % 4] - def encode(self, data, content_encoding="aes128gcm"): + def encode( + self, data: bytes, content_encoding: str = "aes128gcm" + ) -> CaseInsensitiveDict: """Encrypt the data. :param data: A serialized block of byte data (String, JSON, bit array, @@ -184,18 +210,20 @@ def encode(self, data, content_encoding="aes128gcm"): :type content_encoding: enum("aesgcm", "aes128gcm") """ + reply = CaseInsensitiveDict() # Salt is a random 16 byte array. if not data: self.verb("No data found...") - return + raise NoData() if not self.auth_key or not self.receiver_key: raise WebPushException("No keys specified in subscription info") self.verb("Encoding data...") salt = None if content_encoding not in self.valid_encodings: - raise WebPushException("Invalid content encoding specified. " - "Select from " + - json.dumps(self.valid_encodings)) + raise WebPushException( + "Invalid content encoding specified. " + "Select from " + json.dumps(self.valid_encodings) + ) if content_encoding == "aesgcm": self.verb("Generating salt for aesgcm...") salt = os.urandom(16) @@ -205,11 +233,11 @@ def encode(self, data, content_encoding="aes128gcm"): server_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) crypto_key = server_key.public_key().public_bytes( encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint + format=serialization.PublicFormat.UncompressedPoint, ) if isinstance(data, six.text_type): - data = bytes(data.encode('utf8')) + data = bytes(data.encode("utf8")) if content_encoding == "aes128gcm": self.verb("Encrypting to aes128gcm...") encrypted = http_ece.encrypt( @@ -218,13 +246,12 @@ def encode(self, data, content_encoding="aes128gcm"): private_key=server_key, dh=self.receiver_key, auth_secret=self.auth_key, - version=content_encoding) - reply = CaseInsensitiveDict({ - 'body': encrypted - }) + version=content_encoding, + ) + reply["body"] = encrypted else: self.verb("Encrypting to aesgcm...") - crypto_key = base64.urlsafe_b64encode(crypto_key).strip(b'=') + crypto_key = base64.urlsafe_b64encode(crypto_key).strip(b"=") encrypted = http_ece.encrypt( data, salt=salt, @@ -232,16 +259,15 @@ def encode(self, data, content_encoding="aes128gcm"): keyid=crypto_key.decode(), dh=self.receiver_key, auth_secret=self.auth_key, - version=content_encoding) - reply = CaseInsensitiveDict({ - 'crypto_key': crypto_key, - 'body': encrypted, - }) + version=content_encoding, + ) + reply["crypto_key"] = crypto_key + reply["body"] = encrypted if salt: - reply['salt'] = base64.urlsafe_b64encode(salt).strip(b'=') + reply["salt"] = base64.urlsafe_b64encode(salt).strip(b"=") return reply - def as_curl(self, endpoint, encoded_data, headers): + def as_curl(self, endpoint: str, encoded_data: bytes, headers: Dict[str, str]): """Return the send as a curl command. Useful for debugging. This will write out the encoded data to a local @@ -257,23 +283,32 @@ def as_curl(self, endpoint, encoded_data, headers): """ header_list = [ - '-H "{}: {}" \\ \n'.format( - key.lower(), val) for key, val in headers.items() + '-H "{}: {}" \\ \n'.format(key.lower(), val) for key, val in headers.items() ] data = "" if encoded_data: with open("encrypted.data", "wb") as f: f.write(encoded_data) data = "--data-binary @encrypted.data" - if 'content-length' not in headers: + if "content-length" not in headers: self.verb("Generating content-length header...") header_list.append( - '-H "content-length: {}" \\ \n'.format(len(encoded_data))) - return ("""curl -vX POST {url} \\\n{headers}{data}""".format( - url=endpoint, headers="".join(header_list), data=data)) + '-H "content-length: {}" \\ \n'.format(len(encoded_data)) + ) + return """curl -vX POST {url} \\\n{headers}{data}""".format( + url=endpoint, headers="".join(header_list), data=data + ) - def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None, - content_encoding="aes128gcm", curl=False, timeout=None): + def _prepare_send_data( + self, + data: Union[None, bytes] = None, + headers: Union[None, Dict[str, str]] = None, + ttl: int = 0, + gcm_key: Union[None, str] = None, + reg_id: Union[None, str] = None, + content_encoding: str = "aes128gcm", + curl: bool = False, + ) -> dict: """Encode and send the data to the Push Service. :param data: A serialized block of data (see encode() ). @@ -294,14 +329,11 @@ def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None, :type content_encoding: str :param curl: Display output as `curl` command instead of sending :type curl: bool - :param timeout: POST requests timeout - :type timeout: float or tuple - """ # Encode the data. if headers is None: headers = dict() - encoded = {} + encoded = CaseInsensitiveDict() headers = CaseInsensitiveDict(headers) if data: encoded = self.encode(data, content_encoding) @@ -313,80 +345,132 @@ def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None, # should use ';' instead of ',' to append the headers. # see # https://github.com/webpush-wg/webpush-encryption/issues/6 - crypto_key += ';' - crypto_key += ( - "dh=" + encoded["crypto_key"].decode('utf8')) - headers.update({ - 'crypto-key': crypto_key - }) + crypto_key += ";" + crypto_key += "dh=" + encoded["crypto_key"].decode("utf8") + headers.update({"crypto-key": crypto_key}) if "salt" in encoded: - headers.update({ - 'encryption': "salt=" + encoded['salt'].decode('utf8') - }) - headers.update({ - 'content-encoding': content_encoding, - }) + headers.update({"encryption": "salt=" + encoded["salt"].decode("utf8")}) + headers.update( + { + "content-encoding": content_encoding, + } + ) if gcm_key: # guess if it is a legacy GCM project key or actual FCM key # gcm keys are all about 40 chars (use 100 for confidence), # fcm keys are 153-175 chars if len(gcm_key) < 100: self.verb("Guessing this is legacy GCM...") - endpoint = 'https://android.googleapis.com/gcm/send' + endpoint = "https://android.googleapis.com/gcm/send" else: self.verb("Guessing this is FCM...") - endpoint = 'https://fcm.googleapis.com/fcm/send' + endpoint = "https://fcm.googleapis.com/fcm/send" reg_ids = [] if not reg_id: - reg_id = self.subscription_info['endpoint'].rsplit('/', 1)[-1] + reg_id = cast(str, self.subscription_info["endpoint"]).rsplit("/", 1)[ + -1 + ] self.verb("Fetching out registration id: {}", reg_id) reg_ids.append(reg_id) gcm_data = dict() - gcm_data['registration_ids'] = reg_ids + gcm_data["registration_ids"] = reg_ids if data: - gcm_data['raw_data'] = base64.b64encode( - encoded.get('body')).decode('utf8') - gcm_data['time_to_live'] = int( - headers['ttl'] if 'ttl' in headers else ttl) + buffer = encoded.get("body") + if buffer: + gcm_data["raw_data"] = base64.b64encode(buffer).decode("utf8") + gcm_data["time_to_live"] = int(headers["ttl"] if "ttl" in headers else ttl) encoded_data = json.dumps(gcm_data) - headers.update({ - 'Authorization': 'key='+gcm_key, - 'Content-Type': 'application/json', - }) + headers.update( + { + "Authorization": "key=" + gcm_key, + "Content-Type": "application/json", + } + ) else: - encoded_data = encoded.get('body') - endpoint = self.subscription_info['endpoint'] + encoded_data = encoded.get("body") + endpoint = self.subscription_info["endpoint"] - if 'ttl' not in headers or ttl: + if "ttl" not in headers or ttl: self.verb("Generating TTL of 0...") - headers['ttl'] = str(ttl or 0) + headers["ttl"] = str(ttl or 0) # Additionally useful headers: # Authorization / Crypto-Key (VAPID headers) + + self.verb( + "\nSending request to" "\n\thost: {}\n\theaders: {}\n\tdata: {}", + endpoint, + headers, + encoded_data, + ) + + return {"endpoint": endpoint, "data": encoded_data, "headers": headers} + + def send(self, *args, **kwargs) -> Union[Response, str]: + """Encode and send the data to the Push Service""" + timeout = kwargs.pop("timeout", 10000) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") + if curl: - return self.as_curl(endpoint, encoded_data, headers) - self.verb("\nSending request to" - "\n\thost: {}\n\theaders: {}\n\tdata: {}", - endpoint, headers, encoded_data) - resp = self.requests_method.post(endpoint, - data=encoded_data, - headers=headers, - timeout=timeout) - self.verb("\nResponse:\n\tcode: {}\n\tbody: {}\n", - resp.status_code, resp.text or "Empty") + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + + resp = self.requests_method.post( + endpoint, + timeout=timeout, + **params, + ) + self.verb( + "\nResponse:\n\tcode: {}\n\tbody: {}\n", + resp.status_code, + resp.text or "Empty", + ) return resp + async def send_async(self, *args, **kwargs) -> Union[aiohttp.ClientResponse, str]: + timeout = kwargs.pop("timeout", 10000) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") -def webpush(subscription_info, - data=None, - vapid_private_key=None, - vapid_claims=None, - content_encoding="aes128gcm", - curl=False, - timeout=None, - ttl=0, - verbose=False, - headers=None, - requests_session=None): + if curl: + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + if self.aiohttp_session: + resp = await self.aiohttp_session.post(endpoint, timeout=timeout, **params) + resp_text = await resp.text() + else: + async with aiohttp.ClientSession() as session: + resp = await session.post(endpoint, timeout=timeout, **params) + resp_text = await resp.text() + self.verb( + "\nResponse:\n\tcode: {}\n\tbody: {}\n", + resp.status, + resp_text or "Empty", + ) + return resp + + +def webpush( + subscription_info: Dict[ + str, Union[Union[str, bytes], Dict[str, Union[str, bytes]]] + ], + data: Union[None, str] = None, + vapid_private_key: Union[None, Vapid, str] = None, + vapid_claims: Union[None, Dict[str, Union[str, int]]] = None, + content_encoding: str = "aes128gcm", + curl: bool = False, + timeout: Union[None, float] = None, + ttl: int = 0, + verbose: bool = False, + headers: Union[None, Dict[str, Union[str, int, float]]] = None, + requests_session: Union[None, requests.Session] = None, +) -> Union[str, requests.Response]: """ One call solution to endcode and send `data` to the endpoint contained in `subscription_info` using optional VAPID auth headers. @@ -425,7 +509,7 @@ def webpush(subscription_info, :param curl: Return as "curl" string instead of sending :type curl: bool :param timeout: POST requests timeout - :type timeout: float or tuple + :type timeout: float :param ttl: Time To Live :type ttl: int :param verbose: Provide verbose feedback @@ -445,19 +529,19 @@ def webpush(subscription_info, if vapid_claims: if verbose: logging.info("Generating VAPID headers...") - if not vapid_claims.get('aud'): - url = urlparse(subscription_info.get('endpoint')) + if not vapid_claims.get("aud"): + url = urlparse(cast(str, subscription_info.get("endpoint"))) aud = "{}://{}".format(url.scheme, url.netloc) - vapid_claims['aud'] = aud + vapid_claims["aud"] = aud # Remember, passed structures are mutable in python. # It's possible that a previously set `exp` field is no longer valid. - if (not vapid_claims.get('exp') - or vapid_claims.get('exp') < int(time.time())): + if not vapid_claims.get("exp") or int(vapid_claims.get("exp") or 0) < int( + time.time() + ): # encryption lives for 12 hours - vapid_claims['exp'] = int(time.time()) + (12 * 60 * 60) + vapid_claims["exp"] = int(time.time()) + (12 * 60 * 60) if verbose: - logging.info("Setting VAPID expry to {}...".format( - vapid_claims['exp'])) + logging.info("Setting VAPID expry to {}...".format(vapid_claims["exp"])) if not vapid_private_key: raise WebPushException("VAPID dict missing 'private_key'") if isinstance(vapid_private_key, Vapid01): @@ -468,10 +552,8 @@ def webpush(subscription_info, # Presume that key from file is handled correctly by # py_vapid. if verbose: - logging.info( - "Reading VAPID key from file {}".format(vapid_private_key)) - vv = Vapid.from_file( - private_key_file=vapid_private_key) # pragma no cover + logging.info("Reading VAPID key from file {}".format(vapid_private_key)) + vv = Vapid.from_file(private_key_file=vapid_private_key) # pragma no cover else: if verbose: logging.info("Reading VAPID key from arguments") @@ -493,8 +575,12 @@ def webpush(subscription_info, curl=curl, timeout=timeout, ) - if not curl and response.status_code > 202: - raise WebPushException("Push failed: {} {}\nResponse body:{}".format( - response.status_code, response.reason, response.text), - response=response) + if not curl and cast(Response, response).status_code > 202: + response = cast(Response, response) + raise WebPushException( + "Push failed: {} {}\nResponse body:{}".format( + response.status_code, response.reason, response.text + ), + response=response, + ) return response diff --git a/pywebpush/__main__.py b/pywebpush/__main__.py index 530bdc9..0a88859 100644 --- a/pywebpush/__main__.py +++ b/pywebpush/__main__.py @@ -10,16 +10,25 @@ def get_config(): parser = argparse.ArgumentParser(description="WebPush tool") - parser.add_argument("--data", '-d', help="Data file") + parser.add_argument("--data", "-d", help="Data file") parser.add_argument("--info", "-i", help="Subscription Info JSON file") parser.add_argument("--head", help="Header Info JSON file") parser.add_argument("--claims", help="Vapid claim file") parser.add_argument("--key", help="Vapid private key file path") - parser.add_argument("--curl", help="Don't send, display as curl command", - default=False, action="store_true") + parser.add_argument( + "--curl", + help="Don't send, display as curl command", + default=False, + action="store_true", + ) parser.add_argument("--encoding", default="aes128gcm") - parser.add_argument("--verbose", "-v", help="Provide verbose feedback", - default=False, action="store_true") + parser.add_argument( + "--verbose", + "-v", + help="Provide verbose feedback", + default=False, + action="store_true", + ) args = parser.parse_args() @@ -33,7 +42,8 @@ def get_config(): args.sub_info = json.loads(r.read()) except JSONDecodeError as e: raise WebPushException( - "Could not read the subscription info file: {}", e) + "Could not read the subscription info file: {}", e + ) if args.data: with open(args.data) as r: args.data = r.read() @@ -42,8 +52,7 @@ def get_config(): try: args.head = json.loads(r.read()) except JSONDecodeError as e: - raise WebPushException( - "Could not read the header arguments: {}", e) + raise WebPushException("Could not read the header arguments: {}", e) if args.claims: if not args.key: raise WebPushException("No private --key specified for claims") @@ -52,7 +61,8 @@ def get_config(): args.claims = json.loads(r.read()) except JSONDecodeError as e: raise WebPushException( - "Could not read the VAPID claims file {}".format(e)) + "Could not read the VAPID claims file {}".format(e) + ) except Exception as ex: logging.error("Couldn't read input {}.".format(ex)) raise ex @@ -60,7 +70,7 @@ def get_config(): def main(): - """ Send data """ + """Send data""" try: args = get_config() @@ -72,7 +82,8 @@ def main(): curl=args.curl, content_encoding=args.encoding, verbose=args.verbose, - headers=args.head) + headers=args.head, + ) print(result) except Exception as ex: logging.error("{}".format(ex)) diff --git a/pywebpush/tests/test_webpush.py b/pywebpush/tests/test_webpush.py index 93dc51a..6f8378b 100644 --- a/pywebpush/tests/test_webpush.py +++ b/pywebpush/tests/test_webpush.py @@ -3,19 +3,20 @@ import os import unittest import time +from typing import cast, Union, Dict -from mock import patch, Mock import http_ece +import py_vapid +import requests +from mock import patch, Mock, AsyncMock from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend -import py_vapid - -from pywebpush import WebPusher, WebPushException, CaseInsensitiveDict, webpush +from pywebpush import WebPusher, NoData, WebPushException, CaseInsensitiveDict, webpush -class WebpushTestCase(unittest.TestCase): +class WebpushTestUtils(unittest.TestCase): # This is a exported DER formatted string of an ECDH public key # This was lifted from the py_vapid tests. vapid_key = ( @@ -24,75 +25,80 @@ class WebpushTestCase(unittest.TestCase): "M5xqEwuPM7VuQcyiLDhvovthPIXx+gsQRQ==" ) - def _gen_subscription_info(self, - recv_key=None, - endpoint="https://example.com/"): + def _gen_subscription_info(self, recv_key=None, endpoint="https://example.com/"): if not recv_key: - recv_key = ec.generate_private_key(ec.SECP256R1, default_backend()) + recv_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) return { "endpoint": endpoint, "keys": { - 'auth': base64.urlsafe_b64encode(os.urandom(16)).strip(b'='), - 'p256dh': self._get_pubkey_str(recv_key), - } + "auth": base64.urlsafe_b64encode(os.urandom(16)).strip(b"="), + "p256dh": self._get_pubkey_str(recv_key), + }, } def _get_pubkey_str(self, priv_key): return base64.urlsafe_b64encode( priv_key.public_key().public_bytes( encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint - )).strip(b'=') + format=serialization.PublicFormat.UncompressedPoint, + ) + ).strip(b"=") def test_init(self): # use static values so we know what to look for in the reply subscription_info = { - u"endpoint": u"https://example.com/", - u"keys": { - u"p256dh": (u"BOrnIslXrUow2VAzKCUAE4sIbK00daEZCswOcf8m3T" - "F8V82B-OpOg5JbmYLg44kRcvQC1E2gMJshsUYA-_zMPR8"), - u"auth": u"k8JV6sjdbhAi1n3_LDBLvA" - } + "endpoint": "https://example.com/", + "keys": { + "p256dh": ( + "BOrnIslXrUow2VAzKCUAE4sIbK00daEZCswOcf8m3T" + "F8V82B-OpOg5JbmYLg44kRcvQC1E2gMJshsUYA-_zMPR8" + ), + "auth": "k8JV6sjdbhAi1n3_LDBLvA", + }, } - rk_decode = (b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b' - b'\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1' - b'|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3' - b'\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00' - b'\xfb\xfc\xcc=\x1f') + rk_decode = ( + b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b' + b"\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1" + b"|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3" + b"\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00" + b"\xfb\xfc\xcc=\x1f" + ) self.assertRaises( - WebPushException, - WebPusher, - {"keys": {'p256dh': 'AAA=', 'auth': 'AAA='}}) + WebPushException, WebPusher, {"keys": {"p256dh": "AAA=", "auth": "AAA="}} + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", "keys": {'p256dh': 'AAA='}}) + {"endpoint": "https://example.com", "keys": {"p256dh": "AAA="}}, + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", "keys": {'auth': 'AAA='}}) + {"endpoint": "https://example.com", "keys": {"auth": "AAA="}}, + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", - "keys": {'p256dh': 'AAA=', 'auth': 'AAA='}}) + { + "endpoint": "https://example.com", + "keys": {"p256dh": "AAA=", "auth": "AAA="}, + }, + ) push = WebPusher(subscription_info) assert push.subscription_info != subscription_info - assert push.subscription_info['keys'] != subscription_info['keys'] - assert push.subscription_info['endpoint'] == \ - subscription_info['endpoint'] + assert push.subscription_info["keys"] != subscription_info["keys"] + assert push.subscription_info["endpoint"] == subscription_info["endpoint"] assert push.receiver_key == rk_decode assert push.auth_key == b'\x93\xc2U\xea\xc8\xddn\x10"\xd6}\xff,0K\xbc' def test_encode(self): for content_encoding in ["aesgcm", "aes128gcm"]: - recv_key = ec.generate_private_key( - ec.SECP256R1, default_backend()) + recv_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) subscription_info = self._gen_subscription_info(recv_key) data = "Mary had a little lamb, with some nice mint jelly" push = WebPusher(subscription_info) - encoded = push.encode(data, content_encoding=content_encoding) + encoded = push.encode(data.encode(), content_encoding=content_encoding) """ crypto_key = base64.urlsafe_b64encode( self._get_pubkey_str(recv_key) @@ -100,48 +106,45 @@ def test_encode(self): """ # Convert these b64 strings into their raw, binary form. raw_salt = None - if 'salt' in encoded: - raw_salt = base64.urlsafe_b64decode( - push._repad(encoded['salt'])) + if "salt" in encoded: + raw_salt = base64.urlsafe_b64decode(push._repad(encoded["salt"])) raw_dh = None if content_encoding != "aes128gcm": - raw_dh = base64.urlsafe_b64decode( - push._repad(encoded['crypto_key'])) + raw_dh = base64.urlsafe_b64decode(push._repad(encoded["crypto_key"])) raw_auth = base64.urlsafe_b64decode( - push._repad(subscription_info['keys']['auth'])) + push._repad(subscription_info["keys"]["auth"]) + ) decoded = http_ece.decrypt( - encoded['body'], + encoded["body"], salt=raw_salt, dh=raw_dh, private_key=recv_key, auth_secret=raw_auth, - version=content_encoding - ) - assert decoded.decode('utf8') == data + version=content_encoding, + ) + assert decoded.decode("utf8") == data def test_bad_content_encoding(self): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb, with some nice mint jelly" push = WebPusher(subscription_info) - self.assertRaises(WebPushException, - push.encode, - data, - content_encoding="aesgcm128") + self.assertRaises( + WebPushException, push.encode, data, content_encoding="aesgcm128" + ) @patch("requests.post") def test_send(self, mock_post): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" WebPusher(subscription_info).send(data, headers) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey - assert pheaders.get('content-encoding') == 'aes128gcm' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" @patch("requests.post") def test_send_vapid(self, mock_post): @@ -154,26 +157,26 @@ def test_send_vapid(self, mock_post): vapid_private_key=self.vapid_key, vapid_claims={"sub": "mailto:ops@example.com"}, content_encoding="aesgcm", - headers={"Test-Header": "test-value"} + headers={"Test-Header": "test-value"}, ) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" def repad(str): - return str + "===="[:len(str) % 4] + return str + "===="[: len(str) % 4] auth = json.loads( base64.urlsafe_b64decode( - repad(pheaders['authorization'].split('.')[1]) - ).decode('utf8') + repad(pheaders["authorization"].split(".")[1]) + ).decode("utf8") ) - assert subscription_info.get('endpoint').startswith(auth['aud']) - assert 'vapid' in pheaders.get('authorization') - ckey = pheaders.get('crypto-key') - assert 'dh=' in ckey - assert pheaders.get('content-encoding') == 'aesgcm' - assert pheaders.get('test-header') == 'test-value' + assert subscription_info.get("endpoint", "").startswith(auth["aud"]) + assert "vapid" in pheaders.get("authorization") + ckey = pheaders.get("crypto-key") + assert "dh=" in ckey + assert pheaders.get("content-encoding") == "aesgcm" + assert pheaders.get("test-header") == "test-value" @patch.object(WebPusher, "send") @patch.object(py_vapid.Vapid, "sign") @@ -182,7 +185,9 @@ def test_webpush_vapid_instance(self, vapid_sign, pusher_send): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb" vapid_key = py_vapid.Vapid.from_string(self.vapid_key) - claims = dict(sub="mailto:ops@example.com", aud="https://example.com") + claims: Dict[str, Union[str, int]] = dict( + sub="mailto:ops@example.com", aud="https://example.com" + ) webpush( subscription_info=subscription_info, data=data, @@ -199,9 +204,11 @@ def test_webpush_vapid_exp(self, vapid_sign, pusher_send): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb" vapid_key = py_vapid.Vapid.from_string(self.vapid_key) - claims = dict(sub="mailto:ops@example.com", - aud="https://example.com", - exp=int(time.time() - 48600)) + claims = dict( + sub="mailto:ops@example.com", + aud="https://example.com", + exp=int(time.time() - 48600), + ) webpush( subscription_info=subscription_info, data=data, @@ -210,7 +217,7 @@ def test_webpush_vapid_exp(self, vapid_sign, pusher_send): ) vapid_sign.assert_called_once_with(claims) pusher_send.assert_called_once() - assert claims['exp'] > int(time.time()) + assert int(claims["exp"]) > int(time.time()) @patch("requests.post") def test_send_bad_vapid_no_key(self, mock_post): @@ -225,8 +232,9 @@ def test_send_bad_vapid_no_key(self, mock_post): data=data, vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" - }) + "sub": "mailto:ops@example.com", + }, + ) @patch("requests.post") def test_send_bad_vapid_bad_return(self, mock_post): @@ -241,53 +249,47 @@ def test_send_bad_vapid_bad_return(self, mock_post): data=data, vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" + "sub": "mailto:ops@example.com", }, - vapid_private_key=self.vapid_key) + vapid_private_key=self.vapid_key, + ) @patch("requests.post") def test_send_empty(self, mock_post): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - WebPusher(subscription_info).send('', headers) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert 'encryption' not in pheaders - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + WebPusher(subscription_info).send("", headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert "encryption" not in pheaders + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey def test_encode_empty(self): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - encoded = WebPusher(subscription_info).encode('', headers) - assert encoded is None + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + pusher = WebPusher(subscription_info) + self.assertRaises(NoData, pusher.encode, "", headers) def test_encode_no_crypto(self): subscription_info = self._gen_subscription_info() - del(subscription_info['keys']) - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - data = 'Something' + del subscription_info["keys"] + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Something" pusher = WebPusher(subscription_info) - self.assertRaises( - WebPushException, - pusher.encode, - data, - headers) + self.assertRaises(WebPushException, pusher.encode, data, headers) @patch("requests.post") def test_send_no_headers(self, mock_post): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb" WebPusher(subscription_info).send(data) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('content-encoding') == 'aes128gcm' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("content-encoding") == "aes128gcm" @patch("pywebpush.open") def test_as_curl(self, opener): @@ -297,40 +299,40 @@ def test_as_curl(self, opener): data="Mary had a little lamb", vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" + "sub": "mailto:ops@example.com", }, vapid_private_key=self.vapid_key, - curl=True + curl=True, ) + result = cast(str, result) for s in [ "curl -vX POST https://example.com", - "-H \"content-encoding: aes128gcm\"", - "-H \"authorization: vapid ", - "-H \"ttl: 0\"", - "-H \"content-length:" + '-H "content-encoding: aes128gcm"', + '-H "authorization: vapid ', + '-H "ttl: 0"', + '-H "content-length:', ]: assert s in result, "missing: {}".format(s) def test_ci_dict(self): ci = CaseInsensitiveDict({"Foo": "apple", "bar": "banana"}) - assert 'apple' == ci["foo"] - assert 'apple' == ci.get("FOO") - assert 'apple' == ci.get("Foo") - del (ci['FOO']) - assert ci.get('Foo') is None + assert "apple" == ci["foo"] + assert "apple" == ci.get("FOO") + assert "apple" == ci.get("Foo") + del ci["FOO"] + assert ci.get("Foo") is None @patch("requests.post") def test_gcm(self, mock_post): subscription_info = self._gen_subscription_info( - None, - endpoint="https://android.googleapis.com/gcm/send/regid123") - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + None, endpoint="https://android.googleapis.com/gcm/send/regid123" + ) + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" wp = WebPusher(subscription_info) wp.send(data, headers, gcm_key="gcm_key_value") - pdata = json.loads(mock_post.call_args[1].get('data')) - pheaders = mock_post.call_args[1].get('headers') + pdata = json.loads(mock_post.call_args[1].get("data")) + pheaders = mock_post.call_args[1].get("headers") assert pdata["registration_ids"][0] == "regid123" assert pheaders.get("authorization") == "key=gcm_key_value" assert pheaders.get("content-type") == "application/json" @@ -340,52 +342,127 @@ def test_timeout(self, mock_post): mock_post.return_value.status_code = 200 subscription_info = self._gen_subscription_info() WebPusher(subscription_info).send(timeout=5.2) - assert mock_post.call_args[1].get('timeout') == 5.2 + assert mock_post.call_args[1].get("timeout") == 5.2 webpush(subscription_info, timeout=10.001) - assert mock_post.call_args[1].get('timeout') == 10.001 + assert mock_post.call_args[1].get("timeout") == 10.001 @patch("requests.Session") def test_send_using_requests_session(self, mock_session): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + WebPusher(subscription_info, requests_session=mock_session).send(data, headers) + assert subscription_info.get("endpoint") == mock_session.post.call_args[0][0] + pheaders = mock_session.post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" + + +class WebPusherAsyncTestCase(WebpushTestUtils, unittest.IsolatedAsyncioTestCase): + @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) + async def test_send(self, mock_post): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + await WebPusher(subscription_info).send_async(data, headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" + + @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) + async def test_send_empty(self, mock_post): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + await WebPusher(subscription_info).send_async("", headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert "encryption" not in pheaders + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + + @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) + async def test_send_no_headers(self, mock_post): + subscription_info = self._gen_subscription_info() + data = "Mary had a little lamb" + await WebPusher(subscription_info).send_async(data) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("content-encoding") == "aes128gcm" + + @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) + async def test_fcm(self, mock_post): + subscription_info = self._gen_subscription_info( + None, endpoint="https://android.googleapis.com/fcm/send/regid123" + ) + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + wp = WebPusher(subscription_info) + await wp.send_async(data, headers, gcm_key="gcm_key_value") + pdata = json.loads(mock_post.call_args[1].get("data")) + pheaders = mock_post.call_args[1].get("headers") + assert pdata["registration_ids"][0] == "regid123" + assert pheaders.get("authorization") == "key=gcm_key_value" + assert pheaders.get("content-type") == "application/json" + + @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) + async def test_timeout(self, mock_post): + mock_post.return_value.status_code = 200 + subscription_info = self._gen_subscription_info() + await WebPusher(subscription_info).send_async(timeout=5.2) + assert mock_post.call_args[1].get("timeout") == 5.2 + + @patch("aiohttp.ClientSession", new_callable=AsyncMock) + async def test_send_using_requests_session(self, mock_session): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" - WebPusher(subscription_info, - requests_session=mock_session).send(data, headers) - assert subscription_info.get( - 'endpoint') == mock_session.post.call_args[0][0] - pheaders = mock_session.post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey - assert pheaders.get('content-encoding') == 'aes128gcm' + await WebPusher(subscription_info, aiohttp_session=mock_session).send_async( + data, headers + ) + assert subscription_info.get("endpoint") == mock_session.post.call_args[0][0] + pheaders = mock_session.post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" class WebpushExceptionTestCase(unittest.TestCase): - def test_exception(self): from requests import Response exp = WebPushException("foo") - assert ("{}".format(exp) == "WebPushException: foo") + assert "{}".format(exp) == "WebPushException: foo" # Really should try to load the response to verify, but this mock # covers what we need. response = Mock(spec=Response) response.text = ( - '{"code": 401, "errno": 109, "error": ' - '"Unauthorized", "more_info": "http://' - 'autopush.readthedocs.io/en/latest/htt' - 'p.html#error-codes", "message": "Requ' - 'est did not validate missing authoriz' - 'ation header"}') + '{"code": 401, "errno": 109, "error": ' + '"Unauthorized", "more_info": "http://' + "autopush.readthedocs.io/en/latest/htt" + 'p.html#error-codes", "message": "Requ' + "est did not validate missing authoriz" + 'ation header"}' + ) response.json.return_value = json.loads(response.text) response.status_code = 401 response.reason = "Unauthorized" exp = WebPushException("foo", response) assert "{}".format(exp) == "WebPushException: foo, Response {}".format( - response.text) - assert '{}'.format(exp.response), '' - assert exp.response.json().get('errno') == 109 + response.text + ) + assert "{}".format(exp.response), "" + assert cast(requests.Response, exp.response).json().get("errno") == 109 exp = WebPushException("foo", [1, 2, 3]) - assert '{}'.format(exp) == "WebPushException: foo, Response [1, 2, 3]" + assert "{}".format(exp) == "WebPushException: foo, Response [1, 2, 3]" diff --git a/requirements.txt b/requirements.txt index 74596b3..16f1e7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiohttp cryptography>=2.6.1 http-ece>=1.1.0 requests>=2.21.0 diff --git a/setup.py b/setup.py index dacceaf..f8756d1 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,65 @@ +### This is the much older setup script originally used by things like +### setuptools and distutils. It's fallen out of favor by more recent +### packaging tools, but is still referred to on occasion. +### It's a hold-over from the Python 2.7 days, so there are a fair number +### of sharp edges and stone clubs. +### For more info see https://python-packaging.readthedocs.io/en/latest/index.html import io import os from setuptools import find_packages, setup - -__version__ = "1.14.1" +__version__ = "2.0.0" def read_from(file): reply = [] - with io.open(os.path.join(here, file), encoding='utf8') as f: + with io.open(os.path.join(here, file), encoding="utf8") as f: for line in f: line = line.strip() if not line: break - if line[:2] == '-r': - reply += read_from(line.split(' ')[1]) + if line[:2] == "-r": + reply += read_from(line.split(" ")[1]) continue - if line[0] != '#' or line[:2] != '//': + if line[0] != "#" or line[:2] != "//": reply.append(line) return reply here = os.path.abspath(os.path.dirname(__file__)) -with io.open(os.path.join(here, 'README.rst'), encoding='utf8') as f: +with io.open(os.path.join(here, "README.rst"), encoding="utf8") as f: README = f.read() -with io.open(os.path.join(here, 'CHANGELOG.md'), encoding='utf8') as f: +with io.open(os.path.join(here, "CHANGELOG.md"), encoding="utf8") as f: CHANGES = f.read() setup( name="pywebpush", version=__version__, packages=find_packages(), - description='WebPush publication library', - long_description=README + '\n\n' + CHANGES, + description="WebPush publication library", + long_description=README + "\n\n" + CHANGES, classifiers=[ "Topic :: Internet :: WWW/HTTP", "Programming Language :: Python :: Implementation :: PyPy", - 'Programming Language :: Python', + "Programming Language :: Python", "Programming Language :: Python :: 3", ], - keywords='push webpush publication', + keywords="push webpush publication", author="JR Conlin", author_email="src+webpusher@jrconlin.com", - url='https://github.com/web-push-libs/pywebpush', + url="https://github.com/web-push-libs/pywebpush", license="MPL2", include_package_data=True, zip_safe=False, - install_requires=read_from('requirements.txt'), - tests_require=read_from('test-requirements.txt'), - entry_points=""" - [console_scripts] - pywebpush = pywebpush.__main__:main - """, + install_requires=read_from("requirements.txt"), + tests_require=read_from("test-requirements.txt"), + # This used to specify the entry point script that will + # be created, and still will if you run + # `python setup.py develop` + entry_points={ + "console_scripts": [ + "pywebpush=pywebpush.__main__:main" + ], + } ) diff --git a/test-requirements.txt b/test-requirements.txt index 4d5117e..c284191 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,4 @@ -r requirements.txt -pytest -coverage>=4.4.1 -mock>=2.0.0 -flake8>=3.3.0 +black +mock +pytest \ No newline at end of file