From 9a45aa75bb8877b60669fe3a4fc9507a0b71136a Mon Sep 17 00:00:00 2001 From: Sorin Sbarnea Date: Wed, 28 Aug 2024 14:06:18 +0100 Subject: [PATCH] type: make code respect mypy strict mode Fixes: #258 --- .pre-commit-config.yaml | 4 +-- .../event_filter/dashes_to_underscores.py | 5 +++- .../eda/plugins/event_filter/json_filter.py | 13 +++++----- extensions/eda/plugins/event_filter/noop.py | 4 ++- .../plugins/event_filter/normalize_keys.py | 9 ++++--- .../eda/plugins/event_source/alertmanager.py | 4 +-- .../plugins/event_source/aws_cloudtrail.py | 25 +++++++++++++------ .../eda/plugins/event_source/aws_sqs_queue.py | 4 +-- .../plugins/event_source/azure_service_bus.py | 6 ++--- extensions/eda/plugins/event_source/file.py | 8 +++--- .../eda/plugins/event_source/file_watch.py | 8 +++--- .../eda/plugins/event_source/generic.py | 12 +++++---- .../eda/plugins/event_source/journald.py | 2 +- extensions/eda/plugins/event_source/kafka.py | 6 ++--- .../eda/plugins/event_source/pg_listener.py | 10 ++++---- extensions/eda/plugins/event_source/range.py | 4 +-- extensions/eda/plugins/event_source/tick.py | 4 +-- .../eda/plugins/event_source/url_check.py | 4 +-- .../eda/plugins/event_source/webhook.py | 16 ++++++++---- plugins/module_utils/controller.py | 23 +++++++++++------ pyproject.toml | 8 +++--- requirements.txt | 2 +- tests/integration/conftest.py | 2 +- .../event_source_kafka/test_kafka_source.py | 3 ++- .../test_url_check_source.py | 9 ++++--- .../test_webhook_source.py | 17 ++++++++----- tests/integration/utils.py | 8 +++--- .../event_filter/test_insert_hosts_to_meta.py | 6 +++-- .../unit/event_filter/test_normalize_keys.py | 6 ++++- tests/unit/event_source/test_alertmanager.py | 2 +- .../event_source/test_azure_service_bus.py | 2 +- tests/unit/event_source/test_generic.py | 4 +-- tests/unit/event_source/test_kafka.py | 4 ++- tests/unit/event_source/test_pg_listener.py | 7 ++---- tests/unit/event_source/test_webhook.py | 2 +- 35 files changed, 152 insertions(+), 101 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 575abd09..89f0180f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -72,7 +72,7 @@ repos: - types-botocore - types-mock - types-requests - - watchdog + - watchdog>=5.0.0 - xxhash - repo: https://github.com/astral-sh/ruff-pre-commit @@ -113,7 +113,7 @@ repos: - pyyaml - requests - types-aiobotocore - - watchdog + - watchdog>=5.0.0 - xxhash - repo: local hooks: diff --git a/extensions/eda/plugins/event_filter/dashes_to_underscores.py b/extensions/eda/plugins/event_filter/dashes_to_underscores.py index d6a5ab0d..fa70d7c7 100644 --- a/extensions/eda/plugins/event_filter/dashes_to_underscores.py +++ b/extensions/eda/plugins/event_filter/dashes_to_underscores.py @@ -10,9 +10,12 @@ """ import multiprocessing as mp +from typing import Any -def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002 +def main( + event: dict[str, Any], overwrite: bool = True +) -> dict[str, Any]: # noqa: FBT001, FBT002 """Change dashes in keys to underscores.""" logger = mp.get_logger() logger.info("dashes_to_underscores") diff --git a/extensions/eda/plugins/event_filter/json_filter.py b/extensions/eda/plugins/event_filter/json_filter.py index f5860124..551f7160 100644 --- a/extensions/eda/plugins/event_filter/json_filter.py +++ b/extensions/eda/plugins/event_filter/json_filter.py @@ -16,21 +16,22 @@ from __future__ import annotations import fnmatch +from typing import Any, Optional -def _matches_include_keys(include_keys: list, string: str) -> bool: +def _matches_include_keys(include_keys: list[str], string: str) -> bool: return any(fnmatch.fnmatch(string, pattern) for pattern in include_keys) -def _matches_exclude_keys(exclude_keys: list, string: str) -> bool: +def _matches_exclude_keys(exclude_keys: list[str], string: str) -> bool: return any(fnmatch.fnmatch(string, pattern) for pattern in exclude_keys) def main( - event: dict, - exclude_keys: list | None = None, - include_keys: list | None = None, -) -> dict: + event: dict[str, Any], + exclude_keys: Optional[list[str]] = None, + include_keys: Optional[list[str]] = None, +) -> dict[str, Any]: """Filter keys out of events.""" if exclude_keys is None: exclude_keys = [] diff --git a/extensions/eda/plugins/event_filter/noop.py b/extensions/eda/plugins/event_filter/noop.py index 57c32df9..8f0756fb 100644 --- a/extensions/eda/plugins/event_filter/noop.py +++ b/extensions/eda/plugins/event_filter/noop.py @@ -1,6 +1,8 @@ """noop.py: An event filter that does nothing to the input.""" +from typing import Any -def main(event: dict) -> dict: + +def main(event: dict[str, Any]) -> dict[str, Any]: """Return the input.""" return event diff --git a/extensions/eda/plugins/event_filter/normalize_keys.py b/extensions/eda/plugins/event_filter/normalize_keys.py index fbc86c1e..58256c39 100644 --- a/extensions/eda/plugins/event_filter/normalize_keys.py +++ b/extensions/eda/plugins/event_filter/normalize_keys.py @@ -40,11 +40,14 @@ import logging import multiprocessing as mp import re +from typing import Any normalize_regex = re.compile("[^0-9a-zA-Z_]+") -def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002 +def main( + event: dict[str, Any], overwrite: bool = True +) -> dict[str, Any]: # noqa: FBT001, FBT002 """Change keys that contain non-alphanumeric characters to underscores.""" logger = mp.get_logger() logger.info("normalize_keys") @@ -52,10 +55,10 @@ def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002 def _normalize_embedded_keys( - obj: dict, + obj: dict[str, Any], overwrite: bool, # noqa: FBT001 logger: logging.Logger, -) -> dict: +) -> dict[str, Any]: if isinstance(obj, dict): new_dict = {} original_keys = list(obj.keys()) diff --git a/extensions/eda/plugins/event_source/alertmanager.py b/extensions/eda/plugins/event_source/alertmanager.py index b4efd5d6..38c27fbe 100644 --- a/extensions/eda/plugins/event_source/alertmanager.py +++ b/extensions/eda/plugins/event_source/alertmanager.py @@ -114,7 +114,7 @@ def clean_host(host: str) -> str: return host -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Receive events via alertmanager webhook.""" app = web.Application() app["queue"] = queue @@ -144,7 +144,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/aws_cloudtrail.py b/extensions/eda/plugins/event_source/aws_cloudtrail.py index 6f6b680e..8ef18b2a 100644 --- a/extensions/eda/plugins/event_source/aws_cloudtrail.py +++ b/extensions/eda/plugins/event_source/aws_cloudtrail.py @@ -40,13 +40,15 @@ from botocore.client import BaseClient -def _cloudtrail_event_to_dict(event: dict) -> dict: +def _cloudtrail_event_to_dict(event: dict[str, Any]) -> dict[str, Any]: event["EventTime"] = event["EventTime"].isoformat() event["CloudTrailEvent"] = json.loads(event["CloudTrailEvent"]) return event -def _get_events(events: list[dict], last_event_ids: list[str]) -> list: +def _get_events( + events: list[dict[str, Any]], last_event_ids: list[str] +) -> tuple[list[dict[str, Any]], Any, list[str]]: event_time = None event_ids = [] result = [] @@ -60,13 +62,22 @@ def _get_events(events: list[dict], last_event_ids: list[str]) -> list: elif event_time == event["EventTime"]: event_ids.append(event["EventId"]) result.append(event) - return [result, event_time, event_ids] + return result, event_time, event_ids -async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]: +async def _get_cloudtrail_events( + client: BaseClient, params: dict[str, Any] +) -> list[dict[str, Any]]: paginator = client.get_paginator("lookup_events") results = await paginator.paginate(**params).build_full_result() - return results.get("Events", []) + events = results.get("Events", []) + # type guards: + if not isinstance(events, list): + raise ValueError("Events is not a list") + for event in events: + if not isinstance(event, dict): + raise ValueError("Event is not a dictionary") + return events ARGS_MAPPING = { @@ -75,7 +86,7 @@ async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict] } -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Receive events via AWS CloudTrail.""" delay = int(args.get("delay_seconds", 10)) @@ -131,7 +142,7 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/aws_sqs_queue.py b/extensions/eda/plugins/event_source/aws_sqs_queue.py index 66c85d7c..7fcfc04a 100644 --- a/extensions/eda/plugins/event_source/aws_sqs_queue.py +++ b/extensions/eda/plugins/event_source/aws_sqs_queue.py @@ -31,7 +31,7 @@ # pylint: disable=too-many-locals -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Receive events via an AWS SQS queue.""" logger = logging.getLogger() @@ -117,7 +117,7 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/azure_service_bus.py b/extensions/eda/plugins/event_source/azure_service_bus.py index f00d14c5..1bfca1c6 100644 --- a/extensions/eda/plugins/event_source/azure_service_bus.py +++ b/extensions/eda/plugins/event_source/azure_service_bus.py @@ -27,7 +27,7 @@ def receive_events( loop: asyncio.events.AbstractEventLoop, - queue: asyncio.Queue, + queue: asyncio.Queue[Any], args: dict[str, Any], # pylint: disable=W0621 ) -> None: """Receive events from service bus.""" @@ -53,7 +53,7 @@ def receive_events( async def main( - queue: asyncio.Queue, + queue: asyncio.Queue[Any], args: dict[str, Any], # pylint: disable=W0621 ) -> None: """Receive events from service bus in a loop.""" @@ -69,7 +69,7 @@ async def main( class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - def put_nowait(self: "MockQueue", event: dict) -> None: + def put_nowait(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/file.py b/extensions/eda/plugins/event_source/file.py index 2a45e213..96c301ff 100644 --- a/extensions/eda/plugins/event_source/file.py +++ b/extensions/eda/plugins/event_source/file.py @@ -24,7 +24,7 @@ from watchdog.observers import Observer -def send_facts(queue: Queue, filename: Union[str, bytes]) -> None: +def send_facts(queue: Queue[Any], filename: Union[str, bytes]) -> None: """Send facts to the queue.""" if isinstance(filename, bytes): filename = str(filename, "utf-8") @@ -50,7 +50,7 @@ def send_facts(queue: Queue, filename: Union[str, bytes]) -> None: coroutine = queue.put(item) # noqa: F841 -def main(queue: Queue, args: dict) -> None: +def main(queue: Queue[Any], args: dict[str, Any]) -> None: """Load facts from YAML files initially and when the file changes.""" files = [pathlib.Path(f).resolve().as_posix() for f in args.get("files", [])] @@ -62,7 +62,7 @@ def main(queue: Queue, args: dict) -> None: _observe_files(queue, files) -def _observe_files(queue: Queue, files: list[str]) -> None: +def _observe_files(queue: Queue[Any], files: list[str]) -> None: class Handler(RegexMatchingEventHandler): """A handler for file events.""" @@ -104,7 +104,7 @@ def on_moved(self: "Handler", event: FileSystemEvent) -> None: class MockQueue(Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/file_watch.py b/extensions/eda/plugins/event_source/file_watch.py index 0bcf8e55..8c2d499b 100644 --- a/extensions/eda/plugins/event_source/file_watch.py +++ b/extensions/eda/plugins/event_source/file_watch.py @@ -28,8 +28,8 @@ def watch( loop: asyncio.events.AbstractEventLoop, - queue: asyncio.Queue, - args: dict, + queue: asyncio.Queue[Any], + args: dict[str, Any], ) -> None: """Watch for changes and put events on the queue.""" root_path = args["path"] @@ -96,7 +96,7 @@ def on_moved(self: "Handler", event: FileSystemEvent) -> None: observer.join() -async def main(queue: asyncio.Queue, args: dict) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Watch for changes to a file and put events on the queue.""" loop = asyncio.get_event_loop() @@ -110,7 +110,7 @@ async def main(queue: asyncio.Queue, args: dict) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - def put_nowait(self: "MockQueue", event: dict) -> None: + def put_nowait(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/generic.py b/extensions/eda/plugins/event_source/generic.py index 43228e3d..536e5db5 100644 --- a/extensions/eda/plugins/event_source/generic.py +++ b/extensions/eda/plugins/event_source/generic.py @@ -100,7 +100,9 @@ class DelayArgs: class Generic: """Generic source plugin to generate different events.""" - def __init__(self: Generic, queue: asyncio.Queue, args: dict[str, Any]) -> None: + def __init__( + self: Generic, queue: asyncio.Queue[Any], args: dict[str, Any] + ) -> None: """Insert event data into the queue.""" self.queue = queue field_names = [f.name for f in fields(Args)] @@ -164,7 +166,7 @@ async def __call__(self: Generic) -> None: await asyncio.sleep(self.delay_args.shutdown_after) - async def _post_event(self: Generic, event: dict, index: int) -> None: + async def _post_event(self: Generic, event: dict[str, Any], index: int) -> None: data = self._create_data(index) data.update(event) @@ -189,7 +191,7 @@ async def _load_payload_from_file(self: Generic) -> None: def _create_data( self: Generic, index: int, - ) -> dict: + ) -> dict[str, Any]: data: dict[str, str | int] = {} if self.my_args.create_index: data[self.my_args.create_index] = index @@ -206,7 +208,7 @@ def _create_data( async def main( # pylint: disable=R0914 - queue: asyncio.Queue, + queue: asyncio.Queue[Any], args: dict[str, Any], ) -> None: """Call the Generic Source Plugin.""" @@ -218,7 +220,7 @@ async def main( # pylint: disable=R0914 class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: MockQueue, event: dict) -> None: + async def put(self: MockQueue, event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/journald.py b/extensions/eda/plugins/event_source/journald.py index 67087e0e..1e0aa355 100644 --- a/extensions/eda/plugins/event_source/journald.py +++ b/extensions/eda/plugins/event_source/journald.py @@ -33,7 +33,7 @@ from systemd import journal # type: ignore -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: # noqa: D417 +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: # noqa: D417 """Read journal entries and add them to the provided queue. Args: diff --git a/extensions/eda/plugins/event_source/kafka.py b/extensions/eda/plugins/event_source/kafka.py index 21282457..100cfa8f 100644 --- a/extensions/eda/plugins/event_source/kafka.py +++ b/extensions/eda/plugins/event_source/kafka.py @@ -45,7 +45,7 @@ async def main( # pylint: disable=R0914 - queue: asyncio.Queue, + queue: asyncio.Queue[Any], args: dict[str, Any], ) -> None: """Receive events via a kafka topic.""" @@ -116,7 +116,7 @@ async def main( # pylint: disable=R0914 async def receive_msg( - queue: asyncio.Queue, + queue: asyncio.Queue[Any], kafka_consumer: AIOKafkaConsumer, encoding: str, ) -> None: @@ -161,7 +161,7 @@ async def receive_msg( class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/pg_listener.py b/extensions/eda/plugins/event_source/pg_listener.py index 2cd66aa9..dab6931b 100644 --- a/extensions/eda/plugins/event_source/pg_listener.py +++ b/extensions/eda/plugins/event_source/pg_listener.py @@ -82,13 +82,13 @@ def __init__(self: "MissingChunkKeyError", key: str) -> None: super().__init__(f"Chunked payload is missing required {key}") -def _validate_chunked_payload(payload: dict) -> None: +def _validate_chunked_payload(payload: dict[str, Any]) -> None: for key in REQUIRED_CHUNK_KEYS: if key not in payload: raise MissingChunkKeyError(key) -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Listen for events from a channel.""" for key in REQUIRED_KEYS: if key not in args: @@ -119,8 +119,8 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: async def _handle_chunked_message( data: dict[str, Any], - chunked_cache: dict, - queue: asyncio.Queue, + chunked_cache: dict[str, Any], + queue: asyncio.Queue[Any], ) -> None: message_uuid = data[MESSAGE_CHUNKED_UUID] number_of_chunks = data[MESSAGE_CHUNK_COUNT] @@ -172,7 +172,7 @@ async def _handle_chunked_message( class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/range.py b/extensions/eda/plugins/event_source/range.py index 1862c51d..f2619733 100644 --- a/extensions/eda/plugins/event_source/range.py +++ b/extensions/eda/plugins/event_source/range.py @@ -18,7 +18,7 @@ from typing import Any -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Generate events with an increasing index i with a limit.""" delay = args.get("delay", 0) @@ -33,7 +33,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/tick.py b/extensions/eda/plugins/event_source/tick.py index da6227dd..4a71f871 100644 --- a/extensions/eda/plugins/event_source/tick.py +++ b/extensions/eda/plugins/event_source/tick.py @@ -19,7 +19,7 @@ from typing import Any -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Generate events with an increasing index i and a time between ticks.""" delay = args.get("delay", 1) @@ -34,7 +34,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/url_check.py b/extensions/eda/plugins/event_source/url_check.py index cfefce51..fb3c3898 100644 --- a/extensions/eda/plugins/event_source/url_check.py +++ b/extensions/eda/plugins/event_source/url_check.py @@ -27,7 +27,7 @@ OK = 200 -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Poll a set of URLs and send events with status.""" urls = args.get("urls", []) delay = int(args.get("delay", 1)) @@ -72,7 +72,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: "MockQueue", event: dict) -> None: + async def put(self: "MockQueue", event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/extensions/eda/plugins/event_source/webhook.py b/extensions/eda/plugins/event_source/webhook.py index aa93e1c5..124b87aa 100644 --- a/extensions/eda/plugins/event_source/webhook.py +++ b/extensions/eda/plugins/event_source/webhook.py @@ -44,7 +44,7 @@ import logging import ssl import typing -from typing import Any +from typing import Any, Awaitable from aiohttp import web @@ -117,7 +117,10 @@ async def _hmac_verify(request: web.Request) -> bool: @web.middleware -async def bearer_auth(request: web.Request, handler: Callable) -> web.StreamResponse: +async def bearer_auth( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> web.StreamResponse: """Verify authorization is Bearer type.""" try: _parse_token(request) @@ -130,7 +133,10 @@ async def bearer_auth(request: web.Request, handler: Callable) -> web.StreamResp @web.middleware -async def hmac_verify(request: web.Request, handler: Callable) -> web.StreamResponse: +async def hmac_verify( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> web.StreamResponse: """Verify event's HMAC signature.""" hmac_verified = await _hmac_verify(request) if not hmac_verified: @@ -166,7 +172,7 @@ def _get_ssl_context(args: dict[str, Any]) -> ssl.SSLContext | None: return context -async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: """Receive events via webhook.""" if "port" not in args: msg = "Missing required argument: port" @@ -224,7 +230,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: class MockQueue(asyncio.Queue[Any]): """A fake queue.""" - async def put(self: MockQueue, event: dict) -> None: + async def put(self: MockQueue, event: dict[str, Any]) -> None: """Print the event.""" print(event) # noqa: T201 diff --git a/plugins/module_utils/controller.py b/plugins/module_utils/controller.py index 12b01e07..72b25377 100644 --- a/plugins/module_utils/controller.py +++ b/plugins/module_utils/controller.py @@ -108,7 +108,9 @@ def resolve_name_to_id( ) -> Optional[str]: result = self.get_exactly_one(endpoint, name, **kwargs) if result: - return result["id"] + if isinstance(result["id"], str): + return result["id"] + raise EDAError("The endpoint did not provide a string id") return None def get_one_or_many( @@ -116,7 +118,7 @@ def get_one_or_many( endpoint: str, name: Optional[str] = None, **kwargs: Any, - ) -> List[Any]: + ) -> List[dict[str, Any]]: new_kwargs = kwargs.copy() if name: @@ -140,11 +142,18 @@ def get_one_or_many( if response.json["count"] == 0: return [] - return response.json["results"] + # type safeguard + results = response.json["results"] + if not isinstance(results, list): + raise EDAError("The endpoint did not provide a list of dictionaries") + for result in results: + if not isinstance(result, dict): + raise EDAError("The endpoint did not provide a list of dictionaries") + return results def create_if_needed( self, - new_item: dict, + new_item: dict[str, Any], endpoint: str, item_type: str = "unknown", ) -> dict[str, bool]: @@ -222,7 +231,7 @@ def objects_could_be_different( self, old: dict[str, Any], new: dict[str, Any], - field_set: Optional[set] = None, + field_set: Optional[set[str]] = None, warning: bool = False, ) -> bool: if field_set is None: @@ -248,7 +257,7 @@ def objects_could_be_different( def update_if_needed( self, - existing_item: dict, + existing_item: dict[str, Any], new_item: dict[str, Any], endpoint: str, item_type: str, @@ -302,7 +311,7 @@ def update_if_needed( def create_or_update_if_needed( self, existing_item: dict[str, Any], - new_item: dict, + new_item: dict[str, Any], endpoint: str, item_type: str = "unknown", ) -> dict[str, bool]: diff --git a/pyproject.toml b/pyproject.toml index cede5caa..b4951207 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,14 +39,14 @@ color_output = true error_summary = true # TODO: Remove temporary skips and close https://github.com/ansible/event-driven-ansible/issues/258 -# strict = true -# disallow_untyped_calls = true +strict = true +disallow_untyped_calls = true disallow_untyped_defs = true # disallow_any_generics = true -# disallow_any_unimported = True +# disallow_any_unimported = true # warn_redundant_casts = True # warn_return_any = True -# warn_unused_configs = True +warn_unused_configs = true # site-packages is here to help vscode mypy integration getting confused exclude = "(build|dist|test/local-content|site-packages|~/.pyenv|examples/playbooks/collections|plugins/modules)" diff --git a/requirements.txt b/requirements.txt index ffc4fc05..3d25bbaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,5 @@ kafka-python; python_version < "3.12" kafka-python-ng; python_version >= "3.12" psycopg[binary,pool] # extras needed to avoid install failure on macos-aarch64 systemd-python; sys_platform != 'darwin' -watchdog +watchdog>=5.0.0 # types xxhash diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index eab426e0..18783d51 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -5,7 +5,7 @@ @pytest.fixture(scope="function") -def subprocess_teardown() -> Iterator[Callable]: +def subprocess_teardown() -> Iterator[Callable[[Popen[bytes]], None]]: processes: list[Popen[bytes]] = [] def _teardown(process: Popen[bytes]) -> None: diff --git a/tests/integration/event_source_kafka/test_kafka_source.py b/tests/integration/event_source_kafka/test_kafka_source.py index 92603d2a..b7735c9d 100644 --- a/tests/integration/event_source_kafka/test_kafka_source.py +++ b/tests/integration/event_source_kafka/test_kafka_source.py @@ -6,7 +6,8 @@ import pytest from kafka import KafkaProducer -from ..utils import TESTS_PATH, CLIRunner +from .. import TESTS_PATH +from ..utils import CLIRunner @pytest.fixture(scope="session") diff --git a/tests/integration/event_source_url_check/test_url_check_source.py b/tests/integration/event_source_url_check/test_url_check_source.py index 2a242822..c6012fd1 100644 --- a/tests/integration/event_source_url_check/test_url_check_source.py +++ b/tests/integration/event_source_url_check/test_url_check_source.py @@ -5,7 +5,8 @@ import pytest -from ..utils import DEFAULT_TEST_TIMEOUT, TESTS_PATH, CLIRunner +from .. import TESTS_PATH +from ..utils import DEFAULT_TEST_TIMEOUT, CLIRunner EVENT_SOURCE_DIR = os.path.dirname(__file__) @@ -41,7 +42,7 @@ def init_webserver() -> Generator[Any, Any, Any]: ) def test_url_check_source_sanity( init_webserver: None, - subprocess_teardown: Callable, + subprocess_teardown: Callable[..., None], endpoint: str, expected_resp_data: str, ) -> None: @@ -67,7 +68,9 @@ def test_url_check_source_sanity( @pytest.mark.timeout(timeout=DEFAULT_TEST_TIMEOUT, method="signal") -def test_url_check_source_error_handling(subprocess_teardown: Callable) -> None: +def test_url_check_source_error_handling( + subprocess_teardown: Callable[..., None], +) -> None: """ Ensure the url check source plugin responds correctly when the desired HTTP server is unreachable diff --git a/tests/integration/event_source_webhook/test_webhook_source.py b/tests/integration/event_source_webhook/test_webhook_source.py index a9b20a88..c4b161f9 100644 --- a/tests/integration/event_source_webhook/test_webhook_source.py +++ b/tests/integration/event_source_webhook/test_webhook_source.py @@ -7,10 +7,11 @@ import pytest import requests -from ..utils import TESTS_PATH, CLIRunner +from .. import TESTS_PATH +from ..utils import CLIRunner -def wait_for_events(proc: subprocess.Popen, timeout: float = 15.0) -> None: +def wait_for_events(proc: subprocess.Popen[bytes], timeout: float = 15.0) -> None: """ Wait for events to be processed by ansible-rulebook, or timeout. Requires the process to be running in debug mode. @@ -33,7 +34,9 @@ def wait_for_events(proc: subprocess.Popen, timeout: float = 15.0) -> None: pytest.param(5001, id="custom_port"), ], ) -def test_webhook_source_sanity(subprocess_teardown: Callable, port: int) -> None: +def test_webhook_source_sanity( + subprocess_teardown: Callable[..., None], port: int +) -> None: """ Check the successful execution, response and shutdown of the webhook source plugin. @@ -73,7 +76,9 @@ def test_webhook_source_sanity(subprocess_teardown: Callable, port: int) -> None assert proc.returncode == 0 -def test_webhook_source_with_busy_port(subprocess_teardown: Callable) -> None: +def test_webhook_source_with_busy_port( + subprocess_teardown: Callable[..., None], +) -> None: """ Ensure the CLI responds correctly if the desired port is already in use. @@ -91,7 +96,7 @@ def test_webhook_source_with_busy_port(subprocess_teardown: Callable) -> None: assert proc2.returncode == 1 -def test_webhook_source_hmac_sanity(subprocess_teardown: Callable) -> None: +def test_webhook_source_hmac_sanity(subprocess_teardown: Callable[..., None]) -> None: """ Check the successful execution, response and shutdown of the webhook source plugin. @@ -140,7 +145,7 @@ def test_webhook_source_hmac_sanity(subprocess_teardown: Callable) -> None: def test_webhook_source_with_unsupported_hmac_algo( - subprocess_teardown: Callable, + subprocess_teardown: Callable[..., None], ) -> None: """ Ensure the CLI responds correctly if the desired HMAC algorithm is not supported. diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 2b5ab529..7c320952 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -1,7 +1,7 @@ import os import subprocess from dataclasses import dataclass -from typing import List, Optional +from typing import Any, List, Optional from . import TESTS_PATH @@ -25,7 +25,7 @@ class CLIRunner: verbose: bool = False debug: bool = False timeout: float = 10.0 - env: Optional[dict] = None + env: Optional[dict[str, str]] = None def __post_init__(self) -> None: self.env = os.environ.copy() if self.env is None else self.env @@ -54,7 +54,7 @@ def _process_args(self) -> List[str]: return args - def run(self) -> subprocess.CompletedProcess: + def run(self) -> subprocess.CompletedProcess[Any]: args = self._process_args() print("Running command: ", " ".join(args)) return subprocess.run( @@ -66,7 +66,7 @@ def run(self) -> subprocess.CompletedProcess: env=self.env, ) - def run_in_background(self) -> subprocess.Popen: + def run_in_background(self) -> subprocess.Popen[bytes]: args = self._process_args() print("Running command: ", " ".join(args)) return subprocess.Popen( diff --git a/tests/unit/event_filter/test_insert_hosts_to_meta.py b/tests/unit/event_filter/test_insert_hosts_to_meta.py index 54b87dd0..13325dc8 100644 --- a/tests/unit/event_filter/test_insert_hosts_to_meta.py +++ b/tests/unit/event_filter/test_insert_hosts_to_meta.py @@ -41,8 +41,10 @@ @pytest.mark.parametrize("data, args, expected_hosts", EVENT_DATA_1) -def test_find_hosts(data: dict, args: dict, expected_hosts: list) -> None: - data = hosts_main(data, **args) +def test_find_hosts( + data: dict[str, Any], args: dict[str, str], expected_hosts: list[str] +) -> None: + data = hosts_main(data, **args) # type: ignore if expected_hosts: assert data["meta"]["hosts"] == expected_hosts else: diff --git a/tests/unit/event_filter/test_normalize_keys.py b/tests/unit/event_filter/test_normalize_keys.py index e30f0429..bddd53aa 100644 --- a/tests/unit/event_filter/test_normalize_keys.py +++ b/tests/unit/event_filter/test_normalize_keys.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from extensions.eda.plugins.event_filter.normalize_keys import main as normalize_main @@ -42,6 +44,8 @@ @pytest.mark.parametrize("event, overwrite, updated_event", TEST_DATA_1) -def test_normalize_keys(event: dict, overwrite: bool, updated_event: dict) -> None: +def test_normalize_keys( + event: dict[str, Any], overwrite: bool, updated_event: dict[str, Any] +) -> None: data = normalize_main(event, overwrite) assert data == updated_event diff --git a/tests/unit/event_source/test_alertmanager.py b/tests/unit/event_source/test_alertmanager.py index b4d7db6c..72ce669f 100644 --- a/tests/unit/event_source/test_alertmanager.py +++ b/tests/unit/event_source/test_alertmanager.py @@ -7,7 +7,7 @@ from extensions.eda.plugins.event_source.alertmanager import main as alert_main -async def start_server(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def start_server(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: await alert_main(queue, args) diff --git a/tests/unit/event_source/test_azure_service_bus.py b/tests/unit/event_source/test_azure_service_bus.py index 2084c1ce..d9c499ca 100644 --- a/tests/unit/event_source/test_azure_service_bus.py +++ b/tests/unit/event_source/test_azure_service_bus.py @@ -7,7 +7,7 @@ from extensions.eda.plugins.event_source.azure_service_bus import main as azure_main -class MockQueue(asyncio.Queue): +class MockQueue(asyncio.Queue[Any]): def __init__(self) -> None: self.queue: list[Any] = [] diff --git a/tests/unit/event_source/test_generic.py b/tests/unit/event_source/test_generic.py index e82e3682..909dc6ec 100644 --- a/tests/unit/event_source/test_generic.py +++ b/tests/unit/event_source/test_generic.py @@ -1,4 +1,4 @@ -""" Tests for generic source plugin """ +"""Tests for generic source plugin""" import asyncio import os @@ -144,7 +144,7 @@ def test_generic_blob() -> None: @pytest.mark.parametrize("time_format,expected_type", TEST_TIME_FORMATS) -def test_generic_timestamps(time_format: list, expected_type: type) -> None: +def test_generic_timestamps(time_format: list[str], expected_type: type) -> None: """Test receiving events with timestamps.""" myqueue = _MockQueue() event = {"name": "fred"} diff --git a/tests/unit/event_source/test_kafka.py b/tests/unit/event_source/test_kafka.py index 3f19f7e3..d6adaf8e 100644 --- a/tests/unit/event_source/test_kafka.py +++ b/tests/unit/event_source/test_kafka.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json from typing import Any @@ -40,7 +42,7 @@ async def __anext__(self) -> MagicMock: raise StopAsyncIteration -class MockConsumer(AsyncMock): +class MockConsumer(AsyncMock): # type: ignore[misc] def __aiter__(self) -> AsyncIterator: return AsyncIterator() diff --git a/tests/unit/event_source/test_pg_listener.py b/tests/unit/event_source/test_pg_listener.py index c0ed0947..0a2bbf67 100644 --- a/tests/unit/event_source/test_pg_listener.py +++ b/tests/unit/event_source/test_pg_listener.py @@ -1,4 +1,4 @@ -""" Tests for pg_listener source plugin """ +"""Tests for pg_listener source plugin""" import asyncio import json @@ -81,11 +81,8 @@ def _to_chunks(payload: str, result: list[str]) -> None: ] -from typing import List - - @pytest.mark.parametrize("events", TEST_PAYLOADS) -def test_receive_from_pg_listener(events: List[dict]) -> None: +def test_receive_from_pg_listener(events: list[dict[str, Any]]) -> None: """Test receiving different payloads from pg notify.""" notify_payload: list[str] = [] myqueue = _MockQueue() diff --git a/tests/unit/event_source/test_webhook.py b/tests/unit/event_source/test_webhook.py index 1d274360..0f60b263 100644 --- a/tests/unit/event_source/test_webhook.py +++ b/tests/unit/event_source/test_webhook.py @@ -9,7 +9,7 @@ from extensions.eda.plugins.event_source.webhook import main as webhook_main -async def start_server(queue: asyncio.Queue, args: dict[str, Any]) -> None: +async def start_server(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: await webhook_main(queue, args)