-
-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add more comprehensive typing and check with mypy #144
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,20 @@ | ||
[flake8] | ||
max-line-length=120 | ||
ignore: E301, E401 | ||
|
||
[mypy] | ||
pretty = true | ||
show_error_codes = true | ||
show_error_context = true | ||
warn_redundant_casts = true | ||
warn_unused_ignores = true | ||
warn_unreachable = true | ||
|
||
[mypy-boto3.*] | ||
ignore_missing_imports = true | ||
|
||
[mypy-botocore.*] | ||
ignore_missing_imports = true | ||
|
||
[mypy-django.*] | ||
ignore_missing_imports = true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from collections.abc import Mapping | ||
from datetime import date, datetime | ||
from collections.abc import Mapping, MutableMapping | ||
from datetime import date, datetime, timedelta | ||
from logging import LogRecord | ||
from operator import itemgetter | ||
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union | ||
import sys, json, logging, time, threading, warnings, functools, platform | ||
import queue | ||
|
||
|
@@ -9,15 +11,15 @@ | |
from botocore.exceptions import ClientError | ||
|
||
|
||
def _json_serialize_default(o): | ||
def _json_serialize_default(o: Any) -> Any: | ||
""" | ||
A standard 'default' json serializer function that will serialize datetime objects as ISO format. | ||
""" | ||
if isinstance(o, (date, datetime)): | ||
return o.isoformat() | ||
|
||
|
||
def _boto_debug_filter(record): | ||
def _boto_debug_filter(record: LogRecord) -> bool: | ||
# Filter debug log messages from botocore and its dependency, urllib3. | ||
# This is required to avoid message storms any time we send logs. | ||
if record.name.startswith("botocore") and record.levelname == "DEBUG": | ||
|
@@ -27,7 +29,7 @@ def _boto_debug_filter(record): | |
return True | ||
|
||
|
||
def _boto_filter(record): | ||
def _boto_filter(record: LogRecord) -> bool: | ||
# Filter log messages from botocore and its dependency, urllib3. | ||
# This is required to avoid an infinite loop when shutting down. | ||
if record.name.startswith("botocore"): | ||
|
@@ -90,25 +92,28 @@ class CloudWatchLogFormatter(logging.Formatter): | |
for more details about the 'default' parameter. By default, watchtower uses a serializer that formats datetime | ||
objects into strings using the `datetime.isoformat()` method, with no other customizations. | ||
""" | ||
add_log_record_attrs = tuple() | ||
add_log_record_attrs: Tuple[str, ...] = () | ||
|
||
def __init__(self, *args, json_serialize_default: callable = None, add_log_record_attrs: tuple = None, **kwargs): | ||
def __init__(self, | ||
*args, | ||
json_serialize_default: Optional[Callable[[Any], Any]] = None, | ||
add_log_record_attrs: Optional[Tuple[str, ...]] = None, | ||
**kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.json_serialize_default = _json_serialize_default | ||
if json_serialize_default is not None: | ||
self.json_serialize_default = json_serialize_default | ||
self.json_serialize_default = json_serialize_default or _json_serialize_default | ||
if add_log_record_attrs is not None: | ||
self.add_log_record_attrs = add_log_record_attrs | ||
|
||
def format(self, message): | ||
def format(self, message: LogRecord) -> str: | ||
msg: Union[str, MutableMapping] = message.msg | ||
if self.add_log_record_attrs: | ||
msg = message.msg if isinstance(message.msg, Mapping) else {"msg": message.msg} | ||
if not isinstance(msg, Mapping): | ||
msg = {"msg": msg} | ||
for field in self.add_log_record_attrs: | ||
if field != "msg": | ||
msg[field] = getattr(message, field) | ||
message.msg = msg | ||
if isinstance(message.msg, Mapping): | ||
return json.dumps(message.msg, default=self.json_serialize_default) | ||
if isinstance(msg, Mapping): | ||
return json.dumps(msg, default=self.json_serialize_default) | ||
return super().format(message) | ||
|
||
|
||
|
@@ -172,29 +177,36 @@ class CloudWatchLogHandler(logging.Handler): | |
# extra size of meta information with each messages | ||
EXTRA_MSG_PAYLOAD_SIZE = 26 | ||
|
||
queues: Dict[str, queue.Queue] | ||
sequence_tokens: Dict[str, Optional[str]] | ||
threads: List[threading.Thread] | ||
|
||
def __init__(self, | ||
log_group_name: str = __name__, | ||
log_stream_name: str = "{machine_name}/{program_name}/{logger_name}", | ||
use_queues: bool = True, | ||
send_interval: int = 60, | ||
send_interval: Union[int, timedelta] = 60, | ||
max_batch_size: int = 1024 * 1024, | ||
max_batch_count: int = 10000, | ||
boto3_client: botocore.client.BaseClient = None, | ||
boto3_profile_name: str = None, | ||
boto3_client: Optional[botocore.client.BaseClient] = None, | ||
boto3_profile_name: Optional[str] = None, | ||
create_log_group: bool = True, | ||
json_serialize_default: callable = None, | ||
log_group_retention_days: int = None, | ||
json_serialize_default: Optional[Callable[[Any], Any]] = None, | ||
log_group_retention_days: Optional[int] = None, | ||
create_log_stream: bool = True, | ||
max_message_size: int = 256 * 1024, | ||
log_group=None, | ||
stream_name=None, | ||
log_group: Optional[str] = None, | ||
stream_name: Optional[str] = None, | ||
*args, | ||
**kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.log_group_name = log_group_name | ||
self.log_stream_name = log_stream_name | ||
self.use_queues = use_queues | ||
self.send_interval = send_interval | ||
if isinstance(send_interval, timedelta): | ||
self.send_interval = send_interval.total_seconds() | ||
else: | ||
self.send_interval = send_interval | ||
self.json_serialize_default = json_serialize_default or _json_serialize_default | ||
self.max_batch_size = max_batch_size | ||
self.max_batch_count = max_batch_count | ||
|
@@ -236,14 +248,14 @@ def __init__(self, | |
logGroupName=self.log_group_name, | ||
retentionInDays=self.log_group_retention_days) | ||
|
||
def _at_fork_reinit(self): | ||
def _at_fork_reinit(self) -> None: | ||
# This was added in Python 3.9 and should only be called with a recent | ||
# version of Python. An older version will attempt to call createLock | ||
# instead. | ||
super()._at_fork_reinit() | ||
super()._at_fork_reinit() # type: ignore | ||
self._init_state() | ||
|
||
def _init_state(self): | ||
def _init_state(self) -> None: | ||
self.queues, self.sequence_tokens = {}, {} | ||
self.threads = [] | ||
self.creating_log_stream, self.shutting_down = False, False | ||
|
@@ -254,7 +266,7 @@ def _paginate(self, boto3_paginator, *args, **kwargs): | |
for value in page.get(result_key.parsed.get("value"), []): | ||
yield value | ||
|
||
def _ensure_log_group(self): | ||
def _ensure_log_group(self) -> None: | ||
try: | ||
paginator = self.cwl_client.get_paginator("describe_log_groups") | ||
for log_group in self._paginate(paginator, logGroupNamePrefix=self.log_group_name): | ||
|
@@ -264,7 +276,7 @@ def _ensure_log_group(self): | |
pass | ||
self._idempotent_call("create_log_group", logGroupName=self.log_group_name) | ||
|
||
def _idempotent_call(self, method, *args, **kwargs): | ||
def _idempotent_call(self, method: str, *args, **kwargs) -> None: | ||
method_callable = getattr(self.cwl_client, method) | ||
try: | ||
method_callable(*args, **kwargs) | ||
|
@@ -273,10 +285,10 @@ def _idempotent_call(self, method, *args, **kwargs): | |
pass | ||
|
||
@functools.lru_cache(maxsize=0) | ||
def _get_machine_name(self): | ||
def _get_machine_name(self) -> str: | ||
return platform.node() | ||
|
||
def _get_stream_name(self, message): | ||
def _get_stream_name(self, message: LogRecord) -> str: | ||
return self.log_stream_name.format( | ||
machine_name=self._get_machine_name(), | ||
program_name=sys.argv[0], | ||
|
@@ -285,7 +297,7 @@ def _get_stream_name(self, message): | |
strftime=datetime.utcnow() | ||
) | ||
|
||
def _submit_batch(self, batch, log_stream_name, max_retries=5): | ||
def _submit_batch(self, batch: Sequence[dict], log_stream_name: str, max_retries: int = 5) -> None: | ||
if len(batch) < 1: | ||
return | ||
sorted_batch = sorted(batch, key=itemgetter('timestamp'), reverse=False) | ||
|
@@ -325,23 +337,23 @@ def _submit_batch(self, batch, log_stream_name, max_retries=5): | |
finally: | ||
self.creating_log_stream = False | ||
else: | ||
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning) | ||
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning) | ||
except Exception as e: | ||
warnings.warn("Failed to deliver logs: {}".format(e), WatchtowerWarning) | ||
warnings.warn(f"Failed to deliver logs: {e}", WatchtowerWarning) | ||
|
||
# response can be None only when all retries have been exhausted | ||
if response is None or "rejectedLogEventsInfo" in response: | ||
warnings.warn("Failed to deliver logs: {}".format(response), WatchtowerWarning) | ||
warnings.warn(f"Failed to deliver logs: {response}", WatchtowerWarning) | ||
elif "nextSequenceToken" in response: | ||
# According to https://github.com/kislyuk/watchtower/issues/134, nextSequenceToken may sometimes be absent | ||
# from the response | ||
self.sequence_tokens[log_stream_name] = response["nextSequenceToken"] | ||
|
||
def createLock(self): | ||
def createLock(self) -> None: | ||
super().createLock() | ||
self._init_state() | ||
|
||
def emit(self, message): | ||
def emit(self, message: LogRecord) -> None: | ||
if self.creating_log_stream: | ||
return # Avoid infinite recursion when asked to log a message as our own side effect | ||
|
||
|
@@ -375,28 +387,32 @@ def emit(self, message): | |
except Exception: | ||
self.handleError(message) | ||
|
||
def _dequeue_batch(self, my_queue, stream_name, send_interval, max_batch_size, max_batch_count, max_message_size): | ||
msg = None | ||
def _dequeue_batch(self, my_queue: queue.Queue, stream_name: str, send_interval: int, max_batch_size: int, | ||
max_batch_count: int, max_message_size: int) -> None: | ||
msg: Union[dict, int, None] = None | ||
max_message_body_size = max_message_size - CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE | ||
assert max_message_body_size > 0 | ||
|
||
def size(_msg): | ||
def size(_msg: Union[dict, int]) -> int: | ||
return (len(_msg["message"]) if isinstance(_msg, dict) else 1) + CloudWatchLogHandler.EXTRA_MSG_PAYLOAD_SIZE | ||
|
||
def truncate(_msg2): | ||
def truncate(_msg2: dict) -> dict: | ||
warnings.warn("Log message size exceeds CWL max payload size, truncated", WatchtowerWarning) | ||
_msg2["message"] = _msg2["message"][:max_message_body_size] | ||
return _msg2 | ||
|
||
# See https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_log_events | ||
while msg != self.END: | ||
cur_batch = [] if msg is None or msg == self.FLUSH else [msg] | ||
cur_batch: List[dict] = [] if msg is None or msg == self.FLUSH else [cast(dict, msg)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks almost all good to go, except let's please remove all of the casts and asserts. The typing information should decorate the function signatures, not be part of the execution. If mypy fails to check this method without calls to cast and asserts, then we should disable mypy checking for those lines, or possibly split up the method into multiple methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll look at the PR to see what makes sense. In case you aren't familiar, |
||
cur_batch_size = sum(map(size, cur_batch)) | ||
cur_batch_msg_count = len(cur_batch) | ||
cur_batch_deadline = time.time() + send_interval | ||
while True: | ||
try: | ||
msg = my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time())) | ||
msg = cast(Union[dict, int], | ||
my_queue.get(block=True, timeout=max(0, cur_batch_deadline - time.time()))) | ||
if size(msg) > max_message_body_size: | ||
assert isinstance(msg, dict) # size always < max_message_body_size if not `dict` | ||
msg = truncate(msg) | ||
except queue.Empty: | ||
# If the queue is empty, we don't want to reprocess the previous message | ||
|
@@ -413,12 +429,13 @@ def truncate(_msg2): | |
my_queue.task_done() | ||
break | ||
elif msg: | ||
assert isinstance(msg, dict) # mypy can't handle all the or expressions filtering out sentinels | ||
cur_batch_size += size(msg) | ||
cur_batch_msg_count += 1 | ||
cur_batch.append(msg) | ||
my_queue.task_done() | ||
|
||
def flush(self): | ||
def flush(self) -> None: | ||
""" | ||
Send any queued messages to CloudWatch. This method does nothing if ``use_queues`` is set to False. | ||
""" | ||
|
@@ -431,7 +448,7 @@ def flush(self): | |
for q in self.queues.values(): | ||
q.join() | ||
|
||
def close(self): | ||
def close(self) -> None: | ||
""" | ||
Send any queued messages to CloudWatch and prevent further processing of messages. | ||
This method does nothing if ``use_queues`` is set to False. | ||
|
@@ -450,6 +467,6 @@ def close(self): | |
q.join() | ||
super().close() | ||
|
||
def __repr__(self): | ||
def __repr__(self) -> str: | ||
name = self.__class__.__name__ | ||
return f"{name}(log_group_name='{self.log_group_name}', log_stream_name='{self.log_stream_name}')" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was updated to match the text, but the
:type:
was previously justint