Skip to content

Commit

Permalink
fix: move time limit logic to a mixin and respect real time
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinGauk committed Oct 23, 2024
1 parent 4102754 commit cff49f6
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 79 deletions.
1 change: 1 addition & 0 deletions questionpy_server/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Worker(ABC):
"""Interface for worker implementations."""

def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
super().__init__()
self.package = package
self.limits = limits
self.state = WorkerState.NOT_RUNNING
Expand Down
17 changes: 13 additions & 4 deletions questionpy_server/worker/exception.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
# This file is part of the QuestionPy Server. (https://questionpy.org)
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>
from questionpy_server.worker.runtime.messages import BaseWorkerError


class WorkerNotRunningError(Exception):
class WorkerNotRunningError(BaseWorkerError):
pass


class WorkerStartError(Exception):
class WorkerStartError(BaseWorkerError):
pass


class WorkerCPUTimeLimitExceededError(Exception):
pass
class WorkerCPUTimeLimitExceededError(BaseWorkerError):
def __init__(self, limit: float):
self.limit = limit
super().__init__(f"Worker has exceeded its CPU time limit of {limit} seconds and was killed.")


class WorkerRealTimeLimitExceededError(BaseWorkerError):
def __init__(self, limit: float):
self.limit = limit
super().__init__(f"Worker has exceeded its real time limit of {limit} seconds and was killed.")


class StaticFileSizeMismatchError(Exception):
Expand Down
116 changes: 102 additions & 14 deletions questionpy_server/worker/impl/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@
import asyncio
import contextlib
import logging
from abc import ABC
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
from zipfile import ZipFile

from questionpy_common.api.attempt import AttemptModel, AttemptScoredModel, AttemptStartedModel
from questionpy_common.constants import DIST_DIR
from questionpy_common.elements import OptionsFormDefinition
from questionpy_common.environment import RequestUser, WorkerResourceLimits
from questionpy_common.environment import RequestUser
from questionpy_common.manifest import Manifest, PackageFile
from questionpy_server.models import QuestionCreated
from questionpy_server.utils.manifest import ComparableManifest
from questionpy_server.worker import PackageFileData, Worker, WorkerState
from questionpy_server.worker.exception import StaticFileSizeMismatchError, WorkerNotRunningError, WorkerStartError
from questionpy_server.worker.exception import (
StaticFileSizeMismatchError,
WorkerCPUTimeLimitExceededError,
WorkerNotRunningError,
WorkerRealTimeLimitExceededError,
WorkerStartError,
)
from questionpy_server.worker.runtime.messages import (
BaseWorkerError,
CreateQuestionFromOptions,
Exit,
GetOptionsForm,
Expand All @@ -37,7 +45,6 @@
from questionpy_server.worker.runtime.package_location import (
DirPackageLocation,
FunctionPackageLocation,
PackageLocation,
ZipPackageLocation,
)

Expand All @@ -64,14 +71,17 @@ class BaseWorker(Worker, ABC):
"""Base class implementing some common functionality of workers."""

_worker_type = "unknown"
_init_worker_timeout = 2
_load_qpy_package_timeout = 4

def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
super().__init__(package, limits)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self._observe_task: asyncio.Task | None = None

self._connection: ServerToWorkerConnection | None = None
self._expected_incoming_messages: list[tuple[MessageIds, asyncio.Future]] = []
self._receive_messages_exception: BaseException | None = None

async def _initialize(self) -> None:
"""Initializes an already running worker and starts the observe task.
Expand All @@ -88,11 +98,14 @@ async def _initialize(self) -> None:
worker_type=self._worker_type,
),
InitWorker.Response,
self._init_worker_timeout,
)
await self.send_and_wait_for_response(
LoadQPyPackage(location=self.package, main=True), LoadQPyPackage.Response
LoadQPyPackage(location=self.package, main=True),
LoadQPyPackage.Response,
self._load_qpy_package_timeout,
)
except WorkerNotRunningError as e:
except BaseWorkerError as e:
msg = "Worker has exited before or during initialization."
raise WorkerStartError(msg) from e

Expand All @@ -101,7 +114,9 @@ def send(self, message: MessageToWorker) -> None:
raise WorkerNotRunningError
self._connection.send_message(message)

async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_M]) -> _M:
async def send_and_wait_for_response(
self, message: MessageToWorker, expected_response_message: type[_M], timeout: float | None = None
) -> _M:
self.send(message)
fut = asyncio.get_running_loop().create_future()
self._expected_incoming_messages.append((expected_response_message.message_id, fut))
Expand Down Expand Up @@ -138,7 +153,8 @@ async def _receive_messages(self) -> None:
finally:
for _, future in self._expected_incoming_messages:
if not future.done():
future.set_exception(WorkerNotRunningError())
exc = self._receive_messages_exception or WorkerNotRunningError()
future.set_exception(exc)
self._expected_incoming_messages = []

def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
Expand All @@ -153,10 +169,14 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:

async def _observe(self) -> None:
"""Observes the tasks returned by _get_observation_tasks."""
pending: Sequence[asyncio.Task] = []
pending: set[asyncio.Task]
try:
tasks = self._get_observation_tasks()
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
with contextlib.suppress(asyncio.CancelledError):
if exc := task.exception():
self._receive_messages_exception = exc
finally:
self.state = WorkerState.NOT_RUNNING

Expand All @@ -170,7 +190,7 @@ async def _observe(self) -> None:
async def stop(self, timeout: float) -> None:
try:
self.send(Exit())
except WorkerNotRunningError:
except BaseWorkerError:
# No need to stop it then.
return

Expand Down Expand Up @@ -292,3 +312,71 @@ async def get_static_file(self, path: str) -> PackageFileData:

async def get_static_file_index(self) -> dict[str, PackageFile]:
return (await self.get_manifest()).static_files


class LimitTimeUsageMixin(Worker, ABC):
"""Implements a CPU and real time usage limit for a worker.
_limit_cpu_time_usage needs to be added to the return value of :meth:`BaseWorker._get_observation_tasks`.
The worker will be killed when the CPU time limit is exceeded or the worker took more than three times
the cpu limit in real time.
"""

_real_time_limit_factor = 3

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._cur_cpu_time_limit: float = 0
self._request_started_cpu_time: float | None = None
self._request_started_time: float | None = None
self._request_started_event = asyncio.Event()

@abstractmethod
def _get_cpu_time(self) -> float:
"""Get worker's current CPU time (user and system time).
Returns:
CPU time in seconds
"""

def _set_time_limit(self, limit: float) -> None:
"""Set a CPU and real time limit.
The real time limit is the CPU time limit * :meth:`LimitTimeUsageMixin._real_time_limit_factor`.
Args:
limit: CPU time limit in seconds
"""
self._cur_cpu_time_limit = limit
self._request_started_cpu_time = self._get_cpu_time()
self._request_started_time = time.time()
self._request_started_event.set()

def _reset_time_limit(self) -> None:
self._cur_cpu_time_limit = 0
self._request_started_cpu_time = None
self._request_started_time = None
self._request_started_event.clear()

async def _limit_cpu_time_usage(self) -> None:
"""Ensures that the worker will be killed when it is taking too much time. Executed as a task."""
while True:
await self._request_started_event.wait()

# CPU-time is always less or equal to real time (when single-threaded).
await asyncio.sleep(self._cur_cpu_time_limit)

# Check if the start time is still set. Probably the request was already processed or
# maybe another request started meanwhile.
while self._request_started_cpu_time is not None and self._request_started_time is not None:
remaining_cpu_time = self._request_started_cpu_time + self._cur_cpu_time_limit - self._get_cpu_time()
if remaining_cpu_time <= 0:
raise WorkerCPUTimeLimitExceededError(self._cur_cpu_time_limit)

remaining_time = (
self._request_started_time + (self._cur_cpu_time_limit * self._real_time_limit_factor) - time.time()
)
if remaining_time <= 0:
raise WorkerRealTimeLimitExceededError(self._cur_cpu_time_limit * self._real_time_limit_factor)

await asyncio.sleep(max(min(remaining_cpu_time, remaining_time), 0.05))
69 changes: 25 additions & 44 deletions questionpy_server/worker/impl/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import logging
import math
import sys
from asyncio import StreamReader
from collections.abc import Sequence
Expand All @@ -16,8 +17,8 @@
from questionpy_common.environment import WorkerResourceLimits
from questionpy_server.worker import WorkerResources
from questionpy_server.worker.connection import ServerToWorkerConnection
from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError, WorkerNotRunningError, WorkerStartError
from questionpy_server.worker.impl._base import BaseWorker
from questionpy_server.worker.exception import WorkerNotRunningError, WorkerStartError
from questionpy_server.worker.impl._base import BaseWorker, LimitTimeUsageMixin
from questionpy_server.worker.runtime.messages import MessageToServer, MessageToWorker
from questionpy_server.worker.runtime.package_location import PackageLocation

Expand Down Expand Up @@ -71,7 +72,7 @@ def flush(self) -> None:
self._skipped_bytes = 0


class SubprocessWorker(BaseWorker):
class SubprocessWorker(BaseWorker, LimitTimeUsageMixin):
"""Worker implementation running in a non-sandboxed subprocess."""

_worker_type = "process"
Expand All @@ -80,7 +81,7 @@ class SubprocessWorker(BaseWorker):
_runtime_main = ["-m", "questionpy_server.worker.runtime"]

def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None):
super().__init__(package, limits)
super().__init__(package=package, limits=limits)

self._proc: Process | None = None
self._stderr_buffer: _StderrBuffer | None = None
Expand Down Expand Up @@ -111,48 +112,16 @@ async def start(self) -> None:
# Whether initialization was successful or not, flush the logs.
self._stderr_buffer.flush()

async def _limit_cpu_time_usage(self, expected_response_message: type[_T]) -> None:
if not self._proc or self._proc.returncode is not None:
raise WorkerNotRunningError

if self.limits is None:
return

psutil_proc = psutil.Process(self._proc.pid)

# Get current cpu times and calculate the maximum cpu time for the current call.
cpu_times = psutil_proc.cpu_times()
max_cpu_time = cpu_times.user + cpu_times.system + self.limits.max_cpu_time_seconds_per_call

# CPU-time is always less or equal to real time.
await asyncio.sleep(self.limits.max_cpu_time_seconds_per_call)

while True:
cpu_times = psutil_proc.cpu_times()
remaining_time = max_cpu_time - (cpu_times.user + cpu_times.system)
if remaining_time <= 0:
break
await asyncio.sleep(max(remaining_time, 0.05))

# Set the exception and kill the process.
for future in [
fut
for expected_id, fut in self._expected_incoming_messages
if expected_id == expected_response_message.message_id
]:
future.set_exception(WorkerCPUTimeLimitExceededError)
self._expected_incoming_messages.remove((expected_response_message.message_id, future))

await self.kill()

async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_T]) -> _T:
timeout = asyncio.create_task(
self._limit_cpu_time_usage(expected_response_message), name="limit cpu time usage"
)
async def send_and_wait_for_response(
self, message: MessageToWorker, expected_response_message: type[_T], timeout: float | None = None
) -> _T:
try:
return await super().send_and_wait_for_response(message, expected_response_message)
if timeout is None:
timeout = self.limits.max_cpu_time_seconds_per_call if self.limits else math.inf
self._set_time_limit(timeout)
return await super().send_and_wait_for_response(message, expected_response_message, timeout)
finally:
timeout.cancel()
self._reset_time_limit()
# Write worker's stderr to log after every exchange.
if self._stderr_buffer:
self._stderr_buffer.flush()
Expand All @@ -176,8 +145,20 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
*super()._get_observation_tasks(),
asyncio.create_task(self._proc.wait(), name="wait for worker process"),
asyncio.create_task(self._stderr_buffer.read_stderr(), name="receive stderr from worker"),
asyncio.create_task(self._limit_cpu_time_usage(), name="limit cpu time usage"),
)

async def kill(self) -> None:
if self._proc and self._proc.returncode is None:
self._proc.kill()

# Make sure that all resources of the subprocesses are getting cleaned.
await self._proc.wait()

def _get_cpu_time(self) -> float:
if not self._proc or self._proc.returncode is not None:
raise WorkerNotRunningError

psutil_proc = psutil.Process(self._proc.pid)
cpu_times = psutil_proc.cpu_times()
return cpu_times.user + cpu_times.system
2 changes: 1 addition & 1 deletion questionpy_server/worker/impl/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ThreadWorker(BaseWorker):
_worker_type = "thread"

def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
super().__init__(package, limits)
super().__init__(package=package, limits=limits)

self._pipe: DuplexPipe | None = None

Expand Down
2 changes: 1 addition & 1 deletion questionpy_server/worker/runtime/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import NoReturn, cast, TypeAlias, TypeVar
from typing import NoReturn, TypeAlias, TypeVar, cast

from questionpy_common.api.qtype import QuestionTypeInterface
from questionpy_common.environment import (
Expand Down
8 changes: 6 additions & 2 deletions questionpy_server/worker/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,13 @@ def __init__(self, message_id: int, length: int):
super().__init__(f"Received unknown message with id {message_id} and length {length}.")


class WorkerMemoryLimitExceededError(Exception):
class BaseWorkerError(Exception):
pass


class WorkerUnknownError(Exception):
class WorkerMemoryLimitExceededError(BaseWorkerError):
pass


class WorkerUnknownError(BaseWorkerError):
pass
Loading

0 comments on commit cff49f6

Please sign in to comment.