Skip to content

Commit

Permalink
Fix remaining ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bzwei committed Jun 22, 2023
1 parent 9c17bd9 commit 55d18b3
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import multiprocessing as mp


def main(event: dict, overwrite: bool = True) -> dict:
def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002
"""Change dashes in keys to underscores."""
logger = mp.get_logger()
logger.info("dashes_to_underscores")
Expand Down
6 changes: 4 additions & 2 deletions extensions/eda/plugins/event_filter/insert_hosts_to_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@
"""

from __future__ import annotations

from typing import Any

import dpath


def main(
event: dict[str, Any],
host_path: str = None,
host_separator: str = None,
host_path: str | None = None,
host_separator: str | None = None,
path_separator: str = ".",
) -> dict[str, Any]:
"""Extract hosts from event data and insert into meta dict."""
Expand Down
8 changes: 7 additions & 1 deletion extensions/eda/plugins/event_filter/json_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""

from __future__ import annotations

import fnmatch


Expand All @@ -24,7 +26,11 @@ def _matches_exclude_keys(exclude_keys: list, s: str) -> bool:
return any(fnmatch.fnmatch(s, pattern) for pattern in exclude_keys)


def main(event: dict, exclude_keys: list = None, include_keys: list = None) -> dict:
def main(
event: dict,
exclude_keys: list | None = None,
include_keys: list | None = None,
) -> dict:
"""Filter keys out of events."""
if exclude_keys is None:
exclude_keys = []
Expand Down
19 changes: 14 additions & 5 deletions extensions/eda/plugins/event_filter/normalize_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,44 @@
"""

import logging
import multiprocessing as mp
import re

normalize_regex = re.compile("[^0-9a-zA-Z_]+")


def main(event: dict, overwrite: bool = True) -> dict:
def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002
"""Change keys that contain non-alphanumeric characters to underscores."""
logger = mp.get_logger()
logger.info("normalize_keys")
return _normalize_embedded_keys(event, overwrite, logger)


def _normalize_embedded_keys(obj: dict, overwrite: bool, logger) -> dict:
def _normalize_embedded_keys(
obj: dict,
overwrite: bool, # noqa: FBT001
logger: logging.Logger,
) -> dict:
if isinstance(obj, dict):
new_dict = {}
original_keys = list(obj.keys())
for key in original_keys:
new_key = normalize_regex.sub("_", key)
if new_key == key or new_key not in original_keys:
new_dict[new_key] = _normalize_embedded_keys(
obj[key], overwrite, logger,
obj[key],
overwrite,
logger,
)
elif new_key in original_keys and overwrite:
new_dict[new_key] = _normalize_embedded_keys(
obj[key], overwrite, logger,
obj[key],
overwrite,
logger,
)
logger.warning("Replacing existing key %s", new_key)
return new_dict
elif isinstance(obj, list):
if isinstance(obj, list):
return [_normalize_embedded_keys(item, overwrite, logger) for item in obj]
return obj
13 changes: 8 additions & 5 deletions extensions/eda/plugins/event_source/alertmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"""

import asyncio
import logging
from typing import Any

from aiohttp import web
Expand All @@ -42,7 +43,7 @@


@routes.get("/")
async def status(request: web.Request) -> web.Response:
async def status(_request: web.Request) -> web.Response:
"""Return status of a web request."""
return web.Response(status=200, text="up")

Expand Down Expand Up @@ -96,7 +97,9 @@ async def webhook(request: web.Request) -> web.Response:
{
"alert": alert,
"meta": {
"endpoint": endpoint, "headers": dict(request.headers), "hosts": hosts,
"endpoint": endpoint,
"headers": dict(request.headers),
"hosts": hosts,
},
},
)
Expand Down Expand Up @@ -130,7 +133,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
try:
await asyncio.Future()
except asyncio.CancelledError:
print("Plugin Task Cancelled")
logging.getLogger().info("Plugin Task Cancelled")
finally:
await runner.cleanup()

Expand All @@ -141,8 +144,8 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
class MockQueue:
"""A fake queue."""

async def put(self, event: dict) -> None:
async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201
print(event) # noqa: T201

asyncio.run(main(MockQueue(), {}))
15 changes: 8 additions & 7 deletions extensions/eda/plugins/event_source/aws_cloudtrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from datetime import datetime
from typing import Any

from aiobotocore.client import BaseClient
from aiobotocore.session import get_session


Expand All @@ -45,7 +46,7 @@ def _cloudtrail_event_to_dict(event: dict) -> dict:
return event


def get_events(events, last_event_ids):
def _get_events(events: list[dict], last_event_ids: list) -> list:
event_time = None
event_ids = []
result = []
Expand All @@ -62,7 +63,7 @@ def get_events(events, last_event_ids):
return result, event_time, event_ids


async def get_cloudtrail_events(client, params):
async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]:
paginator = client.get_paginator("lookup_events")
results = await paginator.paginate(**params).build_full_result()
return results.get("Events", [])
Expand All @@ -84,17 +85,17 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
if args.get(k) is not None:
params[v] = args.get(k)

params["StartTime"] = datetime.utcnow()
params["StartTime"] = datetime.utcnow() # noqa: DTZ003

async with session.create_client("cloudtrail", **connection_args(args)) as client:
event_time = None
event_ids = []
while True:
events = await get_cloudtrail_events(client, params)
events = await _get_cloudtrail_events(client, params)
if event_time is not None:
params["StartTime"] = event_time

events, c_event_time, c_event_ids = get_events(events, event_ids)
events, c_event_time, c_event_ids = _get_events(events, event_ids)
for event in events:
await queue.put(_cloudtrail_event_to_dict(event))

Expand Down Expand Up @@ -130,8 +131,8 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]:
class MockQueue:
"""A fake queue."""

async def put(self, event: dict) -> None:
async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201
print(event) # noqa: T201

asyncio.run(main(MockQueue(), {}))
6 changes: 3 additions & 3 deletions extensions/eda/plugins/event_source/aws_sqs_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
err.response["Error"]["Code"]
== "AWS.SimpleQueueService.NonExistentQueue"
):
raise ValueError("Queue %s does not exist" % queue_name)
raise ValueError("Queue %s does not exist" % queue_name) from None
raise

queue_url = response["QueueUrl"]
Expand Down Expand Up @@ -108,8 +108,8 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]:
class MockQueue:
"""A fake queue."""

async def put(self, event: dict) -> None:
async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201
print(event) # noqa: T201

asyncio.run(main(MockQueue(), {"region": "us-east-1", "name": "eda"}))
11 changes: 7 additions & 4 deletions extensions/eda/plugins/event_source/azure_service_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@


def receive_events(
loop: asyncio.events.AbstractEventLoop, queue: asyncio.Queue, args: dict[str, Any],
loop: asyncio.events.AbstractEventLoop,
queue: asyncio.Queue,
args: dict[str, Any],
) -> None:
"""Receive events from service bus."""
servicebus_client = ServiceBusClient.from_connection_string(
Expand All @@ -44,7 +46,8 @@ def receive_events(
body = json.loads(body)

loop.call_soon_threadsafe(
queue.put_nowait, {"body": body, "meta": meta},
queue.put_nowait,
{"body": body, "meta": meta},
)
receiver.complete_message(msg)

Expand All @@ -63,9 +66,9 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
class MockQueue:
"""A fake queue."""

async def put_nowait(self, event: dict) -> None:
async def put_nowait(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201
print(event) # noqa: T201

args = {
"conn_str": "Endpoint=sb://foo.servicebus.windows.net/",
Expand Down
40 changes: 21 additions & 19 deletions extensions/eda/plugins/event_source/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,69 +15,71 @@
"""

import os
import pathlib

import yaml
from watchdog.events import RegexMatchingEventHandler
from watchdog.observers import Observer


def send_facts(queue, filename: str) -> None:
def send_facts(queue, filename: str) -> None: # noqa: ANN001
"""Send facts to the queue."""
with open(filename) as f:
with pathlib.Path(filename).open() as f:
data = yaml.safe_load(f.read())
if data is None:
return
if isinstance(data, dict):
queue.put(data)
else:
if not isinstance(data, list):
msg = f"Unsupported facts type, expects a list of dicts found {type(data)}"
raise Exception(
msg,
msg = (
"Unsupported facts type, expects a list of dicts found "
f"{type(data)}"
)
raise TypeError(msg)
if not all(bool(isinstance(item, dict)) for item in data):
msg = f"Unsupported facts type, expects a list of dicts found {data}"
raise Exception(
msg,
)
raise TypeError(msg)
for item in data:
queue.put(item)


def main(queue, args: dict) -> None:
def main(queue, args: dict) -> None: # noqa: ANN001
"""Load facts from YAML files initially and when the file changes."""
files = [os.path.abspath(f) for f in args.get("files", [])]
files = [pathlib.Path(f).resolve().as_posix() for f in args.get("files", [])]

if not files:
return

for filename in files:
send_facts(queue, filename)
_observe_files(queue, files)


def _observe_files(queue, files: list[str]) -> None: # noqa: ANN001
class Handler(RegexMatchingEventHandler):
def __init__(self, **kwargs) -> None:
def __init__(self: "Handler", **kwargs) -> None: # noqa: ANN003
RegexMatchingEventHandler.__init__(self, **kwargs)

def on_created(self, event: dict) -> None:
def on_created(self: "Handler", event: dict) -> None:
if event.src_path in files:
send_facts(queue, event.src_path)

def on_deleted(self, event: dict) -> None:
def on_deleted(self: "Handler", event: dict) -> None:
pass

def on_modified(self, event: dict) -> None:
def on_modified(self: "Handler", event: dict) -> None:
if event.src_path in files:
send_facts(queue, event.src_path)

def on_moved(self, event: dict) -> None:
def on_moved(self: "Handler", event: dict) -> None:
pass

observer = Observer()
handler = Handler()

for filename in files:
directory = os.path.dirname(filename)
directory = pathlib.Path(filename).parent
observer.schedule(handler, directory, recursive=False)

observer.start()
Expand All @@ -94,8 +96,8 @@ def on_moved(self, event: dict) -> None:
class MockQueue:
"""A fake queue."""

async def put(self, event: dict) -> None:
async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201
print(event) # noqa: T201

main(MockQueue(), {"files": ["facts.yml"]})
Loading

0 comments on commit 55d18b3

Please sign in to comment.