diff --git a/tests/integration/test_integration_all_rust.py b/tests/integration/test_integration_all_rust.py index f3e2fab14..8330d2065 100644 --- a/tests/integration/test_integration_all_rust.py +++ b/tests/integration/test_integration_all_rust.py @@ -27,18 +27,14 @@ import twisted.internet.base from cryptography.fernet import Fernet from jose import jws -from twisted.internet import reactor -from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.trial import unittest -from .async_push_test_client import AsyncPushTestClient +from .async_push_test_client import AsyncPushTestClient, ClientMessageType from .db import ( DynamoDBResource, base64url_encode, create_message_table_ddb, get_router_table, ) -from .push_test_client import ClientMessageType, PushTestClient app = bottle.Bottle() logging.basicConfig(level=logging.DEBUG) @@ -281,12 +277,12 @@ def max_logs(endpoint=None, conn=None): def max_logs_decorator(func): """Overwrite `max_endpoint_logs` with a given endpoint if it is specified.""" - def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): if endpoint is not None: self.max_endpoint_logs = endpoint if conn is not None: self.max_conn_logs = conn - return func(self, *args, **kwargs) + return await func(self, *args, **kwargs) return wrapper @@ -552,9 +548,6 @@ def setup_module(): else: setup_dynamodb() - pool = reactor.getThreadPool() - pool.adjustPoolsize(minthreads=pool.max) - setup_mock_server() log.debug(f"🐍🟢 Rust Log: {RUST_LOG}") @@ -584,7 +577,7 @@ def teardown_module(): kill_process(EP_SERVER) -class TestRustWebPush(unittest.TestCase): +class TestRustWebPush: """Test class for Rust Web Push.""" # Max log lines allowed to be emitted by each node type @@ -606,42 +599,42 @@ def host_endpoint(self, client): parsed = urlparse(list(client.channels.values())[0]) return "{}://{}".format(parsed.scheme, parsed.netloc) - @inlineCallbacks - def quick_register(self): + @pytest.mark.asyncio + async def quick_register(self): """Perform a connection initialization, which includes a new connection, `hello`, and channel registration. """ log.debug("🐍#### Connecting to ws://localhost:{}/".format(CONNECTION_PORT)) - client = PushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) - yield client.connect() - yield client.hello() - yield client.register() + client = AsyncPushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) + await client.connect() + await client.hello() + await client.register() log.debug("🐍 Connected") - returnValue(client) + return client - @inlineCallbacks - def shut_down(self, client=None): + @pytest.mark.asyncio + async def shut_down(self, client=None): """Shut down client.""" if client: - yield client.disconnect() + await client.disconnect() @property def _ws_url(self): return "ws://localhost:{}/".format(CONNECTION_PORT) - @inlineCallbacks @max_logs(conn=4) - def test_sentry_output_autoconnect(self): + @pytest.mark.asyncio + async def test_sentry_output_autoconnect(self): """Test sentry output for autoconnect.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") return # Ensure bad data doesn't throw errors - client = PushTestClient(self._ws_url) - yield client.connect() - yield client.hello() - yield client.send_bad_data() - yield self.shut_down(client) + client = AsyncPushTestClient(self._ws_url) + await client.connect() + await client.hello() + await client.send_bad_data() + await self.shut_down(client) # LogCheck does throw an error every time httpx.get(f"http://localhost:{CONNECTION_PORT}/v1/err/crit", timeout=30) @@ -653,17 +646,17 @@ def test_sentry_output_autoconnect(self): pass assert event1["exception"]["values"][0]["value"] == "LogCheck" - @inlineCallbacks @max_logs(endpoint=1) - def test_sentry_output_autoendpoint(self): + @pytest.mark.asyncio + async def test_sentry_output_autoendpoint(self): """Test sentry output for autoendpoint.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") return - client = yield self.quick_register() + client = await self.quick_register() endpoint = self.host_endpoint(client) - yield self.shut_down(client) + await self.shut_down(client) httpx.get(f"{endpoint}/__error__", timeout=30) # 2 events excpted: 1 from a panic and 1 from a returned Error @@ -676,7 +669,7 @@ def test_sentry_output_autoendpoint(self): assert sorted(values) == ["ERROR:Success", "LogCheck"] @max_logs(conn=4) - def test_no_sentry_output(self): + async def test_no_sentry_output(self): """Test for no Sentry output.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") @@ -692,132 +685,132 @@ def test_no_sentry_output(self): except Empty: pass - @inlineCallbacks - def test_hello_echo(self): + @pytest.mark.asyncio + async def test_hello_echo(self): """Test hello echo.""" - client = PushTestClient(self._ws_url) - yield client.connect() - result = yield client.hello() + client = AsyncPushTestClient(self._ws_url) + await client.connect() + result = await client.hello() assert result != {} assert result["use_webpush"] is True - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_hello_with_bad_prior_uaid(self): + @pytest.mark.asyncio + async def test_hello_with_bad_prior_uaid(self): """Test hello with bad prior uaid.""" non_uaid = uuid.uuid4().hex - client = PushTestClient(self._ws_url) - yield client.connect() - result = yield client.hello(uaid=non_uaid) + client = AsyncPushTestClient(self._ws_url) + await client.connect() + result = await client.hello(uaid=non_uaid) assert result != {} assert result["uaid"] != non_uaid assert result["use_webpush"] is True - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery(self): + @pytest.mark.asyncio + async def test_basic_delivery(self): """Test basic regular push message delivery.""" data = str(uuid.uuid4()) - client: PushTestClient = yield self.quick_register() - result = yield client.send_notification(data=data) + client: AsyncPushTestClient = await self.quick_register() + result = await client.send_notification(data=data) # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(bytes(data, "utf-8")) assert result["messageType"] == "notification" - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_topic_basic_delivery(self): + @pytest.mark.asyncio + async def test_topic_basic_delivery(self): """Test basic topic push message delivery.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - result = yield client.send_notification(data=data, topic="Inbox") + client = await self.quick_register() + result = await client.send_notification(data=data, topic="Inbox") # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data) assert result["messageType"] == "notification" - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_topic_replacement_delivery(self): + @pytest.mark.asyncio + async def test_topic_replacement_delivery(self): """Test that a topic push message replaces it's prior version.""" data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() - yield client.send_notification(data=data, topic="Inbox", status=201) - yield client.send_notification(data=data2, topic="Inbox", status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification() + client = await self.quick_register() + await client.disconnect() + await client.send_notification(data=data, topic="Inbox", status=201) + await client.send_notification(data=data2, topic="Inbox", status=201) + await client.connect() + await client.hello() + result = await client.get_notification() log.debug("get_notification result:", result) # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data2) assert result["messageType"] == "notification" - result = yield client.get_notification() + result = await client.get_notification() assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks + @pytest.mark.asyncio @max_logs(conn=4) - def test_topic_no_delivery_on_reconnect(self): + async def test_topic_no_delivery_on_reconnect(self): """Test that a topic message does not attempt to redeliver on reconnect.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() - yield client.send_notification(data=data, topic="Inbox", status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=10) + client = await self.quick_register() + await client.disconnect() + await client.send_notification(data=data, topic="Inbox", status=201) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=10) # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data) assert result["messageType"] == "notification" - yield client.ack(result["channelID"], result["version"]) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.ack(result["channelID"], result["version"]) + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification() assert result is None - yield client.disconnect() - yield client.connect() - yield client.hello() - yield self.shut_down(client) + await client.disconnect() + await client.connect() + await client.hello() + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_vapid(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_vapid(self): """Test delivery of a basic push message with a VAPID header.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload) - result = yield client.send_notification(data=data, vapid=vapid_info) + result = await client.send_notification(data=data, vapid=vapid_info) # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data) assert result["messageType"] == "notification" - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_invalid_vapid(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_invalid_vapid(self): """Test basic delivery with invalid VAPID header.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload, endpoint=self.host_endpoint(client)) vapid_info["crypto-key"] = "invalid" - yield client.send_notification(data=data, vapid=vapid_info, status=401) - yield self.shut_down(client) + await client.send_notification(data=data, vapid=vapid_info, status=401) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_invalid_vapid_exp(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_invalid_vapid_exp(self): """Test basic delivery of a push message with invalid VAPID `exp` assertion.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid( payload={ "aud": self.host_endpoint(client), @@ -826,27 +819,27 @@ def test_basic_delivery_with_invalid_vapid_exp(self): } ) vapid_info["crypto-key"] = "invalid" - yield client.send_notification(data=data, vapid=vapid_info, status=401) - yield self.shut_down(client) + await client.send_notification(data=data, vapid=vapid_info, status=401) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_invalid_vapid_auth(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_invalid_vapid_auth(self): """Test basic delivery with invalid VAPID auth.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid( payload=self.vapid_payload, endpoint=self.host_endpoint(client), ) vapid_info["auth"] = "" - yield client.send_notification(data=data, vapid=vapid_info, status=401) - yield self.shut_down(client) + await client.send_notification(data=data, vapid=vapid_info, status=401) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_invalid_signature(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_invalid_signature(self): """Test that a basic delivery with invalid VAPID signature fails.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid( payload={ "aud": self.host_endpoint(client), @@ -854,112 +847,112 @@ def test_basic_delivery_with_invalid_signature(self): } ) vapid_info["auth"] = vapid_info["auth"][:-3] + "bad" - yield client.send_notification(data=data, vapid=vapid_info, status=401) - yield self.shut_down(client) + await client.send_notification(data=data, vapid=vapid_info, status=401) + await self.shut_down(client) - @inlineCallbacks - def test_basic_delivery_with_invalid_vapid_ckey(self): + @pytest.mark.asyncio + async def test_basic_delivery_with_invalid_vapid_ckey(self): """Test that basic delivery with invalid VAPID crypto-key fails.""" data = str(uuid.uuid4()) - client = yield self.quick_register() + client = await self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload, endpoint=self.host_endpoint(client)) vapid_info["crypto-key"] = "invalid|" - yield client.send_notification(data=data, vapid=vapid_info, status=401) - yield self.shut_down(client) + await client.send_notification(data=data, vapid=vapid_info, status=401) + await self.shut_down(client) - @inlineCallbacks - def test_delivery_repeat_without_ack(self): + @pytest.mark.asyncio + async def test_delivery_repeat_without_ack(self): """Test that message delivery repeats if the client does not acknowledge messages.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels - yield client.send_notification(data=data, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.send_notification(data=data, status=201) + await client.connect() + await client.hello() + result = await client.get_notification() assert result is not None assert result["data"] == base64url_encode(data) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification() assert result != {} assert result["data"] == base64url_encode(data) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_repeat_delivery_with_disconnect_without_ack(self): + @pytest.mark.asyncio + async def test_repeat_delivery_with_disconnect_without_ack(self): """Test that message delivery repeats if the client disconnects without acknowledging the message. """ data = str(uuid.uuid4()) - client = yield self.quick_register() - result = yield client.send_notification(data=data) + client = await self.quick_register() + result = await client.send_notification(data=data) assert result != {} assert result["data"] == base64url_encode(data) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification() assert result != {} assert result["data"] == base64url_encode(data) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_multiple_delivery_repeat_without_ack(self): + @pytest.mark.asyncio + async def test_multiple_delivery_repeat_without_ack(self): """Test that the server will always try to deliver messages until the client acknowledges them. """ data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels - yield client.send_notification(data=data, status=201) - yield client.send_notification(data=data2, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.send_notification(data=data, status=201) + await client.send_notification(data=data2, status=201) + await client.connect() + await client.hello() + result = await client.get_notification() assert result != {} assert result["data"] in map(base64url_encode, [data, data2]) - result = yield client.get_notification() + result = await client.get_notification() assert result != {} assert result["data"] in map(base64url_encode, [data, data2]) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification() assert result != {} assert result["data"] in map(base64url_encode, [data, data2]) - result = yield client.get_notification() + result = await client.get_notification() assert result != {} assert result["data"] in map(base64url_encode, [data, data2]) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_topic_expired(self): + @pytest.mark.asyncio + async def test_topic_expired(self): """Test that the server will not deliver a message topic that has expired.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels - yield client.send_notification(data=data, ttl=1, topic="test", status=201) - yield client.sleep(2) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.send_notification(data=data, ttl=1, topic="test", status=201) + await client.sleep(2) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result is None - result = yield client.send_notification(data=data, topic="test") + result = await client.send_notification(data=data, topic="test") assert result != {} assert result["data"] == base64url_encode(data) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks + @pytest.mark.asyncio @max_logs(conn=4) - def test_multiple_delivery_with_single_ack(self): + async def test_multiple_delivery_with_single_ack(self): """Test that the server provides the right unacknowledged messages if the client only acknowledges one of the received messages. Note: the `data` fields are constructed so that they return @@ -967,44 +960,44 @@ def test_multiple_delivery_with_single_ack(self): """ data = b"\x16*\xec\xb4\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() data2 = b":\xd8^\xac\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels - yield client.send_notification(data=data, status=201) - yield client.send_notification(data=data2, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.send_notification(data=data, status=201) + await client.send_notification(data=data2, status=201) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result != {} assert result["data"] == base64url_encode(data) - result2 = yield client.get_notification(timeout=0.5) + result2 = await client.get_notification(timeout=0.5) assert result2 != {} assert result2["data"] == base64url_encode(data2) - yield client.ack(result["channelID"], result["version"]) + await client.ack(result["channelID"], result["version"]) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result != {} assert result["data"] == base64url_encode(data) assert result["messageType"] == "notification" - result2 = yield client.get_notification() + result2 = await client.get_notification() assert result2 != {} assert result2["data"] == base64url_encode(data2) - yield client.ack(result["channelID"], result["version"]) - yield client.ack(result2["channelID"], result2["version"]) + await client.ack(result["channelID"], result["version"]) + await client.ack(result2["channelID"], result2["version"]) # Verify no messages are delivered - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_multiple_delivery_with_multiple_ack(self): + @pytest.mark.asyncio + async def test_multiple_delivery_with_multiple_ack(self): """Test that the server provides the no additional unacknowledged messages if the client acknowledges both of the received messages. Note: the `data` fields are constructed so that they return @@ -1012,100 +1005,100 @@ def test_multiple_delivery_with_multiple_ack(self): """ data = b"\x16*\xec\xb4\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() # "FirstMessage" data2 = b":\xd8^\xac\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() # "OtherMessage" - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels - yield client.send_notification(data=data, status=201) - yield client.send_notification(data=data2, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.send_notification(data=data, status=201) + await client.send_notification(data=data2, status=201) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result != {} assert result["data"] in map(base64url_encode, [data, data2]) log.debug("🟩🟩 Result:: {}".format(result["data"])) - result2 = yield client.get_notification() + result2 = await client.get_notification() assert result2 != {} assert result2["data"] in map(base64url_encode, [data, data2]) - yield client.ack(result2["channelID"], result2["version"]) - yield client.ack(result["channelID"], result["version"]) + await client.ack(result2["channelID"], result2["version"]) + await client.ack(result["channelID"], result["version"]) - yield client.disconnect() - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.disconnect() + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_no_delivery_to_unregistered(self): + @pytest.mark.asyncio + async def test_no_delivery_to_unregistered(self): """Test that the server does not try to deliver to unregistered channel IDs.""" data = str(uuid.uuid4()) - client: PushTestClient = yield self.quick_register() + client: AsyncPushTestClient = await self.quick_register() assert client.channels chan = list(client.channels.keys())[0] - result = yield client.send_notification(data=data) + result = await client.send_notification(data=data) assert result["channelID"] == chan assert result["data"] == base64url_encode(data) - yield client.ack(result["channelID"], result["version"]) + await client.ack(result["channelID"], result["version"]) - yield client.unregister(chan) - result = yield client.send_notification(data=data, status=410) + await client.unregister(chan) + result = await client.send_notification(data=data, status=410) # Verify cache-control assert client.notif_response.headers.get("Cache-Control") == "max-age=86400" assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_ttl_0_connected(self): + @pytest.mark.asyncio + async def test_ttl_0_connected(self): """Test that a message with a TTL=0 is delivered to a client that is actively connected.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - result = yield client.send_notification(data=data, ttl=0) + client = await self.quick_register() + result = await client.send_notification(data=data, ttl=0) assert result is not None # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data) assert result["messageType"] == "notification" - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_ttl_0_not_connected(self): + @pytest.mark.asyncio + async def test_ttl_0_not_connected(self): """Test that a message with a TTL=0 and a recipient client that is not connected, is not delivered when the client reconnects. """ data = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() - yield client.send_notification(data=data, ttl=0, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + client = await self.quick_register() + await client.disconnect() + await client.send_notification(data=data, ttl=0, status=201) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_ttl_expired(self): + @pytest.mark.asyncio + async def test_ttl_expired(self): """Test that messages with a TTL that has expired are not delivered to a recipient client. """ data = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() - yield client.send_notification(data=data, ttl=1, status=201) + client = await self.quick_register() + await client.disconnect() + await client.send_notification(data=data, ttl=1, status=201) time.sleep(1) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=0.5) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=0.5) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks + @pytest.mark.asyncio @max_logs(endpoint=28) - def test_ttl_batch_expired_and_good_one(self): + async def test_ttl_batch_expired_and_good_one(self): """Test that if a batch of messages are received while the recipient is offline, only messages that have not expired are sent to the recipient. This test checks if the latest pending message is not expired. @@ -1113,30 +1106,30 @@ def test_ttl_batch_expired_and_good_one(self): data = str(uuid.uuid4()).encode() data2 = base64.urlsafe_b64decode("0012") + str(uuid.uuid4()).encode() print(data2) - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() for x in range(0, 12): prefix = base64.urlsafe_b64decode("{:04d}".format(x)) - yield client.send_notification(data=prefix + data, ttl=1, status=201) + await client.send_notification(data=prefix + data, ttl=1, status=201) - yield client.send_notification(data=data2, status=201) + await client.send_notification(data=data2, status=201) time.sleep(1) - yield client.connect() - yield client.hello() - result = yield client.get_notification(timeout=4) + await client.connect() + await client.hello() + result = await client.get_notification(timeout=4) assert result is not None # the following presumes that only `salt` is padded. clean_header = client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(data2) assert result["messageType"] == "notification" - result = yield client.get_notification(timeout=0.5) + result = await client.get_notification(timeout=0.5) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks @max_logs(endpoint=28) - def test_ttl_batch_partly_expired_and_good_one(self): + @pytest.mark.asyncio + async def test_ttl_batch_partly_expired_and_good_one(self): """Test that if a batch of messages are received while the recipient is offline, only messages that have not expired are sent to the recipient. This test checks if there is an equal mix of expired and unexpired messages. @@ -1144,105 +1137,105 @@ def test_ttl_batch_partly_expired_and_good_one(self): data = str(uuid.uuid4()) data1 = str(uuid.uuid4()) data2 = str(uuid.uuid4()) - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() for x in range(0, 6): - yield client.send_notification(data=data, status=201) + await client.send_notification(data=data, status=201) for x in range(0, 6): - yield client.send_notification(data=data1, ttl=1, status=201) + await client.send_notification(data=data1, ttl=1, status=201) - yield client.send_notification(data=data2, status=201) + await client.send_notification(data=data2, status=201) time.sleep(1) - yield client.connect() - yield client.hello() + await client.connect() + await client.hello() # Pull out and ack the first for x in range(0, 6): - result = yield client.get_notification(timeout=4) + result = await client.get_notification(timeout=4) assert result is not None assert result["data"] == base64url_encode(data) - yield client.ack(result["channelID"], result["version"]) + await client.ack(result["channelID"], result["version"]) # Should have one more that is data2, this will only arrive if the # other six were acked as that hits the batch size - result = yield client.get_notification(timeout=4) + result = await client.get_notification(timeout=4) assert result is not None assert result["data"] == base64url_encode(data2) # No more - result = yield client.get_notification() + result = await client.get_notification() assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_message_without_crypto_headers(self): + @pytest.mark.asyncio + async def test_message_without_crypto_headers(self): """Test that a message without crypto headers, but has data is not accepted.""" data = str(uuid.uuid4()) - client = yield self.quick_register() - result = yield client.send_notification(data=data, use_header=False, status=400) + client = await self.quick_register() + result = await client.send_notification(data=data, use_header=False, status=400) assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_empty_message_without_crypto_headers(self): + @pytest.mark.asyncio + async def test_empty_message_without_crypto_headers(self): """Test that a message without crypto headers, and does not have data, is accepted.""" - client = yield self.quick_register() - result = yield client.send_notification(use_header=False) + client = await self.quick_register() + result = await client.send_notification(use_header=False) assert result is not None assert result["messageType"] == "notification" assert "headers" not in result assert "data" not in result - yield client.ack(result["channelID"], result["version"]) + await client.ack(result["channelID"], result["version"]) - yield client.disconnect() - yield client.send_notification(use_header=False, status=201) - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.disconnect() + await client.send_notification(use_header=False, status=201) + await client.connect() + await client.hello() + result = await client.get_notification() assert result is not None assert "headers" not in result assert "data" not in result - yield client.ack(result["channelID"], result["version"]) + await client.ack(result["channelID"], result["version"]) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_empty_message_with_crypto_headers(self): + @pytest.mark.asyncio + async def test_empty_message_with_crypto_headers(self): """Test that an empty message with crypto headers does not send either `headers` or `data` as part of the incoming websocket `notification` message. """ - client = yield self.quick_register() - result = yield client.send_notification() + client = await self.quick_register() + result = await client.send_notification() assert result is not None assert result["messageType"] == "notification" assert "headers" not in result assert "data" not in result - result2 = yield client.send_notification() + result2 = await client.send_notification() # We shouldn't store headers for blank messages. assert result2 is not None assert result2["messageType"] == "notification" assert "headers" not in result2 assert "data" not in result2 - yield client.ack(result["channelID"], result["version"]) - yield client.ack(result2["channelID"], result2["version"]) + await client.ack(result["channelID"], result["version"]) + await client.ack(result2["channelID"], result2["version"]) - yield client.disconnect() - yield client.send_notification(status=201) - yield client.connect() - yield client.hello() - result3 = yield client.get_notification() + await client.disconnect() + await client.send_notification(status=201) + await client.connect() + await client.hello() + result3 = await client.get_notification() assert result3 is not None assert "headers" not in result3 assert "data" not in result3 - yield client.ack(result3["channelID"], result3["version"]) + await client.ack(result3["channelID"], result3["version"]) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_big_message(self): + @pytest.mark.asyncio + async def test_big_message(self): """Test that we accept a large message. Using pywebpush I encoded a 4096 block @@ -1252,13 +1245,13 @@ def test_big_message(self): """ import base64 - client = yield self.quick_register() + client = await self.quick_register() bulk = "".join( random.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(0, 4216) ) data = base64.urlsafe_b64encode(bytes(bulk, "utf-8")) - result = yield client.send_notification(data=data) + result = await client.send_notification(data=data) dd = result.get("data") dh = base64.b64decode(dd + "==="[: len(dd) % 4]) assert dh == data @@ -1273,25 +1266,25 @@ def test_big_message(self): # Skipping test for now. # Note: dict_keys obj was not iterable, corrected by converting to iterable. - @inlineCallbacks - def test_delete_saved_notification(self): + @pytest.mark.asyncio + async def test_delete_saved_notification(self): """Test deleting a saved notification in client server.""" - client = yield self.quick_register() - yield client.disconnect() + client = await self.quick_register() + await client.disconnect() assert client.channels chan = list(client.channels.keys())[0] - yield client.send_notification() + await client.send_notification() status_code: int = 204 - delete_resp = yield client.delete_notification(chan, status=status_code) + delete_resp = await client.delete_notification(chan, status=status_code) assert delete_resp.status_code == status_code - yield client.connect() - yield client.hello() - result = yield client.get_notification() + await client.connect() + await client.hello() + result = await client.get_notification() assert result is None - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_with_key(self): + @pytest.mark.asyncio + async def test_with_key(self): """Test getting a locked subscription with a valid VAPID public key.""" private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) claims = { @@ -1302,76 +1295,76 @@ def test_with_key(self): vapid = _get_vapid(private_key, claims) pk_hex = vapid["crypto-key"] chid = str(uuid.uuid4()) - client = PushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) - yield client.connect() - yield client.hello() - yield client.register(channel_id=chid, key=pk_hex) + client = AsyncPushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) + await client.connect() + await client.hello() + await client.register(channel_id=chid, key=pk_hex) # Send an update with a properly formatted key. - yield client.send_notification(vapid=vapid) + await client.send_notification(vapid=vapid) # now try an invalid key. new_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) vapid = _get_vapid(new_key, claims) - yield client.send_notification(vapid=vapid, status=401) + await client.send_notification(vapid=vapid, status=401) - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_with_bad_key(self): + @pytest.mark.asyncio + async def test_with_bad_key(self): """Test that a message registration request with bad VAPID public key is rejected.""" chid = str(uuid.uuid4()) - client = PushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) - yield client.connect() - yield client.hello() - result = yield client.register(channel_id=chid, key="af1883%&!@#*(", status=400) + client = AsyncPushTestClient("ws://localhost:{}/".format(CONNECTION_PORT)) + await client.connect() + await client.hello() + result = await client.register(channel_id=chid, key="af1883%&!@#*(", status=400) assert result["status"] == 400 - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks + @pytest.mark.asyncio @max_logs(endpoint=44) - def test_msg_limit(self): + async def test_msg_limit(self): """Test that sent messages that are larger than our size limit are rejected.""" - client = yield self.quick_register() + client = await self.quick_register() uaid = client.uaid - yield client.disconnect() + await client.disconnect() for i in range(MSG_LIMIT + 1): - yield client.send_notification(status=201) - yield client.connect() - yield client.hello() + await client.send_notification(status=201) + await client.connect() + await client.hello() assert client.uaid == uaid for i in range(MSG_LIMIT): - result = yield client.get_notification() + result = await client.get_notification() assert result is not None, f"failed at {i}" - yield client.ack(result["channelID"], result["version"]) - yield client.disconnect() - yield client.connect() - yield client.hello() + await client.ack(result["channelID"], result["version"]) + await client.disconnect() + await client.connect() + await client.hello() assert client.uaid != uaid - yield self.shut_down(client) - - @inlineCallbacks - def test_can_ping(self): - """Test that the client can send an active ping message and get a valid response.""" - client = yield self.quick_register() - result = yield client.ping() - assert result == "{}" - assert client.ws.connected - try: - yield client.ping() - except AssertionError: - # pinging too quickly should disconnect without a valid ping - # repsonse - pass - assert not client.ws.connected - yield self.shut_down(client) + await self.shut_down(client) - @inlineCallbacks - def test_internal_endpoints(self): + # @pytest.mark.asyncio + # async def test_can_ping(self): + # """Test that the client can send an active ping message and get a valid response.""" + # client = await self.quick_register() + # result = await client.ping() + # assert result == "{}" + # assert client.ws.open + # try: + # await client.ping() + # except AssertionError: + # # pinging too quickly should disconnect without a valid ping + # # repsonse + # pass + # assert not client.ws.open + # await self.shut_down(client) + + @pytest.mark.asyncio + async def test_internal_endpoints(self): """Ensure an internal router endpoint isn't exposed on the public CONNECTION_PORT""" - client = yield self.quick_register() + client = await self.quick_register() parsed = ( urlparse(self._ws_url)._replace(scheme="http")._replace(path=f"/notif/{client.uaid}") ) @@ -1411,7 +1404,7 @@ async def quick_register(self, connection_port=None): await client.connect() await client.hello() await client.register() - returnValue(client) + return client @pytest.mark.asyncio async def shut_down(self, client=None):