Skip to content

Commit

Permalink
fix: monitor cpu time usage from server process
Browse files Browse the repository at this point in the history
  • Loading branch information
janbritz authored and MartinGauk committed Oct 23, 2024
1 parent d98bbe5 commit 4102754
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
40 changes: 39 additions & 1 deletion questionpy_server/worker/impl/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from questionpy_common.environment import WorkerResourceLimits
from questionpy_server.worker import WorkerResources
from questionpy_server.worker.connection import ServerToWorkerConnection
from questionpy_server.worker.exception import WorkerNotRunningError, WorkerStartError
from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError, WorkerNotRunningError, WorkerStartError
from questionpy_server.worker.impl._base import BaseWorker
from questionpy_server.worker.runtime.messages import MessageToServer, MessageToWorker
from questionpy_server.worker.runtime.package_location import PackageLocation
Expand Down Expand Up @@ -111,10 +111,48 @@ async def start(self) -> None:
# Whether initialization was successful or not, flush the logs.
self._stderr_buffer.flush()

async def _limit_cpu_time_usage(self, expected_response_message: type[_T]) -> None:
if not self._proc or self._proc.returncode is not None:
raise WorkerNotRunningError

if self.limits is None:
return

psutil_proc = psutil.Process(self._proc.pid)

# Get current cpu times and calculate the maximum cpu time for the current call.
cpu_times = psutil_proc.cpu_times()
max_cpu_time = cpu_times.user + cpu_times.system + self.limits.max_cpu_time_seconds_per_call

# CPU-time is always less or equal to real time.
await asyncio.sleep(self.limits.max_cpu_time_seconds_per_call)

while True:
cpu_times = psutil_proc.cpu_times()
remaining_time = max_cpu_time - (cpu_times.user + cpu_times.system)
if remaining_time <= 0:
break
await asyncio.sleep(max(remaining_time, 0.05))

# Set the exception and kill the process.
for future in [
fut
for expected_id, fut in self._expected_incoming_messages
if expected_id == expected_response_message.message_id
]:
future.set_exception(WorkerCPUTimeLimitExceededError)
self._expected_incoming_messages.remove((expected_response_message.message_id, future))

await self.kill()

async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_T]) -> _T:
timeout = asyncio.create_task(
self._limit_cpu_time_usage(expected_response_message), name="limit cpu time usage"
)
try:
return await super().send_and_wait_for_response(message, expected_response_message)
finally:
timeout.cancel()
# Write worker's stderr to log after every exchange.
if self._stderr_buffer:
self._stderr_buffer.flush()
Expand Down
31 changes: 1 addition & 30 deletions questionpy_server/worker/runtime/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>
import resource
import signal
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from typing import Any, NoReturn, cast, TypeAlias, TypeVar
from typing import NoReturn, cast, TypeAlias, TypeVar

from questionpy_common.api.qtype import QuestionTypeInterface
from questionpy_common.environment import (
Expand All @@ -33,7 +31,6 @@
StartAttempt,
ViewAttempt,
WorkerError,
WorkerTimeLimitExceededError,
)
from questionpy_server.worker.runtime.package import ImportablePackage, load_package

Expand All @@ -57,28 +54,6 @@ def register_on_request_callback(self, callback: OnRequestCallback) -> None:
OnMessageCallback: TypeAlias = Callable[[M], MessageToServer]


def timeout_after(seconds: float) -> Callable[[OnMessageCallback], OnMessageCallback]:
def decorator(function: OnMessageCallback) -> OnMessageCallback:
@wraps(function)
def wrapper(msg: M) -> MessageToServer:
def raise_time_limit_exceeded(*_: Any) -> NoReturn:
raise WorkerTimeLimitExceededError

# Create a timer that raises after the given amount of cpu time.
signal.signal(signal.SIGVTALRM, raise_time_limit_exceeded)
signal.setitimer(signal.ITIMER_VIRTUAL, seconds)

try:
return function(msg)
finally:
# Clear the timer.
signal.setitimer(signal.ITIMER_VIRTUAL, 0)

return wrapper

return decorator


class WorkerManager:
def __init__(self, server_connection: WorkerToServerConnection):
self._connection: WorkerToServerConnection = server_connection
Expand Down Expand Up @@ -113,10 +88,6 @@ def bootstrap(self) -> None:
if self._limits:
# Limit memory usage.
resource.setrlimit(resource.RLIMIT_AS, (self._limits.max_memory, self._limits.max_memory))
# Limit cpu time usage.
timeout = timeout_after(self._limits.max_cpu_time_seconds_per_call)
for message_id, callback in self._message_dispatch.items():
self._message_dispatch[message_id] = timeout(callback)

self._connection.send_message(InitWorker.Response())

Expand Down
9 changes: 0 additions & 9 deletions questionpy_server/worker/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ class ErrorType(StrEnum):

UNKNOWN = auto()
MEMORY_EXCEEDED = auto()
TIME_LIMIT_EXCEEDED = auto()
QUESTION_STATE_INVALID = auto()

message_id: ClassVar[MessageIds] = MessageIds.ERROR
Expand All @@ -218,8 +217,6 @@ def from_exception(cls, error: Exception, cause: MessageToWorker) -> "WorkerErro
"""Get a WorkerError message from an exception."""
if isinstance(error, MemoryError):
error_type = WorkerError.ErrorType.MEMORY_EXCEEDED
elif isinstance(error, WorkerTimeLimitExceededError):
error_type = WorkerError.ErrorType.TIME_LIMIT_EXCEEDED
elif isinstance(error, InvalidQuestionStateError):
error_type = WorkerError.ErrorType.QUESTION_STATE_INVALID
else:
Expand All @@ -242,8 +239,6 @@ def to_exception(self) -> Exception:
error: Exception
if self.type == WorkerError.ErrorType.MEMORY_EXCEEDED:
error = WorkerMemoryLimitExceededError(self.message)
elif self.type == WorkerError.ErrorType.TIME_LIMIT_EXCEEDED:
error = WorkerTimeLimitExceededError(self.message)
elif self.type == WorkerError.ErrorType.QUESTION_STATE_INVALID:
error = InvalidQuestionStateError(self.message)
else:
Expand Down Expand Up @@ -278,9 +273,5 @@ class WorkerMemoryLimitExceededError(Exception):
pass


class WorkerTimeLimitExceededError(Exception):
pass


class WorkerUnknownError(Exception):
pass
6 changes: 3 additions & 3 deletions tests/questionpy_server/worker/impl/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from questionpy_common.constants import MiB
from questionpy_common.environment import WorkerResourceLimits
from questionpy_server import WorkerPool
from questionpy_server.worker.impl import Worker
from questionpy_server.worker import Worker
from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError
from questionpy_server.worker.impl.subprocess import SubprocessWorker
from questionpy_server.worker.runtime.messages import WorkerTimeLimitExceededError
from questionpy_server.worker.runtime.package_location import PackageLocation
from tests.conftest import PACKAGE

Expand All @@ -40,6 +40,6 @@ def worker_init(self: Worker, package: PackageLocation, _: WorkerResourceLimits
# Set the cpu time limit to a small float greater than zero.
self.limits = WorkerResourceLimits(200 * MiB, math.ulp(0))

with pytest.raises(WorkerTimeLimitExceededError), patch.object(Worker, "__init__", worker_init):
with pytest.raises(WorkerCPUTimeLimitExceededError), patch.object(Worker, "__init__", worker_init):
async with pool.get_worker(PACKAGE, 1, 1):
pass

0 comments on commit 4102754

Please sign in to comment.