diff --git a/autoendpoint/src/routers/webpush.rs b/autoendpoint/src/routers/webpush.rs index 53506adde..4fcc636e0 100644 --- a/autoendpoint/src/routers/webpush.rs +++ b/autoendpoint/src/routers/webpush.rs @@ -187,7 +187,7 @@ impl WebPushRouter { /// Update metrics and create a response for when a notification has been stored in the database /// for future transmission. fn make_stored_response(&self, notification: &Notification) -> RouterResponse { - self.make_response(notification, "Stored", StatusCode::ACCEPTED) + self.make_response(notification, "Stored", StatusCode::CREATED) } /// Update metrics and create a response after routing a notification diff --git a/tests/test_integration_all_rust.py b/tests/test_integration_all_rust.py index 2f61c1fe0..eeedc735b 100644 --- a/tests/test_integration_all_rust.py +++ b/tests/test_integration_all_rust.py @@ -3,8 +3,10 @@ """ import copy +import json import logging import os +import random import signal import socket import subprocess @@ -19,18 +21,20 @@ import bottle import ecdsa +import httplib import psutil import requests +import websocket import twisted.internet.base from autopush.db import ( DynamoDBResource, create_message_table, get_router_table ) -from autopush.tests.test_integration import Client from autopush.utils import base64url_encode from cryptography.fernet import Fernet from Queue import Empty, Queue from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.threads import deferToThread from twisted.trial import unittest from typing import Optional from urlparse import urlparse @@ -127,6 +131,243 @@ def get_free_port(): ) +class Client(object): + """Test Client""" + def __init__(self, url, sslcontext=None): + self.url = url + self.uaid = None + self.ws = None + self.use_webpush = True + self.channels = {} + self.messages = {} + self.notif_response = None # type: Optional[httplib.HTTPResponse] + self._crypto_key = """\ +keyid="http://example.org/bob/keys/123";salt="XZwpw6o37R-6qoZjw6KwAw=="\ +""" + self.sslcontext = sslcontext + self.headers = { + "User-Agent": + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:61.0) " + "Gecko/20100101 Firefox/61.0" + } + + def __getattribute__(self, name): + # Python fun to turn all functions into deferToThread functions + f = object.__getattribute__(self, name) + if name.startswith("__"): + return f + + if callable(f): + return lambda *args, **kwargs: deferToThread(f, *args, **kwargs) + else: + return f + + def connect(self, connection_port=None): + url = self.url + if connection_port: # pragma: nocover + url = "ws://localhost:{}/".format(connection_port) + self.ws = websocket.create_connection(url, header=self.headers) + return self.ws.connected + + def hello(self, uaid=None, services=None): + if self.channels: + chans = self.channels.keys() + else: + chans = [] + hello_dict = dict(messageType="hello", + use_webpush=True, + channelIDs=chans) + if uaid or self.uaid: + hello_dict["uaid"] = uaid or self.uaid + if services: # pragma: nocover + hello_dict["broadcasts"] = services + msg = json.dumps(hello_dict) + log.debug("Send: %s", msg) + self.ws.send(msg) + result = json.loads(self.ws.recv()) + log.debug("Recv: %s", result) + assert result["status"] == 200 + assert "-" not in result["uaid"] + if self.uaid and self.uaid != result["uaid"]: # pragma: nocover + log.debug("Mismatch on re-using uaid. Old: %s, New: %s", + self.uaid, result["uaid"]) + self.channels = {} + self.uaid = result["uaid"] + return result + + def broadcast_subscribe(self, services): # pragma: nocover + msg = json.dumps(dict(messageType="broadcast_subscribe", + broadcasts=services)) + log.debug("Send: %s", msg) + self.ws.send(msg) + + def register(self, chid=None, key=None, status=200): + chid = chid or str(uuid.uuid4()) + msg = json.dumps(dict(messageType="register", + channelID=chid, + key=key)) + log.debug("Send: %s", msg) + self.ws.send(msg) + rcv = self.ws.recv() + result = json.loads(rcv) + log.debug("Recv: %s", result) + assert result["status"] == status + assert result["channelID"] == chid + if status == 200: + self.channels[chid] = result["pushEndpoint"] + return result + + def unregister(self, chid): + msg = json.dumps(dict(messageType="unregister", channelID=chid)) + log.debug("Send: %s", msg) + self.ws.send(msg) + result = json.loads(self.ws.recv()) + log.debug("Recv: %s", result) + return result + + def delete_notification(self, channel, message=None, status=204): + messages = self.messages[channel] + if not message: + message = random.choice(messages) + + log.debug("Delete: %s", message) + url = urlparse(message) + http = None + if url.scheme == "https": # pragma: nocover + http = httplib.HTTPSConnection(url.netloc, context=self.sslcontext) + else: + http = httplib.HTTPConnection(url.netloc) + + http.request("DELETE", url.path) + resp = http.getresponse() + http.close() + assert resp.status == status + + def send_notification(self, channel=None, version=None, data=None, + use_header=True, status=None, ttl=200, + timeout=0.2, vapid=None, endpoint=None, + topic=None): + if not channel: + channel = random.choice(self.channels.keys()) + + endpoint = endpoint or self.channels[channel] + url = urlparse(endpoint) + http = None + if url.scheme == "https": # pragma: nocover + http = httplib.HTTPSConnection(url.netloc, context=self.sslcontext) + else: + http = httplib.HTTPConnection(url.netloc) + + headers = {} + if ttl is not None: + headers = {"TTL": str(ttl)} + if use_header: + headers.update({ + "Content-Type": "application/octet-stream", + "Content-Encoding": "aesgcm", + "Encryption": self._crypto_key, + "Crypto-Key": 'keyid="a1"; dh="JcqK-OLkJZlJ3sJJWstJCA"', + }) + if vapid: + headers.update({ + "Authorization": "Bearer " + vapid.get('auth') + }) + ckey = 'p256ecdsa="' + vapid.get('crypto-key') + '"' + headers.update({ + 'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey + }) + if topic: + headers["Topic"] = topic + body = data or "" + method = "POST" + # 202 status reserved for yet to be implemented push w/ reciept. + status = status or 201 + + log.debug("%s body: %s", method, body) + http.request(method, url.path.encode("utf-8"), body, headers) + resp = http.getresponse() + log.debug("%s Response (%s): %s", method, resp.status, resp.read()) + http.close() + assert resp.status == status, \ + "Expected %d, got %d" % (status, resp.status) + self.notif_response = resp + location = resp.getheader("Location", None) + log.debug("Response Headers: %s", resp.getheaders()) + if status >= 200 and status < 300: + assert location is not None + if status == 201 and ttl is not None: + ttl_header = resp.getheader("TTL") + assert ttl_header == str(ttl) + if ttl != 0 and status == 201: + assert location is not None + if channel in self.messages: + self.messages[channel].append(location) + else: + self.messages[channel] = [location] + + # Pull the notification if connected + if self.ws and self.ws.connected: + return object.__getattribute__(self, "get_notification")(timeout) + else: + return resp + + def get_notification(self, timeout=1): + orig_timeout = self.ws.gettimeout() + self.ws.settimeout(timeout) + try: + d = self.ws.recv() + log.debug("Recv: %s", d) + return json.loads(d) + except Exception: + return None + finally: + self.ws.settimeout(orig_timeout) + + def get_broadcast(self, timeout=1): # pragma: nocover + orig_timeout = self.ws.gettimeout() + self.ws.settimeout(timeout) + try: + d = self.ws.recv() + log.debug("Recv: %s", d) + result = json.loads(d) + assert result.get("messageType") == "broadcast" + return result + except Exception: # pragma: nocover + return None + finally: + self.ws.settimeout(orig_timeout) + + def ping(self): + log.debug("Send: %s", "{}") + self.ws.send("{}") + result = self.ws.recv() + log.debug("Recv: %s", result) + assert result == "{}" + return result + + def ack(self, channel, version): + msg = json.dumps(dict(messageType="ack", + updates=[dict(channelID=channel, + version=version)])) + log.debug("Send: %s", msg) + self.ws.send(msg) + + def disconnect(self): + self.ws.close() + + def sleep(self, duration): # pragma: nocover + time.sleep(duration) + + def wait_for(self, func): + """Waits several seconds for a function to return True""" + times = 0 + while not func(): # pragma: nocover + time.sleep(1) + times += 1 + if times > 9: # pragma: nocover + break + + def _get_vapid(key=None, payload=None, endpoint=None): global CONNECTION_CONFIG @@ -491,8 +732,8 @@ def test_topic_replacement_delivery(self): data2 = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() - yield client.send_notification(data=data, topic="Inbox", status=202) - yield client.send_notification(data=data2, topic="Inbox", status=202) + yield client.send_notification(data=data, topic="Inbox", status=201) + yield client.send_notification(data=data2, topic="Inbox", status=201) yield client.connect() yield client.hello() result = yield client.get_notification() @@ -512,7 +753,7 @@ def test_topic_no_delivery_on_reconnect(self): data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() - yield client.send_notification(data=data, topic="Inbox", status=202) + yield client.send_notification(data=data, topic="Inbox", status=201) yield client.connect() yield client.hello() result = yield client.get_notification(timeout=10) @@ -627,7 +868,7 @@ def test_delivery_repeat_without_ack(self): client = yield self.quick_register() yield client.disconnect() assert client.channels - yield client.send_notification(data=data, status=202) + yield client.send_notification(data=data, status=201) yield client.connect() yield client.hello() result = yield client.get_notification() @@ -664,8 +905,8 @@ def test_multiple_delivery_repeat_without_ack(self): client = yield self.quick_register() yield client.disconnect() assert client.channels - yield client.send_notification(data=data, status=202) - yield client.send_notification(data=data2, status=202) + yield client.send_notification(data=data, status=201) + yield client.send_notification(data=data2, status=201) yield client.connect() yield client.hello() result = yield client.get_notification() @@ -692,7 +933,7 @@ def test_topic_expired(self): client = yield self.quick_register() yield client.disconnect() assert client.channels - yield client.send_notification(data=data, ttl=1, topic="test", status=202) + yield client.send_notification(data=data, ttl=1, topic="test", status=201) yield client.sleep(2) yield client.connect() yield client.hello() @@ -711,8 +952,8 @@ def test_multiple_delivery_with_single_ack(self): client = yield self.quick_register() yield client.disconnect() assert client.channels - yield client.send_notification(data=data, status=202) - yield client.send_notification(data=data2, status=202) + yield client.send_notification(data=data, status=201) + yield client.send_notification(data=data2, status=201) yield client.connect() yield client.hello() result = yield client.get_notification(timeout=0.5) @@ -751,8 +992,8 @@ def test_multiple_delivery_with_multiple_ack(self): client = yield self.quick_register() yield client.disconnect() assert client.channels - yield client.send_notification(data=data, status=202) - yield client.send_notification(data=data2, status=202) + yield client.send_notification(data=data, status=201) + yield client.send_notification(data=data2, status=201) yield client.connect() yield client.hello() result = yield client.get_notification(timeout=0.5) @@ -824,7 +1065,7 @@ def test_ttl_expired(self): data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() - yield client.send_notification(data=data, ttl=1, status=202) + yield client.send_notification(data=data, ttl=1, status=201) time.sleep(1) yield client.connect() yield client.hello() @@ -840,9 +1081,9 @@ def test_ttl_batch_expired_and_good_one(self): client = yield self.quick_register() yield client.disconnect() for x in range(0, 12): - yield client.send_notification(data=data, ttl=1, status=202) + yield client.send_notification(data=data, ttl=1, status=201) - yield client.send_notification(data=data2, status=202) + yield client.send_notification(data=data2, status=201) time.sleep(1) yield client.connect() yield client.hello() @@ -867,12 +1108,12 @@ def test_ttl_batch_partly_expired_and_good_one(self): client = yield self.quick_register() yield client.disconnect() for x in range(0, 6): - yield client.send_notification(data=data, status=202) + yield client.send_notification(data=data, status=201) for x in range(0, 6): - yield client.send_notification(data=data1, ttl=1, status=202) + yield client.send_notification(data=data1, ttl=1, status=201) - yield client.send_notification(data=data2, status=202) + yield client.send_notification(data=data2, status=201) time.sleep(1) yield client.connect() yield client.hello() @@ -915,7 +1156,7 @@ def test_empty_message_without_crypto_headers(self): yield client.ack(result["channelID"], result["version"]) yield client.disconnect() - yield client.send_notification(use_header=False, status=202) + yield client.send_notification(use_header=False, status=201) yield client.connect() yield client.hello() result = yield client.get_notification() @@ -946,7 +1187,7 @@ def test_empty_message_with_crypto_headers(self): yield client.ack(result2["channelID"], result2["version"]) yield client.disconnect() - yield client.send_notification(status=202) + yield client.send_notification(status=201) yield client.connect() yield client.hello() result3 = yield client.get_notification() @@ -1027,7 +1268,7 @@ def test_msg_limit(self): uaid = client.uaid yield client.disconnect() for i in range(MSG_LIMIT + 1): - yield client.send_notification(status=202) + yield client.send_notification(status=201) yield client.connect() yield client.hello() assert client.uaid == uaid