From cff49f6e5a879d2ec8a5998af3787f581d0f93aa Mon Sep 17 00:00:00 2001 From: Martin Gauk Date: Wed, 9 Oct 2024 16:08:27 +0200 Subject: [PATCH] fix: move time limit logic to a mixin and respect real time --- questionpy_server/worker/__init__.py | 1 + questionpy_server/worker/exception.py | 17 ++- questionpy_server/worker/impl/_base.py | 116 +++++++++++++++--- questionpy_server/worker/impl/subprocess.py | 69 ++++------- questionpy_server/worker/impl/thread.py | 2 +- questionpy_server/worker/runtime/manager.py | 2 +- questionpy_server/worker/runtime/messages.py | 8 +- .../worker/impl/test_subprocess.py | 69 +++++++++-- 8 files changed, 205 insertions(+), 79 deletions(-) diff --git a/questionpy_server/worker/__init__.py b/questionpy_server/worker/__init__.py index 7a91647..aaa81a8 100644 --- a/questionpy_server/worker/__init__.py +++ b/questionpy_server/worker/__init__.py @@ -54,6 +54,7 @@ class Worker(ABC): """Interface for worker implementations.""" def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None: + super().__init__() self.package = package self.limits = limits self.state = WorkerState.NOT_RUNNING diff --git a/questionpy_server/worker/exception.py b/questionpy_server/worker/exception.py index 75f52f6..74b3437 100644 --- a/questionpy_server/worker/exception.py +++ b/questionpy_server/worker/exception.py @@ -1,18 +1,27 @@ # This file is part of the QuestionPy Server. (https://questionpy.org) # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus +from questionpy_server.worker.runtime.messages import BaseWorkerError -class WorkerNotRunningError(Exception): +class WorkerNotRunningError(BaseWorkerError): pass -class WorkerStartError(Exception): +class WorkerStartError(BaseWorkerError): pass -class WorkerCPUTimeLimitExceededError(Exception): - pass +class WorkerCPUTimeLimitExceededError(BaseWorkerError): + def __init__(self, limit: float): + self.limit = limit + super().__init__(f"Worker has exceeded its CPU time limit of {limit} seconds and was killed.") + + +class WorkerRealTimeLimitExceededError(BaseWorkerError): + def __init__(self, limit: float): + self.limit = limit + super().__init__(f"Worker has exceeded its real time limit of {limit} seconds and was killed.") class StaticFileSizeMismatchError(Exception): diff --git a/questionpy_server/worker/impl/_base.py b/questionpy_server/worker/impl/_base.py index 73dcbce..fd72611 100644 --- a/questionpy_server/worker/impl/_base.py +++ b/questionpy_server/worker/impl/_base.py @@ -5,21 +5,29 @@ import asyncio import contextlib import logging -from abc import ABC +import time +from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from zipfile import ZipFile from questionpy_common.api.attempt import AttemptModel, AttemptScoredModel, AttemptStartedModel from questionpy_common.constants import DIST_DIR from questionpy_common.elements import OptionsFormDefinition -from questionpy_common.environment import RequestUser, WorkerResourceLimits +from questionpy_common.environment import RequestUser from questionpy_common.manifest import Manifest, PackageFile from questionpy_server.models import QuestionCreated from questionpy_server.utils.manifest import ComparableManifest from questionpy_server.worker import PackageFileData, Worker, WorkerState -from questionpy_server.worker.exception import StaticFileSizeMismatchError, WorkerNotRunningError, WorkerStartError +from questionpy_server.worker.exception import ( + StaticFileSizeMismatchError, + WorkerCPUTimeLimitExceededError, + WorkerNotRunningError, + WorkerRealTimeLimitExceededError, + WorkerStartError, +) from questionpy_server.worker.runtime.messages import ( + BaseWorkerError, CreateQuestionFromOptions, Exit, GetOptionsForm, @@ -37,7 +45,6 @@ from questionpy_server.worker.runtime.package_location import ( DirPackageLocation, FunctionPackageLocation, - PackageLocation, ZipPackageLocation, ) @@ -64,14 +71,17 @@ class BaseWorker(Worker, ABC): """Base class implementing some common functionality of workers.""" _worker_type = "unknown" + _init_worker_timeout = 2 + _load_qpy_package_timeout = 4 - def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None: - super().__init__(package, limits) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._observe_task: asyncio.Task | None = None self._connection: ServerToWorkerConnection | None = None self._expected_incoming_messages: list[tuple[MessageIds, asyncio.Future]] = [] + self._receive_messages_exception: BaseException | None = None async def _initialize(self) -> None: """Initializes an already running worker and starts the observe task. @@ -88,11 +98,14 @@ async def _initialize(self) -> None: worker_type=self._worker_type, ), InitWorker.Response, + self._init_worker_timeout, ) await self.send_and_wait_for_response( - LoadQPyPackage(location=self.package, main=True), LoadQPyPackage.Response + LoadQPyPackage(location=self.package, main=True), + LoadQPyPackage.Response, + self._load_qpy_package_timeout, ) - except WorkerNotRunningError as e: + except BaseWorkerError as e: msg = "Worker has exited before or during initialization." raise WorkerStartError(msg) from e @@ -101,7 +114,9 @@ def send(self, message: MessageToWorker) -> None: raise WorkerNotRunningError self._connection.send_message(message) - async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_M]) -> _M: + async def send_and_wait_for_response( + self, message: MessageToWorker, expected_response_message: type[_M], timeout: float | None = None + ) -> _M: self.send(message) fut = asyncio.get_running_loop().create_future() self._expected_incoming_messages.append((expected_response_message.message_id, fut)) @@ -138,7 +153,8 @@ async def _receive_messages(self) -> None: finally: for _, future in self._expected_incoming_messages: if not future.done(): - future.set_exception(WorkerNotRunningError()) + exc = self._receive_messages_exception or WorkerNotRunningError() + future.set_exception(exc) self._expected_incoming_messages = [] def _get_observation_tasks(self) -> Sequence[asyncio.Task]: @@ -153,10 +169,14 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]: async def _observe(self) -> None: """Observes the tasks returned by _get_observation_tasks.""" - pending: Sequence[asyncio.Task] = [] + pending: set[asyncio.Task] try: tasks = self._get_observation_tasks() - _, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for task in done: + with contextlib.suppress(asyncio.CancelledError): + if exc := task.exception(): + self._receive_messages_exception = exc finally: self.state = WorkerState.NOT_RUNNING @@ -170,7 +190,7 @@ async def _observe(self) -> None: async def stop(self, timeout: float) -> None: try: self.send(Exit()) - except WorkerNotRunningError: + except BaseWorkerError: # No need to stop it then. return @@ -292,3 +312,71 @@ async def get_static_file(self, path: str) -> PackageFileData: async def get_static_file_index(self) -> dict[str, PackageFile]: return (await self.get_manifest()).static_files + + +class LimitTimeUsageMixin(Worker, ABC): + """Implements a CPU and real time usage limit for a worker. + + _limit_cpu_time_usage needs to be added to the return value of :meth:`BaseWorker._get_observation_tasks`. + The worker will be killed when the CPU time limit is exceeded or the worker took more than three times + the cpu limit in real time. + """ + + _real_time_limit_factor = 3 + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._cur_cpu_time_limit: float = 0 + self._request_started_cpu_time: float | None = None + self._request_started_time: float | None = None + self._request_started_event = asyncio.Event() + + @abstractmethod + def _get_cpu_time(self) -> float: + """Get worker's current CPU time (user and system time). + + Returns: + CPU time in seconds + """ + + def _set_time_limit(self, limit: float) -> None: + """Set a CPU and real time limit. + + The real time limit is the CPU time limit * :meth:`LimitTimeUsageMixin._real_time_limit_factor`. + + Args: + limit: CPU time limit in seconds + """ + self._cur_cpu_time_limit = limit + self._request_started_cpu_time = self._get_cpu_time() + self._request_started_time = time.time() + self._request_started_event.set() + + def _reset_time_limit(self) -> None: + self._cur_cpu_time_limit = 0 + self._request_started_cpu_time = None + self._request_started_time = None + self._request_started_event.clear() + + async def _limit_cpu_time_usage(self) -> None: + """Ensures that the worker will be killed when it is taking too much time. Executed as a task.""" + while True: + await self._request_started_event.wait() + + # CPU-time is always less or equal to real time (when single-threaded). + await asyncio.sleep(self._cur_cpu_time_limit) + + # Check if the start time is still set. Probably the request was already processed or + # maybe another request started meanwhile. + while self._request_started_cpu_time is not None and self._request_started_time is not None: + remaining_cpu_time = self._request_started_cpu_time + self._cur_cpu_time_limit - self._get_cpu_time() + if remaining_cpu_time <= 0: + raise WorkerCPUTimeLimitExceededError(self._cur_cpu_time_limit) + + remaining_time = ( + self._request_started_time + (self._cur_cpu_time_limit * self._real_time_limit_factor) - time.time() + ) + if remaining_time <= 0: + raise WorkerRealTimeLimitExceededError(self._cur_cpu_time_limit * self._real_time_limit_factor) + + await asyncio.sleep(max(min(remaining_cpu_time, remaining_time), 0.05)) diff --git a/questionpy_server/worker/impl/subprocess.py b/questionpy_server/worker/impl/subprocess.py index 522725f..fb14e07 100644 --- a/questionpy_server/worker/impl/subprocess.py +++ b/questionpy_server/worker/impl/subprocess.py @@ -4,6 +4,7 @@ import asyncio import logging +import math import sys from asyncio import StreamReader from collections.abc import Sequence @@ -16,8 +17,8 @@ from questionpy_common.environment import WorkerResourceLimits from questionpy_server.worker import WorkerResources from questionpy_server.worker.connection import ServerToWorkerConnection -from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError, WorkerNotRunningError, WorkerStartError -from questionpy_server.worker.impl._base import BaseWorker +from questionpy_server.worker.exception import WorkerNotRunningError, WorkerStartError +from questionpy_server.worker.impl._base import BaseWorker, LimitTimeUsageMixin from questionpy_server.worker.runtime.messages import MessageToServer, MessageToWorker from questionpy_server.worker.runtime.package_location import PackageLocation @@ -71,7 +72,7 @@ def flush(self) -> None: self._skipped_bytes = 0 -class SubprocessWorker(BaseWorker): +class SubprocessWorker(BaseWorker, LimitTimeUsageMixin): """Worker implementation running in a non-sandboxed subprocess.""" _worker_type = "process" @@ -80,7 +81,7 @@ class SubprocessWorker(BaseWorker): _runtime_main = ["-m", "questionpy_server.worker.runtime"] def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None): - super().__init__(package, limits) + super().__init__(package=package, limits=limits) self._proc: Process | None = None self._stderr_buffer: _StderrBuffer | None = None @@ -111,48 +112,16 @@ async def start(self) -> None: # Whether initialization was successful or not, flush the logs. self._stderr_buffer.flush() - async def _limit_cpu_time_usage(self, expected_response_message: type[_T]) -> None: - if not self._proc or self._proc.returncode is not None: - raise WorkerNotRunningError - - if self.limits is None: - return - - psutil_proc = psutil.Process(self._proc.pid) - - # Get current cpu times and calculate the maximum cpu time for the current call. - cpu_times = psutil_proc.cpu_times() - max_cpu_time = cpu_times.user + cpu_times.system + self.limits.max_cpu_time_seconds_per_call - - # CPU-time is always less or equal to real time. - await asyncio.sleep(self.limits.max_cpu_time_seconds_per_call) - - while True: - cpu_times = psutil_proc.cpu_times() - remaining_time = max_cpu_time - (cpu_times.user + cpu_times.system) - if remaining_time <= 0: - break - await asyncio.sleep(max(remaining_time, 0.05)) - - # Set the exception and kill the process. - for future in [ - fut - for expected_id, fut in self._expected_incoming_messages - if expected_id == expected_response_message.message_id - ]: - future.set_exception(WorkerCPUTimeLimitExceededError) - self._expected_incoming_messages.remove((expected_response_message.message_id, future)) - - await self.kill() - - async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_T]) -> _T: - timeout = asyncio.create_task( - self._limit_cpu_time_usage(expected_response_message), name="limit cpu time usage" - ) + async def send_and_wait_for_response( + self, message: MessageToWorker, expected_response_message: type[_T], timeout: float | None = None + ) -> _T: try: - return await super().send_and_wait_for_response(message, expected_response_message) + if timeout is None: + timeout = self.limits.max_cpu_time_seconds_per_call if self.limits else math.inf + self._set_time_limit(timeout) + return await super().send_and_wait_for_response(message, expected_response_message, timeout) finally: - timeout.cancel() + self._reset_time_limit() # Write worker's stderr to log after every exchange. if self._stderr_buffer: self._stderr_buffer.flush() @@ -176,8 +145,20 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]: *super()._get_observation_tasks(), asyncio.create_task(self._proc.wait(), name="wait for worker process"), asyncio.create_task(self._stderr_buffer.read_stderr(), name="receive stderr from worker"), + asyncio.create_task(self._limit_cpu_time_usage(), name="limit cpu time usage"), ) async def kill(self) -> None: if self._proc and self._proc.returncode is None: self._proc.kill() + + # Make sure that all resources of the subprocesses are getting cleaned. + await self._proc.wait() + + def _get_cpu_time(self) -> float: + if not self._proc or self._proc.returncode is not None: + raise WorkerNotRunningError + + psutil_proc = psutil.Process(self._proc.pid) + cpu_times = psutil_proc.cpu_times() + return cpu_times.user + cpu_times.system diff --git a/questionpy_server/worker/impl/thread.py b/questionpy_server/worker/impl/thread.py index 777869d..cf2c086 100644 --- a/questionpy_server/worker/impl/thread.py +++ b/questionpy_server/worker/impl/thread.py @@ -65,7 +65,7 @@ class ThreadWorker(BaseWorker): _worker_type = "thread" def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None: - super().__init__(package, limits) + super().__init__(package=package, limits=limits) self._pipe: DuplexPipe | None = None diff --git a/questionpy_server/worker/runtime/manager.py b/questionpy_server/worker/runtime/manager.py index eedcf35..980f537 100644 --- a/questionpy_server/worker/runtime/manager.py +++ b/questionpy_server/worker/runtime/manager.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import dataclass -from typing import NoReturn, cast, TypeAlias, TypeVar +from typing import NoReturn, TypeAlias, TypeVar, cast from questionpy_common.api.qtype import QuestionTypeInterface from questionpy_common.environment import ( diff --git a/questionpy_server/worker/runtime/messages.py b/questionpy_server/worker/runtime/messages.py index 03fd080..1b7d398 100644 --- a/questionpy_server/worker/runtime/messages.py +++ b/questionpy_server/worker/runtime/messages.py @@ -269,9 +269,13 @@ def __init__(self, message_id: int, length: int): super().__init__(f"Received unknown message with id {message_id} and length {length}.") -class WorkerMemoryLimitExceededError(Exception): +class BaseWorkerError(Exception): pass -class WorkerUnknownError(Exception): +class WorkerMemoryLimitExceededError(BaseWorkerError): + pass + + +class WorkerUnknownError(BaseWorkerError): pass diff --git a/tests/questionpy_server/worker/impl/test_subprocess.py b/tests/questionpy_server/worker/impl/test_subprocess.py index 15235ce..8713267 100644 --- a/tests/questionpy_server/worker/impl/test_subprocess.py +++ b/tests/questionpy_server/worker/impl/test_subprocess.py @@ -1,21 +1,27 @@ # This file is part of the QuestionPy Server. (https://questionpy.org) # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus -import math import resource +from collections.abc import Iterator +from contextlib import contextmanager +from time import process_time, sleep, time from unittest.mock import patch import psutil import pytest from questionpy_common.constants import MiB -from questionpy_common.environment import WorkerResourceLimits from questionpy_server import WorkerPool -from questionpy_server.worker import Worker -from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError +from questionpy_server.worker.exception import ( + WorkerCPUTimeLimitExceededError, + WorkerRealTimeLimitExceededError, + WorkerStartError, +) +from questionpy_server.worker.impl._base import BaseWorker, LimitTimeUsageMixin from questionpy_server.worker.impl.subprocess import SubprocessWorker -from questionpy_server.worker.runtime.package_location import PackageLocation +from questionpy_server.worker.runtime.manager import WorkerManager from tests.conftest import PACKAGE +from tests.questionpy_server.worker.impl.conftest import patch_worker_pool @pytest.fixture @@ -34,12 +40,49 @@ async def test_should_apply_limits(pool: WorkerPool) -> None: assert soft == hard == 200 * MiB -async def test_should_raise_timout_error(pool: WorkerPool) -> None: - def worker_init(self: Worker, package: PackageLocation, _: WorkerResourceLimits | None) -> None: - self.package = package - # Set the cpu time limit to a small float greater than zero. - self.limits = WorkerResourceLimits(200 * MiB, math.ulp(0)) - - with pytest.raises(WorkerCPUTimeLimitExceededError), patch.object(Worker, "__init__", worker_init): - async with pool.get_worker(PACKAGE, 1, 1): +@contextmanager +def _make_get_manifest_busy_wait() -> Iterator[None]: + def busy_wait(self: WorkerManager) -> None: + wait_until = process_time() + 10 + while wait_until > process_time(): pass + + with patch.object(WorkerManager, "bootstrap", busy_wait): + yield + + +async def test_should_raise_cpu_timout_error(pool: WorkerPool) -> None: + with patch_worker_pool(pool, _make_get_manifest_busy_wait): + start_time = time() + # Change the timeout for faster testing. + with pytest.raises(WorkerStartError) as exc_info, patch.object(BaseWorker, "_init_worker_timeout", 0.05): + async with pool.get_worker(PACKAGE, 1, 1): + pass + assert isinstance(exc_info.value.__cause__, WorkerCPUTimeLimitExceededError) + assert 0.05 < (time() - start_time) < 0.5 + + +@contextmanager +def _make_get_manifest_sleep() -> Iterator[None]: + def _sleep(self: WorkerManager) -> None: + sleep(10) + + with patch.object(WorkerManager, "bootstrap", _sleep): + yield + + +async def test_should_raise_real_timout_error(pool: WorkerPool) -> None: + with patch_worker_pool(pool, _make_get_manifest_sleep): + # The timeout should not be too short, because the Python interpreter also needs some time to start up, which + # is accounted for the init worker step. Otherwise, a WorkerCPUTimeLimitExceededError is raised. + start_time = time() + with ( + pytest.raises(WorkerStartError) as exc_info, + # Change the timeout and factor for faster testing. + patch.object(BaseWorker, "_init_worker_timeout", 0.6), + patch.object(LimitTimeUsageMixin, "_real_time_limit_factor", 1.0), + ): + async with pool.get_worker(PACKAGE, 1, 1) as worker: + await worker.get_manifest() + assert isinstance(exc_info.value.__cause__, WorkerRealTimeLimitExceededError) + assert 0.6 < (time() - start_time) < 2.0