diff --git a/scripts/fernet_key.py b/scripts/fernet_key.py index a9667d32b..0793c5ef4 100644 --- a/scripts/fernet_key.py +++ b/scripts/fernet_key.py @@ -1,6 +1,4 @@ -""" -A command-line utility that generates endpoint encryption keys. -""" +"""A command-line utility that generates endpoint encryption keys.""" from __future__ import print_function from cryptography.fernet import Fernet diff --git a/scripts/gendpoint.py b/scripts/gendpoint.py index 17bbb1d16..cc05e86ef 100644 --- a/scripts/gendpoint.py +++ b/scripts/gendpoint.py @@ -1,5 +1,7 @@ +"""Module to process configuration from cli arguments and environment +variables. +""" #! env python3 - import argparse import os @@ -35,6 +37,7 @@ def config(env_args: os._Environ) -> argparse.Namespace: def main(): + """Process environment arguments/variables and set key.""" args = config(os.environ) if isinstance(args.key, list): key = args.key[0] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e69de29bb..72ebee739 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""__init__.py for integration tests.""" diff --git a/tests/integration/db.py b/tests/integration/db.py index 3b1f3169c..034865962 100644 --- a/tests/integration/db.py +++ b/tests/integration/db.py @@ -59,6 +59,8 @@ class ItemNotFound(Exception): class DynamoDBResource(threading.local): + """DynamoDBResource class subclassing threading.local""" + def __init__(self, **kwargs): conf = kwargs if not conf.get("endpoint_url"): @@ -87,7 +89,7 @@ def __getattr__(self, name): def get_latest_message_tablenames( self, prefix: str = "message", previous: int = 1 ) -> list[str]: - """Fetches the name of the last message table""" + """Fetch the name of the last message table.""" client = self._resource.meta.client paginator = client.get_paginator("list_tables") tables = [] @@ -102,11 +104,13 @@ def get_latest_message_tablenames( return tables[0 - previous :] def get_latest_message_tablename(self, prefix: str = "message") -> str: - """Fetches the name of the last message table""" + """Fetch the name of the last message table.""" return self.get_latest_message_tablenames(prefix=prefix, previous=1)[0] class DynamoDBTable(threading.local): + """DynamoDBTable class.""" + def __init__(self, ddb_resource: DynamoDBResource, *args, **kwargs) -> None: self._table = ddb_resource.Table(*args, **kwargs) @@ -118,13 +122,24 @@ def __getattr__(self, name): def generate_hash(key: bytes, payload: bytes) -> str: """Generate a HMAC for the uaid using the secret - :returns: HMAC hash and the nonce used as a tuple (nonce, hash). + :param key: key + :type: bytes + :param payload: payload + :type: bytes + :returns: A hexadecimal string of the HMAC hash and the nonce, used as a tuple (nonce, hash) + :rtype: str """ h = hmac.new(key=key, msg=payload, digestmod=hashlib.sha256) return h.hexdigest() def normalize_id(ident: uuid.UUID | str) -> str: + """Normalize and return ID as string + + :param ident: uuid.UUID or str identifier + :returns: string representation of UUID + :raises ValueError: raises an exception if UUID is invalid + """ if isinstance(ident, uuid.UUID): return str(ident) try: @@ -134,9 +149,10 @@ def normalize_id(ident: uuid.UUID | str) -> str: def base64url_encode(value: bytes | str) -> str: + """Encode an unpadded Base64 URL-encoded string per RFC 7515.""" if isinstance(value, str): value = bytes(value, "utf-8") - """Encodes an unpadded Base64 URL-encoded string per RFC 7515.""" + return base64.urlsafe_b64encode(value).strip(b"=").decode("utf-8") @@ -155,9 +171,7 @@ def base64url_encode(value: bytes | str) -> str: def get_month(delta: int = 0) -> datetime.date: - """Basic helper function to get a datetime.date object iterations months - ahead/behind of now. - """ + """Get a datetime.date object iterations months ahead/behind of now.""" new = last = datetime.date.today() # Move until we hit a new month, this avoids having to manually # check year changes as we push forward or backward since the Python @@ -308,7 +322,8 @@ def get_router_table( def track_provisioned(func: Callable[..., T]) -> Callable[..., T]: """Tracks provisioned exceptions and increments a metric for them named - after the function decorated""" + after the function decorated. + """ @wraps(func) def wrapper(self, *args, **kwargs): diff --git a/tests/integration/test_integration_all_rust.py b/tests/integration/test_integration_all_rust.py index 2bf35df58..cafa31e9d 100644 --- a/tests/integration/test_integration_all_rust.py +++ b/tests/integration/test_integration_all_rust.py @@ -1,6 +1,4 @@ -""" -Rust Connection and Endpoint Node Integration Tests -""" +"""Rust Connection and Endpoint Node Integration Tests.""" import base64 import copy @@ -95,6 +93,7 @@ def get_free_port() -> int: + """Get free port.""" port: int s = socket.socket(socket.AF_INET, type=socket.SOCK_STREAM) s.bind(("localhost", 0)) @@ -110,6 +109,7 @@ def get_free_port() -> int: def get_db_settings() -> str | dict[str, str | int | float] | None: + """Get database settings.""" env_var = os.environ.get("DB_SETTINGS") if env_var: if os.path.isfile(env_var): @@ -212,6 +212,7 @@ def __init__(self, url) -> None: } def __getattribute__(self, name: str): + """Turn functions into deferToThread functions.""" # Python fun to turn all functions into deferToThread functions f = object.__getattribute__(self, name) if name.startswith("__"): @@ -223,6 +224,7 @@ def __getattribute__(self, name: str): return f def connect(self, connection_port: int | None = None): + """Connect.""" url = self.url if connection_port: # pragma: nocover url = "ws://localhost:{}/".format(connection_port) @@ -230,6 +232,7 @@ def connect(self, connection_port: int | None = None): return self.ws.connected if self.ws else None def hello(self, uaid: str | None = None, services: list[str] | None = None): + """Hello verification.""" if not self.ws: raise Exception("WebSocket client not available as expected") @@ -261,6 +264,7 @@ def hello(self, uaid: str | None = None, services: list[str] | None = None): return result def broadcast_subscribe(self, services: list[str]): + """Broadcast subscribe.""" if not self.ws: raise Exception("WebSocket client not available as expected") @@ -269,6 +273,7 @@ def broadcast_subscribe(self, services: list[str]): self.ws.send(msg) def register(self, chid: str | None = None, key=None, status=200): + """Register.""" if not self.ws: raise Exception("WebSocket client not available as expected") @@ -286,6 +291,7 @@ def register(self, chid: str | None = None, key=None, status=200): return result def unregister(self, chid): + """Unregister.""" msg = json.dumps(dict(messageType="unregister", channelID=chid)) log.debug("Send: %s", msg) self.ws.send(msg) @@ -294,6 +300,7 @@ def unregister(self, chid): return result def delete_notification(self, channel, message=None, status=204): + """Delete notification.""" messages = self.messages[channel] if not message: message = random.choice(messages) @@ -317,6 +324,7 @@ def send_notification( topic=None, headers=None, ): + """Send notification.""" if not channel: channel = random.choice(list(self.channels.keys())) @@ -375,6 +383,7 @@ def send_notification( return resp def get_notification(self, timeout=1): + """Get notification.""" orig_timeout = self.ws.gettimeout() self.ws.settimeout(timeout) try: @@ -387,6 +396,7 @@ def get_notification(self, timeout=1): self.ws.settimeout(orig_timeout) def get_broadcast(self, timeout=1): # pragma: nocover + """Get broadcast.""" orig_timeout = self.ws.gettimeout() self.ws.settimeout(timeout) try: @@ -402,6 +412,7 @@ def get_broadcast(self, timeout=1): # pragma: nocover self.ws.settimeout(orig_timeout) def ping(self): + """Test ping.""" log.debug("Send: %s", "{}") self.ws.send("{}") result = self.ws.recv() @@ -410,6 +421,7 @@ def ping(self): return result def ack(self, channel, version): + """Acknowledge message send.""" msg = json.dumps( dict( messageType="ack", @@ -420,13 +432,15 @@ def ack(self, channel, version): self.ws.send(msg) def disconnect(self): + """Disconnect""" self.ws.close() def sleep(self, duration: int): # pragma: nocover + """Sleep wrapper function.""" time.sleep(duration) def wait_for(self, func): - """Waits several seconds for a function to return True""" + """Wait several seconds for a function to return True""" times = 0 while not func(): # pragma: nocover time.sleep(1) @@ -440,6 +454,7 @@ def _get_vapid( payload: dict[str, str | int] | None = None, endpoint: str | None = None, ) -> dict[str, str | bytes]: + """Get vapid key.""" global CONNECTION_CONFIG if endpoint is None: @@ -465,12 +480,14 @@ def _get_vapid( def enqueue_output(out, queue): + """Enqueue output.""" for line in iter(out.readline, b""): queue.put(line) out.close() def print_lines_in_queues(queues, prefix): + """Print lines in queues to stdout.""" for queue in queues: is_empty = False while not is_empty: @@ -516,6 +533,8 @@ def max_logs(endpoint=None, conn=None): """ def max_logs_decorator(func): + """Decorate max_logs.""" + def wrapper(self, *args, **kwargs): if endpoint is not None: self.max_endpoint_logs = endpoint @@ -530,24 +549,30 @@ def wrapper(self, *args, **kwargs): @app.get("/v1/broadcasts") def broadcast_handler(): + """Broadcast handler setup.""" assert bottle.request.headers["Authorization"] == MOCK_MP_TOKEN MOCK_MP_POLLED.set() return dict(broadcasts=MOCK_MP_SERVICES) @app.post("/api/1/envelope/") -def sentry_handler(): +def sentry_handler() -> dict[str, str]: + """Sentry handler configuration.""" headers, item_headers, payload = bottle.request.body.read().splitlines() MOCK_SENTRY_QUEUE.put(json.loads(payload)) return {"id": "fc6d8c0c43fc4630ad850ee518f1b9d0"} class CustomClient(Client): + """Custom Client for testing.""" + def send_bad_data(self): + """Set `bad-data`""" self.ws.send("bad-data") def kill_process(process): + """Kill child processes.""" # This kinda sucks, but its the only way to nuke the child procs if process is None: return @@ -559,6 +584,7 @@ def kill_process(process): def get_rust_binary_path(binary): + """Get Rust binary path.""" global STRICT_LOG_COUNTS rust_bin = root_dir + "/target/release/{}".format(binary) @@ -578,6 +604,7 @@ def get_rust_binary_path(binary): def write_config_to_env(config, prefix): + """Write configurations to env.""" for key, val in config.items(): new_key = prefix + key log.debug("✍ config {} => {}".format(new_key, val)) @@ -585,6 +612,7 @@ def write_config_to_env(config, prefix): def capture_output_to_queue(output_stream): + """Capture output to log queue.""" log_queue = Queue() t = Thread(target=enqueue_output, args=(output_stream, log_queue)) t.daemon = True # thread dies with the program @@ -593,6 +621,7 @@ def capture_output_to_queue(output_stream): def setup_bt(): + """Set up BigTable emulator.""" global BT_PROCESS, BT_DB_SETTINGS log.debug("🐍🟢 Starting bigtable emulator") BT_PROCESS = subprocess.Popen("gcloud beta emulators bigtable start".split(" ")) @@ -617,6 +646,7 @@ def setup_bt(): def setup_dynamodb(): + """Set up DynamoDB.""" global DDB_PROCESS log.debug("🐍🟢 Starting dynamodb") @@ -643,6 +673,7 @@ def setup_dynamodb(): def setup_mock_server(): + """Set up mock server.""" global MOCK_SERVER_THREAD MOCK_SERVER_THREAD = Thread(target=app.run, kwargs=dict(port=MOCK_SERVER_PORT, debug=True)) @@ -654,6 +685,7 @@ def setup_mock_server(): def setup_connection_server(connection_binary): + """Set up connection server from config.""" global CN_SERVER, BT_PROCESS, DDB_PROCESS # NOTE: @@ -688,6 +720,7 @@ def setup_connection_server(connection_binary): def setup_megaphone_server(connection_binary): + """Set up megaphone server from configuration.""" global CN_MP_SERVER url = os.getenv("AUTOPUSH_MP_SERVER") @@ -710,6 +743,7 @@ def setup_megaphone_server(connection_binary): def setup_endpoint_server(): + """Set up endpoint server from configuration.""" global CONNECTION_CONFIG, EP_SERVER, BT_PROCESS # Set up environment @@ -749,6 +783,9 @@ def setup_endpoint_server(): def setup_module(): + """Set up module including BigTable or Dynamo + and connection, endpoint and megaphone servers. + """ global CN_SERVER, CN_QUEUES, CN_MP_SERVER, MOCK_SERVER_THREAD, STRICT_LOG_COUNTS, RUST_LOG if "SKIP_INTEGRATION" in os.environ: # pragma: nocover @@ -781,6 +818,7 @@ def setup_module(): def teardown_module(): + """Teardown module for dynamo, bigtable, and servers.""" if DDB_PROCESS: os.unsetenv("AWS_LOCAL_DYNAMODB") log.debug("🐍🔴 Stopping dynamodb") @@ -798,6 +836,8 @@ def teardown_module(): class TestRustWebPush(unittest.TestCase): + """Test class for Rust Web Push.""" + # Max log lines allowed to be emitted by each node type max_endpoint_logs = 8 max_conn_logs = 3 @@ -807,16 +847,19 @@ class TestRustWebPush(unittest.TestCase): } def tearDown(self): + """Tear down and log processing.""" process_logs(self) while not MOCK_SENTRY_QUEUE.empty(): MOCK_SENTRY_QUEUE.get_nowait() def host_endpoint(self, client): + """Return host endpoint.""" parsed = urlparse(list(client.channels.values())[0]) return "{}://{}".format(parsed.scheme, parsed.netloc) @inlineCallbacks def quick_register(self): + """Quick register.""" log.debug("🐍#### Connecting to ws://localhost:{}/".format(CONNECTION_PORT)) client = Client("ws://localhost:{}/".format(CONNECTION_PORT)) yield client.connect() @@ -827,6 +870,7 @@ def quick_register(self): @inlineCallbacks def shut_down(self, client=None): + """Shut down client.""" if client: yield client.disconnect() @@ -837,6 +881,7 @@ def _ws_url(self): @inlineCallbacks @max_logs(conn=4) def test_sentry_output_autoconnect(self): + """Test sentry output for autoconnect.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") return @@ -860,6 +905,7 @@ def test_sentry_output_autoconnect(self): @inlineCallbacks @max_logs(endpoint=1) def test_sentry_output_autoendpoint(self): + """Test sentry output for autoendpoint.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") return @@ -880,6 +926,7 @@ def test_sentry_output_autoendpoint(self): @max_logs(conn=4) def test_no_sentry_output(self): + """Test for no Sentry output.""" if os.getenv("SKIP_SENTRY"): SkipTest("Skipping sentry test") return @@ -896,6 +943,7 @@ def test_no_sentry_output(self): @inlineCallbacks def test_hello_echo(self): + """Test hello echo.""" client = Client(self._ws_url) yield client.connect() result = yield client.hello() @@ -905,6 +953,7 @@ def test_hello_echo(self): @inlineCallbacks def test_hello_with_bad_prior_uaid(self): + """Test hello with bard prior uaid.""" non_uaid = uuid.uuid4().hex client = Client(self._ws_url) yield client.connect() @@ -916,6 +965,7 @@ def test_hello_with_bad_prior_uaid(self): @inlineCallbacks def test_basic_delivery(self): + """Test basic delivery.""" data = str(uuid.uuid4()) client: Client = yield self.quick_register() result = yield client.send_notification(data=data) @@ -928,6 +978,7 @@ def test_basic_delivery(self): @inlineCallbacks def test_topic_basic_delivery(self): + """Test topic basic delivery.""" data = str(uuid.uuid4()) client = yield self.quick_register() result = yield client.send_notification(data=data, topic="Inbox") @@ -940,6 +991,7 @@ def test_topic_basic_delivery(self): @inlineCallbacks def test_topic_replacement_delivery(self): + """Test topic replacement delivery.""" data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) client = yield self.quick_register() @@ -962,6 +1014,7 @@ def test_topic_replacement_delivery(self): @inlineCallbacks @max_logs(conn=4) def test_topic_no_delivery_on_reconnect(self): + """Test topic no delivery on reconnect.""" data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() @@ -987,6 +1040,7 @@ def test_topic_no_delivery_on_reconnect(self): @inlineCallbacks def test_basic_delivery_with_vapid(self): + """Test basic delivery with vapid.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload) @@ -1000,6 +1054,7 @@ def test_basic_delivery_with_vapid(self): @inlineCallbacks def test_basic_delivery_with_invalid_vapid(self): + """Test basic delivery with invalid vapid.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload, endpoint=self.host_endpoint(client)) @@ -1009,6 +1064,7 @@ def test_basic_delivery_with_invalid_vapid(self): @inlineCallbacks def test_basic_delivery_with_invalid_vapid_exp(self): + """Test basic delivery with invalid vapid exp.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid( @@ -1024,6 +1080,7 @@ def test_basic_delivery_with_invalid_vapid_exp(self): @inlineCallbacks def test_basic_delivery_with_invalid_vapid_auth(self): + """Test basic delivery with invalid vapid auth.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid( @@ -1036,6 +1093,7 @@ def test_basic_delivery_with_invalid_vapid_auth(self): @inlineCallbacks def test_basic_delivery_with_invalid_signature(self): + """Test basic delivery with invalid signature.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid( @@ -1050,6 +1108,7 @@ def test_basic_delivery_with_invalid_signature(self): @inlineCallbacks def test_basic_delivery_with_invalid_vapid_ckey(self): + """Test basic delivery with invalid vapid ckey.""" data = str(uuid.uuid4()) client = yield self.quick_register() vapid_info = _get_vapid(payload=self.vapid_payload, endpoint=self.host_endpoint(client)) @@ -1059,6 +1118,7 @@ def test_basic_delivery_with_invalid_vapid_ckey(self): @inlineCallbacks def test_delivery_repeat_without_ack(self): + """Test delivery repeat without ack.""" data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() @@ -1080,6 +1140,7 @@ def test_delivery_repeat_without_ack(self): @inlineCallbacks def test_repeat_delivery_with_disconnect_without_ack(self): + """Test repeat delivery with disconnect without ack.""" data = str(uuid.uuid4()) client = yield self.quick_register() result = yield client.send_notification(data=data) @@ -1095,6 +1156,7 @@ def test_repeat_delivery_with_disconnect_without_ack(self): @inlineCallbacks def test_multiple_delivery_repeat_without_ack(self): + """Test multiple delivery repeat without ack.""" data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) client = yield self.quick_register() @@ -1124,6 +1186,7 @@ def test_multiple_delivery_repeat_without_ack(self): @inlineCallbacks def test_topic_expired(self): + """Test topic expired.""" data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() @@ -1142,6 +1205,7 @@ def test_topic_expired(self): @inlineCallbacks @max_logs(conn=4) def test_multiple_delivery_with_single_ack(self): + """Test multiple delivery with single ack.""" data = b"\x16*\xec\xb4\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() data2 = b":\xd8^\xac\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() client = yield self.quick_register() @@ -1182,6 +1246,7 @@ def test_multiple_delivery_with_single_ack(self): @inlineCallbacks def test_multiple_delivery_with_multiple_ack(self): + """Test multiple delivery with multiple ack.""" data = b"\x16*\xec\xb4\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() # "FirstMessage" data2 = b":\xd8^\xac\xc7\xac\xb1\xa8\x1e" + str(uuid.uuid4()).encode() # "OtherMessage" client = yield self.quick_register() @@ -1210,6 +1275,7 @@ def test_multiple_delivery_with_multiple_ack(self): @inlineCallbacks def test_no_delivery_to_unregistered(self): + """Test no delivery to unregistered.""" data = str(uuid.uuid4()) client: Client = yield self.quick_register() assert client.channels @@ -1231,6 +1297,7 @@ def test_no_delivery_to_unregistered(self): @inlineCallbacks def test_ttl_0_connected(self): + """Test TTL 0 connected.""" data = str(uuid.uuid4()) client = yield self.quick_register() result = yield client.send_notification(data=data, ttl=0) @@ -1244,6 +1311,7 @@ def test_ttl_0_connected(self): @inlineCallbacks def test_ttl_0_not_connected(self): + """Test TTL 0 not connected.""" data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() @@ -1256,6 +1324,7 @@ def test_ttl_0_not_connected(self): @inlineCallbacks def test_ttl_expired(self): + """Test TTL expired.""" data = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() @@ -1270,6 +1339,7 @@ def test_ttl_expired(self): @inlineCallbacks @max_logs(endpoint=28) def test_ttl_batch_expired_and_good_one(self): + """Test TTL batch expired with one good result.""" data = str(uuid.uuid4()).encode() data2 = base64.urlsafe_b64decode("0012") + str(uuid.uuid4()).encode() print(data2) @@ -1297,6 +1367,7 @@ def test_ttl_batch_expired_and_good_one(self): @inlineCallbacks @max_logs(endpoint=28) def test_ttl_batch_partly_expired_and_good_one(self): + """Test TTL batch partly expired with one good result.""" data = str(uuid.uuid4()) data1 = str(uuid.uuid4()) data2 = str(uuid.uuid4()) @@ -1333,6 +1404,7 @@ def test_ttl_batch_partly_expired_and_good_one(self): @inlineCallbacks def test_message_without_crypto_headers(self): + """Test message without crypto headers.""" data = str(uuid.uuid4()) client = yield self.quick_register() result = yield client.send_notification(data=data, use_header=False, status=400) @@ -1341,6 +1413,7 @@ def test_message_without_crypto_headers(self): @inlineCallbacks def test_empty_message_without_crypto_headers(self): + """Test empty message without crypto headers.""" client = yield self.quick_register() result = yield client.send_notification(use_header=False) assert result is not None @@ -1363,6 +1436,7 @@ def test_empty_message_without_crypto_headers(self): @inlineCallbacks def test_empty_message_with_crypto_headers(self): + """Test empty message with crypto headers.""" client = yield self.quick_register() result = yield client.send_notification() assert result is not None @@ -1441,6 +1515,7 @@ def test_delete_saved_notification(self): @inlineCallbacks def test_with_key(self): + """Test with key.""" private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) claims = { "aud": "http://localhost:{}".format(ENDPOINT_PORT), @@ -1468,6 +1543,7 @@ def test_with_key(self): @inlineCallbacks def test_with_bad_key(self): + """Test with bad key.""" chid = str(uuid.uuid4()) client = Client("ws://localhost:{}/".format(CONNECTION_PORT)) yield client.connect() @@ -1480,6 +1556,7 @@ def test_with_bad_key(self): @inlineCallbacks @max_logs(endpoint=44) def test_msg_limit(self): + """Test message limit.""" self.skipTest("known broken") client = yield self.quick_register() uaid = client.uaid @@ -1501,6 +1578,7 @@ def test_msg_limit(self): @inlineCallbacks def test_can_ping(self): + """Test can ping.""" client = yield self.quick_register() yield client.ping() assert client.ws.connected @@ -1539,14 +1617,18 @@ def test_internal_endpoints(self): class TestRustWebPushBroadcast(unittest.TestCase): + """Test class for Rust Web Push Broadcast.""" + max_endpoint_logs = 4 max_conn_logs = 1 def tearDown(self): + """Tear down.""" process_logs(self) @inlineCallbacks def quick_register(self, connection_port=None): + """Connect and register client.""" conn_port = connection_port or MP_CONNECTION_PORT client = Client("ws://localhost:{}/".format(conn_port)) yield client.connect() @@ -1556,6 +1638,7 @@ def quick_register(self, connection_port=None): @inlineCallbacks def shut_down(self, client=None): + """Shut down client connection.""" if client: yield client.disconnect() @@ -1565,6 +1648,7 @@ def _ws_url(self): @inlineCallbacks def test_broadcast_update_on_connect(self): + """Test broadcast update on connect.""" global MOCK_MP_SERVICES MOCK_MP_SERVICES = {"kinto:123": "ver1"} MOCK_MP_POLLED.clear() @@ -1589,6 +1673,7 @@ def test_broadcast_update_on_connect(self): @inlineCallbacks def test_broadcast_update_on_connect_with_errors(self): + """Test broadcast update on connect with errors.""" global MOCK_MP_SERVICES MOCK_MP_SERVICES = {"kinto:123": "ver1"} MOCK_MP_POLLED.clear() @@ -1606,6 +1691,7 @@ def test_broadcast_update_on_connect_with_errors(self): @inlineCallbacks def test_broadcast_subscribe(self): + """Test broadcast subscribe.""" global MOCK_MP_SERVICES MOCK_MP_SERVICES = {"kinto:123": "ver1"} MOCK_MP_POLLED.clear() @@ -1634,6 +1720,7 @@ def test_broadcast_subscribe(self): @inlineCallbacks def test_broadcast_subscribe_with_errors(self): + """Test that broadcast returns expected errors.""" global MOCK_MP_SERVICES MOCK_MP_SERVICES = {"kinto:123": "ver1"} MOCK_MP_POLLED.clear() @@ -1656,6 +1743,7 @@ def test_broadcast_subscribe_with_errors(self): @inlineCallbacks def test_broadcast_no_changes(self): + """Test to ensure there are no changes from broadcast.""" global MOCK_MP_SERVICES MOCK_MP_SERVICES = {"kinto:123": "ver1"} MOCK_MP_POLLED.clear() diff --git a/tests/load/locustfiles/args.py b/tests/load/locustfiles/args.py index 96c4d2d64..def0911ee 100644 --- a/tests/load/locustfiles/args.py +++ b/tests/load/locustfiles/args.py @@ -1,3 +1,4 @@ +"""Load test arguments.""" # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. @@ -19,6 +20,7 @@ def parse_wait_time(val: str): raise ValueError("Invalid wait_time") -def float_or_int(val: str): +def float_or_int(val: str) -> int | float: + """Parse string value into float or integer.""" float_val: float = float(val) return int(float_val) if float_val.is_integer() else float_val diff --git a/tests/load/locustfiles/exceptions.py b/tests/load/locustfiles/exceptions.py index 0b9123c9a..ed4f022e8 100644 --- a/tests/load/locustfiles/exceptions.py +++ b/tests/load/locustfiles/exceptions.py @@ -1,3 +1,4 @@ +"""Custom exceptions for load tests.""" # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. diff --git a/tests/load/locustfiles/load.py b/tests/load/locustfiles/load.py index a9b7f332a..6af5b3b0d 100644 --- a/tests/load/locustfiles/load.py +++ b/tests/load/locustfiles/load.py @@ -27,7 +27,7 @@ def __init__(self, max_run_time: int, max_users: int): ) def calculate_users(self, run_time: int) -> int: - """Determined the number of active users given a run time. + """Determine the number of active users given a run time. Returns: int: The number of users diff --git a/tests/load/locustfiles/locustfile.py b/tests/load/locustfiles/locustfile.py index a75641ef3..66c007ca6 100644 --- a/tests/load/locustfiles/locustfile.py +++ b/tests/load/locustfiles/locustfile.py @@ -1,9 +1,8 @@ +"""Performance test module.""" # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -"""Performance test module.""" - import base64 import json import logging @@ -61,6 +60,8 @@ def _(environment, **kwargs): class AutopushUser(FastHttpUser): + """AutopushUser class.""" + REST_HEADERS: dict[str, str] = {"TTL": "60", "Content-Encoding": "aes128gcm"} WEBSOCKET_HEADERS: dict[str, str] = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:61.0) " @@ -79,14 +80,15 @@ def __init__(self, environment) -> None: self.ws_greenlet: Greenlet | None = None def wait_time(self): + """Return the autopush wait time.""" return self.environment.autopush_wait_time(self) def on_start(self) -> Any: - """Called when a User starts running.""" + """Call when a User starts running.""" self.ws_greenlet = gevent.spawn(self.connect) def on_stop(self) -> Any: - """Called when a User stops running.""" + """Call when a User stops running.""" if self.ws: for channel_id in self.channels.keys(): self.send_unregister(self.ws, channel_id) @@ -96,7 +98,7 @@ def on_stop(self) -> Any: gevent.kill(self.ws_greenlet) def on_ws_open(self, ws: WebSocket) -> None: - """Called when opening a WebSocket. + """Call when opening a WebSocket. Args: ws: WebSocket class object @@ -104,7 +106,7 @@ def on_ws_open(self, ws: WebSocket) -> None: self.send_hello(ws) def on_ws_message(self, ws: WebSocket, data: str) -> None: - """Called when received data from a WebSocket. + """Call when received data from a WebSocket. Args: ws: WebSocket class object @@ -121,7 +123,7 @@ def on_ws_message(self, ws: WebSocket, data: str) -> None: del self.channels[message.channelID] def on_ws_error(self, ws: WebSocket, error: Exception) -> None: - """Called when there is a WebSocket error or if an exception is raised in a WebSocket + """Call when there is a WebSocket error or if an exception is raised in a WebSocket callback function. Args: @@ -141,7 +143,7 @@ def on_ws_error(self, ws: WebSocket, error: Exception) -> None: def on_ws_close( self, ws: WebSocket, close_status_code: int | None, close_msg: str | None ) -> None: - """Called when closing a WebSocket. + """Call when closing a WebSocket. Args: ws: WebSocket class object @@ -153,7 +155,7 @@ def on_ws_close( @task(weight=98) def send_notification(self): - """Sends a notification to a registered endpoint while connected to Autopush.""" + """Send a notification to a registered endpoint while connected to Autopush.""" if not self.ws or not self.channels: logger.debug("Task 'send_notification' skipped.") return @@ -163,7 +165,7 @@ def send_notification(self): @task(weight=1) def subscribe(self): - """Subscribes a user to an Autopush channel.""" + """Subscribe a user to an Autopush channel.""" if not self.ws: logger.debug("Task 'subscribe' skipped.") return @@ -173,7 +175,7 @@ def subscribe(self): @task(weight=1) def unsubscribe(self): - """Unsubscribes a user from an Autopush channel.""" + """Unsubscribe a user from an Autopush channel.""" if not self.ws or not self.channels: logger.debug("Task 'unsubscribe' skipped.") return @@ -182,7 +184,7 @@ def unsubscribe(self): self.send_unregister(self.ws, channel_id) def connect(self) -> None: - """Creates the WebSocketApp that will run indefinitely.""" + """Create the WebSocketApp that will run indefinitely.""" if not self.host: raise LocustError("'host' value is unavailable.") diff --git a/tests/load/locustfiles/models.py b/tests/load/locustfiles/models.py index 5b36d98f2..49aa3ecaf 100644 --- a/tests/load/locustfiles/models.py +++ b/tests/load/locustfiles/models.py @@ -1,7 +1,6 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. - """Load test models module.""" from typing import Any, Literal