From 8e1812e0fc6827a5722b55d790953c21b21e2fbd Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Thu, 29 Apr 2021 19:16:12 -0700 Subject: [PATCH] add more comprehensive typing and check with mypy --- Makefile | 9 ++++- setup.cfg | 17 +++++++++ watchtower/__init__.py | 78 +++++++++++++++++++++++++++++------------- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/Makefile b/Makefile index bf94a82..2c100a0 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,20 @@ SHELL=/bin/bash test_deps: - pip install coverage flake8 wheel pyyaml boto3 + pip install \ + boto3 \ + coverage \ + flake8 \ + mypy \ + pyyaml \ + wheel lint: test_deps flake8 test: test_deps lint coverage run --source=watchtower ./test/test.py + mypy watchtower docs: sphinx-build docs docs/html diff --git a/setup.cfg b/setup.cfg index 1a6ca11..874e8d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,20 @@ [flake8] max-line-length=120 ignore: E301, E401 + +[mypy] +pretty = true +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[mypy-boto3.*] +ignore_missing_imports = true + +[mypy-botocore.*] +ignore_missing_imports = true + +[mypy-django.*] +ignore_missing_imports = true diff --git a/watchtower/__init__.py b/watchtower/__init__.py index 5659d23..a19b47b 100644 --- a/watchtower/__init__.py +++ b/watchtower/__init__.py @@ -1,7 +1,8 @@ from collections.abc import Mapping from datetime import date, datetime, timedelta +from logging import LogRecord from operator import itemgetter -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union, TYPE_CHECKING import json, logging, time, threading, warnings import queue @@ -10,7 +11,30 @@ from botocore.exceptions import ClientError -def _idempotent_create(client, method, *args, **kwargs): +if TYPE_CHECKING: + from typing_extensions import Protocol + + class LogClient(Protocol): + exceptions: Any + + def create_log_group(self, logGroupName: str) -> None: + ... + + def create_log_stream(self, logGroupName: str, logStreamName: str) -> None: + ... + + def put_log_events(self, **kwargs: Any) -> dict: + ... + + def put_retention_policy(self, logGroupName: str, retentionInDays: Optional[int]) -> None: + ... + + class Session(Protocol): + def client(self, service: str, endpoint_url: Optional[str] = None) -> LogClient: + ... + + +def _idempotent_create(client: LogClient, method: str, *args: Any, **kwargs: Any) -> Any: method_callable = getattr(client, method) try: method_callable(*args, **kwargs) @@ -18,7 +42,7 @@ def _idempotent_create(client, method, *args, **kwargs): pass -def _json_serialize_default(o): +def _json_serialize_default(o: Any) -> Any: """ A standard 'default' json serializer function that will serialize datetime objects as ISO format. """ @@ -26,7 +50,7 @@ def _json_serialize_default(o): return o.isoformat() -def _boto_debug_filter(record): +def _boto_debug_filter(record: LogRecord) -> bool: # Filter debug log messages from botocore and its dependency, urllib3. # This is required to avoid message storms any time we send logs. if record.name.startswith("botocore") and record.levelname == "DEBUG": @@ -36,7 +60,7 @@ def _boto_debug_filter(record): return True -def _boto_filter(record): +def _boto_filter(record: LogRecord) -> bool: # Filter log messages from botocore and its dependency, urllib3. # This is required to avoid an infinite loop when shutting down. if record.name.startswith("botocore"): @@ -102,7 +126,7 @@ class CloudWatchLogHandler(logging.Handler): EXTRA_MSG_PAYLOAD_SIZE = 26 @staticmethod - def _get_session(boto3_session, boto3_profile_name): + def _get_session(boto3_session: Optional[Session], boto3_profile_name: Optional[str]) -> Session: if boto3_session: return boto3_session @@ -111,13 +135,13 @@ def _get_session(boto3_session, boto3_profile_name): return boto3 - def __init__(self, log_group: str = __name__, stream_name: str = None, use_queues: bool = True, + def __init__(self, log_group: str = __name__, stream_name: Optional[str] = None, use_queues: bool = True, send_interval: Union[int, timedelta] = 60, max_batch_size: int = 1024 * 1024, max_batch_count: int = 10000, boto3_session: Optional[boto3.session.Session] = None, - boto3_profile_name: str = None, create_log_group: bool = True, - log_group_retention_days: Optional[int]=None, - create_log_stream: bool = True, json_serialize_default: Callable[[Any], Any] = None, - max_message_size: int = 256 * 1024, endpoint_url: str = None, *args, **kwargs): + boto3_profile_name: Optional[str] = None, create_log_group: bool = True, + log_group_retention_days: Optional[int] = None, create_log_stream: bool = True, + json_serialize_default: Optional[Callable[[Any], Any]] = None, max_message_size: int = 256 * 1024, + endpoint_url: Optional[str] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.log_group = log_group self.stream_name = stream_name @@ -130,8 +154,9 @@ def __init__(self, log_group: str = __name__, stream_name: str = None, use_queue self.max_batch_size = max_batch_size self.max_batch_count = max_batch_count self.max_message_size = max_message_size - self.queues, self.sequence_tokens = {}, {} - self.threads = [] + self.queues: Dict[str, queue.Queue] = {} + self.sequence_tokens: Dict[str, Optional[str]] = {} + self.threads: List[threading.Thread] = [] self.creating_log_stream, self.shutting_down = False, False self.create_log_stream = create_log_stream self.log_group_retention_days = log_group_retention_days @@ -150,7 +175,7 @@ def __init__(self, log_group: str = __name__, stream_name: str = None, use_queue self.addFilter(_boto_debug_filter) - def _submit_batch(self, batch, stream_name, max_retries=5): + def _submit_batch(self, batch: Sequence[dict], stream_name: str, max_retries: int = 5) -> None: if len(batch) < 1: return sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False) @@ -203,7 +228,7 @@ def _submit_batch(self, batch, stream_name, max_retries=5): # from the response self.sequence_tokens[stream_name] = response["nextSequenceToken"] - def emit(self, message): + def emit(self, message: LogRecord) -> None: if self.creating_log_stream: return # Avoid infinite recursion when asked to log a message as our own side effect stream_name = self.stream_name @@ -214,7 +239,7 @@ def emit(self, message): if stream_name not in self.sequence_tokens: self.sequence_tokens[stream_name] = None - if isinstance(message.msg, Mapping): + if isinstance(cast(Union[dict, str], message.msg), Mapping): message.msg = json.dumps(message.msg, default=self.json_serialize_default) cwl_message = dict(timestamp=int(message.created * 1000), message=self.format(message)) @@ -235,28 +260,32 @@ def emit(self, message): else: self._submit_batch([cwl_message], stream_name) - def batch_sender(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size): - msg = None + def batch_sender(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int, + max_batch_count: int, max_message_size: int) -> None: + msg: Union[dict, int, None] = None max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE + assert max_message_body_size > 0 - def size(_msg): + def size(_msg: Union[dict, int]) -> int: return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE - def truncate(_msg2): + def truncate(_msg2: dict) -> dict: warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning) _msg2["message"] = _msg2["message"][:max_message_body_size] return _msg2 # See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events while msg != self.END: - cur_batch = [] if msg is None or msg == self.FLUSH else [msg] + cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)] cur_batch_size = sum(map(size, cur_batch)) cur_batch_msg_count = len(cur_batch) cur_batch_deadline = time.time() + send_interval while True: try: - msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())) + msg = cast(Union[dict, int], + my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))) if size(msg) > max_message_body_size: + assert isinstance(msg, dict) # size always < max_message_body_size if not `dict` msg = truncate(msg) except queue.Empty: # If the queue is empty, we don't want to reprocess the previous message @@ -273,12 +302,13 @@ def truncate(_msg2): my_queue.task_done() break elif msg: + assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels cur_batch_size += size(msg) cur_batch_msg_count += 1 cur_batch.append(msg) my_queue.task_done() - def flush(self): + def flush(self) -> None: """ Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False. """ @@ -291,7 +321,7 @@ def flush(self): for q in self.queues.values(): q.join() - def close(self): + def close(self) -> None: """ Send any queued messages to CloudWatch and prevent further processing of messages. This method does nothing if ``use_queues`` is set to False.