Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy and application bootstrap test to boefjes #2460

Merged
merged 4 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
additional_dependencies: ['types-PyYAML', 'types-requests', 'types-cachetools', 'types-retry', 'pydantic', 'pynacl']
exclude: |
(?x)(
^boefjes/ |
^boefjes/boefjes/plugins |
^keiko/templates |
^mula/whitelist\.py$ |
^octopoes/ |
Expand Down
7 changes: 4 additions & 3 deletions boefjes/boefjes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime, timezone
from enum import Enum
from typing import List, Optional
from uuid import UUID

from fastapi import Depends, FastAPI, HTTPException, Response
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -49,7 +50,7 @@ def run():


class BoefjeInput(BaseModel):
task_id: str
task_id: UUID
output_url: str
boefje_meta: BoefjeMeta
model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -90,7 +91,7 @@ async def root():

@app.get("/api/v0/tasks/{task_id}", response_model=BoefjeInput)
async def boefje_input(
task_id: str,
task_id: UUID,
scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client),
local_repository: LocalPluginRepository = Depends(get_local_repository),
):
Expand All @@ -107,7 +108,7 @@ async def boefje_input(

@app.post("/api/v0/tasks/{task_id}")
async def boefje_output(
task_id: str,
task_id: UUID,
boefje_output: BoefjeOutput,
scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client),
bytes_client: BytesAPIClient = Depends(get_bytes_client),
Expand Down
37 changes: 22 additions & 15 deletions boefjes/boefjes/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import signal
import sys
import time
from queue import Queue
from typing import Dict, List, Optional, Tuple

from pydantic import ValidationError
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(

self.task_queue = manager.Queue() # multiprocessing.Queue() will not work on macOS, see mp.Queue.qsize()
self.handling_tasks = manager.dict()
self.workers = []
self.workers: List[mp.Process] = []

logger.setLevel(log_level)

Expand Down Expand Up @@ -77,7 +78,7 @@ def run(self, queue_type: WorkerManager.Queue) -> None:

raise

def _fill_queue(self, task_queue: mp.Queue, queue_type: WorkerManager.Queue):
def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue):
if task_queue.qsize() > self.settings.pool_size:
time.sleep(self.settings.worker_heartbeat)
return
Expand Down Expand Up @@ -141,15 +142,20 @@ def _check_workers(self) -> None:
new_workers = []

for worker in self.workers:
if not worker._closed and worker.is_alive():
new_workers.append(worker)
continue
closed = False

try:
if worker.is_alive():
new_workers.append(worker)
continue
except ValueError:
closed = True # worker is closed, so we create a new one

logger.warning(
"Worker[pid=%s, %s] not alive, creating new worker...", worker.pid, _format_exit_code(worker.exitcode)
)

if not worker._closed: # Closed workers do not have a pid, so cleaning up would fail
if not closed: # Closed workers do not have a pid, so cleaning up would fail
self._cleanup_pending_worker_task(worker)
worker.close()

Expand Down Expand Up @@ -198,9 +204,12 @@ def exit(self, queue_type: WorkerManager.Queue, signum: Optional[int] = None):
killed_workers = []

for worker in self.workers: # Send all signals before joining, speeding up shutdowns
if not worker._closed and worker.is_alive():
worker.kill()
killed_workers.append(worker)
try:
if worker.is_alive():
worker.kill()
killed_workers.append(worker)
except ValueError:
pass # worker is already closed

for worker in killed_workers:
worker.join()
Expand All @@ -215,8 +224,8 @@ def exit(self, queue_type: WorkerManager.Queue, signum: Optional[int] = None):
sys.exit()


def _format_exit_code(exitcode: int) -> str:
if exitcode >= 0:
def _format_exit_code(exitcode: Optional[int]) -> str:
if exitcode is None or exitcode >= 0:
return f"exitcode={exitcode}"

return f"signal={signal.Signals(-exitcode).name}"
Expand Down Expand Up @@ -256,10 +265,8 @@ def get_runtime_manager(settings: Settings, queue: WorkerManager.Queue, log_leve
if queue is WorkerManager.Queue.BOEFJES:
item_handler = BoefjeHandler(LocalBoefjeJobRunner(local_repository), local_repository, bytes_api_client)
else:
item_handler = NormalizerHandler(
LocalNormalizerJobRunner(local_repository),
bytes_api_client,
settings.scan_profile_whitelist,
item_handler = NormalizerHandler( # type: ignore
LocalNormalizerJobRunner(local_repository), bytes_api_client, settings.scan_profile_whitelist
)

return SchedulerWorkerManager(
Expand Down
4 changes: 2 additions & 2 deletions boefjes/boefjes/clients/bytes_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import typing
from functools import wraps
from typing import Any, Callable, Dict, Set, Union
from typing import Any, Callable, Dict, FrozenSet, Union
from uuid import UUID

import requests
Expand Down Expand Up @@ -103,7 +103,7 @@ def save_normalizer_meta(self, normalizer_meta: NormalizerMeta) -> None:
self._verify_response(response)

@retry_with_login
def save_raw(self, boefje_meta_id: str, raw: bytes, mime_types: Set[str] = frozenset()) -> UUID:
def save_raw(self, boefje_meta_id: str, raw: bytes, mime_types: FrozenSet[str] = frozenset()) -> UUID:
headers = {"content-type": "application/octet-stream"}
headers.update(self.headers)

Expand Down
4 changes: 2 additions & 2 deletions boefjes/boefjes/clients/scheduler_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def push_item(self, queue_id: str, p_item: QueuePrioritizedItem) -> None:
response = self._session.post(f"{self.base_url}/queues/{queue_id}/push", data=p_item.json())
self._verify_response(response)

def patch_task(self, task_id: str, status: TaskStatus) -> None:
def patch_task(self, task_id: uuid.UUID, status: TaskStatus) -> None:
response = self._session.patch(f"{self.base_url}/tasks/{task_id}", json={"status": status.value})
self._verify_response(response)

def get_task(self, task_id: str) -> Task:
def get_task(self, task_id: uuid.UUID) -> Task:
response = self._session.get(f"{self.base_url}/tasks/{task_id}")
self._verify_response(response)

Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Settings(BaseSettings):
"1.1.1.1", description="Name server used for remote DNS resolution in the boefje runner"
)

scan_profile_whitelist: Dict[str, conint(strict=True, ge=0, le=4)] = Field(
scan_profile_whitelist: Dict[str, conint(strict=True, ge=0, le=4)] = Field( # type: ignore
default_factory=dict,
description="Whitelist for normalizer ids allowed to produce scan profiles, including a maximum level.",
examples=['{"kat_external_db_normalize": 3, "kat_dns_normalize": 1}'],
Expand Down
4 changes: 2 additions & 2 deletions boefjes/boefjes/docker_boefjes_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def run(self) -> None:

stderr_mime_types = boefjes.plugins.models._default_mime_types(self.boefje_meta.boefje)

task_id = str(self.boefje_meta.id)
task_id = self.boefje_meta.id
self.scheduler_client.patch_task(task_id, TaskStatus.RUNNING)
self.boefje_meta.started_at = datetime.now(timezone.utc)

try:
input_url = str(settings.api).rstrip("/") + f"/api/v0/tasks/{task_id}"
container_logs = self.docker_client.containers.run(
image=self.boefje_resource.oci_image,
name="kat_boefje_" + task_id,
name="kat_boefje_" + str(task_id),
command=input_url,
stdout=False,
stderr=True,
Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(
):
self.job_runner = job_runner
self.bytes_client: BytesAPIClient = bytes_client
self.whitelist = whitelist
self.whitelist = whitelist or {}
self.octopoes_factory = octopoes_factory

def handle(self, normalizer_meta: NormalizerMeta) -> None:
Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/katalogus/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def get_plugin(self, repository: Repository, plugin_id: str) -> PluginType:
res = self._session.get(f"{repository.base_url}/plugins/{plugin_id}")
res.raise_for_status()

return PluginType.model_validate_json(res.content)
return PluginType.model_validate_json(res.content) # type: ignore
4 changes: 2 additions & 2 deletions boefjes/boefjes/katalogus/dependencies/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def get_all(self, organisation_id: str) -> List[PluginType]:
all_plugins = self._plugins_for_repos(self.repository_storage.get_all().values(), organisation_id)

flat = []
flat: List[PluginType] = []

for plugins in all_plugins.values():
flat.extend(plugins.values())
Expand Down Expand Up @@ -161,7 +161,7 @@ def update_by_id(self, repository_id: str, plugin_id: str, organisation_id: str,
def _plugins_for_repos(
self, repositories: Iterable[Repository], organisation_id: str
) -> Dict[str, Dict[str, PluginType]]:
plugins = {}
plugins: Dict[str, Dict[str, PluginType]] = {}

for repository in repositories:
if repository.id == RESERVED_LOCAL_ID:
Expand Down
6 changes: 3 additions & 3 deletions boefjes/boefjes/katalogus/local_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import pkgutil
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from boefjes.katalogus.models import PluginType
from boefjes.plugins.models import (
Expand All @@ -22,8 +22,8 @@
class LocalPluginRepository:
def __init__(self, path: Path):
self.path = path
self._cached_boefjes = None
self._cached_normalizers = None
self._cached_boefjes: Optional[Dict[str, Any]] = None
self._cached_normalizers: Optional[Dict[str, Any]] = None

def get_all(self) -> List[PluginType]:
all_plugins = [boefje_resource.boefje for boefje_resource in self.resolve_boefjes().values()]
Expand Down
8 changes: 4 additions & 4 deletions boefjes/boefjes/katalogus/storage/memory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

from boefjes.katalogus.models import RESERVED_LOCAL_ID, Organisation, Repository
from boefjes.katalogus.storage.interfaces import (
Expand All @@ -19,7 +19,7 @@


class OrganisationStorageMemory(OrganisationStorage):
def __init__(self, defaults: Dict[str, Organisation] = None):
def __init__(self, defaults: Optional[Dict[str, Organisation]] = None):
self._data = organisations if defaults is None else defaults

def get_by_id(self, organisation_id: str) -> Organisation:
Expand All @@ -39,7 +39,7 @@ class RepositoryStorageMemory(RepositoryStorage):
def __init__(
self,
organisation_id: str,
defaults: Dict[str, Repository] = None,
defaults: Optional[Dict[str, Repository]] = None,
):
self._data = repositories.setdefault(organisation_id, {}) if defaults is None else defaults
self._organisation_id = organisation_id
Expand Down Expand Up @@ -84,7 +84,7 @@ class PluginStatesStorageMemory(PluginEnabledStorage):
def __init__(
self,
organisation: str,
defaults: Dict[str, bool] = None,
defaults: Optional[Dict[str, bool]] = None,
):
self._data = plugins_state.setdefault(organisation, {}) if defaults is None else defaults
self._organisation = organisation
Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/katalogus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class PaginationParameters(BaseModel):
offset: int = 0
limit: Optional[int] = LIMIT
limit: int = LIMIT


class FilterParameters(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions boefjes/boefjes/runtime_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union

from boefjes.job_models import BoefjeMeta, NormalizerMeta, NormalizerOutput

Expand All @@ -10,7 +10,7 @@ def handle(self, item: Union[BoefjeMeta, NormalizerMeta]):


class BoefjeJobRunner:
def run(self, boefje_meta: BoefjeMeta, environment: Dict[str, str]) -> Tuple[BoefjeMeta, Union[str, bytes]]:
def run(self, boefje_meta: BoefjeMeta, environment: Dict[str, str]) -> List[Tuple[set, Union[bytes, str]]]:
raise NotImplementedError()


Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/sql/plugin_enabled_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_all_enabled(self, organisation_id: str) -> Dict[str, List[str]]:
.filter(PluginStateInDB.enabled)
)

per_repository = {}
per_repository: Dict[str, List[str]] = {}

for state in query.all():
if state.repository.id not in per_repository:
Expand Down
4 changes: 2 additions & 2 deletions boefjes/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def patch_task(self, task_id: UUID, status: TaskStatus) -> None:
with self.log_path.open("a") as f:
f.write(f"{task_id},{status.value}\n")

task = self._task_from_id(task_id) if task_id not in self._tasks else self._tasks[str(task_id)]
task = self._task_from_id(task_id) if str(task_id) not in self._tasks else self._tasks[str(task_id)]
task.status = status
self._tasks[str(task_id)] = task

Expand All @@ -76,7 +76,7 @@ def get_all_patched_tasks(self) -> List[Tuple[str, ...]]:
return [tuple(x.strip().split(",")) for x in f]

def get_task(self, task_id: UUID) -> Task:
return self._task_from_id(task_id) if task_id not in self._tasks else self._tasks[str(task_id)]
return self._task_from_id(task_id) if str(task_id) not in self._tasks else self._tasks[str(task_id)]

def _task_from_id(self, task_id: UUID):
return Task(
Expand Down
8 changes: 7 additions & 1 deletion boefjes/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest

from boefjes.app import SchedulerWorkerManager
from boefjes.app import SchedulerWorkerManager, get_runtime_manager
from boefjes.config import Settings
from boefjes.runtime_interfaces import WorkerManager
from tests.conftest import MockHandler, MockSchedulerClient
from tests.loading import get_dummy_data
Expand Down Expand Up @@ -151,3 +152,8 @@ def test_null(manager: SchedulerWorkerManager, tmp_path: Path, item_handler: Moc
assert len(patched_tasks) == 3
assert patched_tasks[0] == ("70da7d4f-f41f-4940-901b-d98a92e9014b", "completed")
assert patched_tasks[2] == ("70da7d4f-f41f-4940-901b-d98a92e9014b", "completed")


def test_create_manager():
get_runtime_manager(Settings(), WorkerManager.Queue.BOEFJES, "INFO")
get_runtime_manager(Settings(), WorkerManager.Queue.NORMALIZERS, "INFO")