Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce exception hooks and improve typing #9

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ergate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

__all__ = [
"Ergate",
"ErgateClient",
"Depends",
"ErgateError",
"InvalidDefinitionError",
"AbortJob",
"SkipNSteps",
"ValidationError",
"Job",
"JobStateStoreClientProtocol",
"JobStateStoreWorkerProtocol",
"JobStatus",
"QueueProtocol",
Expand Down
38 changes: 27 additions & 11 deletions ergate/app.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,57 @@
from __future__ import annotations

from collections.abc import Callable
from contextlib import ExitStack
from typing import Callable, ContextManager, Generic, TypeVar
from typing import Generic, TypeVar

from .handler import ErrorHookHandler
from .job import Job
from .job_runner import JobRunner
from .job_state_store import JobStateStoreWorkerProtocol
from .queue import QueueProtocol
from .types import ExceptionHook, Lifespan
from .workflow import Workflow
from .workflow_registry import WorkflowRegistry

T = TypeVar("T", bound=Job)
AppType = TypeVar("AppType", bound="Ergate")
JobType = TypeVar("JobType", bound=Job)


class Ergate(Generic[T]):
class Ergate(Generic[JobType]):
def __init__(
self,
queue: QueueProtocol[T],
job_state_store: JobStateStoreWorkerProtocol[T],
lifespan: Callable[[Ergate], ContextManager[None]] | None = None,
self: AppType,
queue: QueueProtocol[JobType],
job_state_store: JobStateStoreWorkerProtocol[JobType],
lifespan: Lifespan[AppType] | None = None,
error_hook_handler: ErrorHookHandler[JobType] | None = None,
) -> None:
self.lifespan = lifespan
self.error_hook_handler: ErrorHookHandler[JobType] = (
error_hook_handler or ErrorHookHandler()
)

self.workflow_registry = WorkflowRegistry()
self.job_runner = JobRunner(
self.job_runner: JobRunner[JobType] = JobRunner(
queue,
self.workflow_registry,
job_state_store,
self.on_error,
self.error_hook_handler,
)

def on_error(self, exc: Exception) -> None: ...
def exception_hook(
self, *exceptions: type[Exception]
) -> Callable[[ExceptionHook[JobType]], ExceptionHook[JobType]]:
def decorator(func: ExceptionHook[JobType]) -> ExceptionHook[JobType]:
for exception in exceptions:
self.error_hook_handler.register(exception, func)
return func

return decorator

def register_workflow(self, workflow: Workflow) -> None:
self.workflow_registry.register(workflow)

def run(self) -> None:
def run(self: AppType) -> None:
with ExitStack() as stack:
if self.lifespan:
stack.enter_context(self.lifespan(self))
Expand Down
35 changes: 35 additions & 0 deletions ergate/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Generic, TypeVar

from .job import Job
from .log import LOG
from .types import ExceptionHook

JobType = TypeVar("JobType", bound=Job)


class ErrorHookHandler(Generic[JobType]):
def __init__(self) -> None:
self._hooks: dict[Any, ExceptionHook[JobType]] = {}

def register(
self,
exception: type[Exception],
hook: ExceptionHook[JobType],
) -> None:
self._hooks[exception] = hook

def get_hook(self, exception: Exception) -> ExceptionHook[JobType] | None:
for cls in type(exception).__mro__:
if cls in self._hooks:
return self._hooks[cls]
return None

def notify(self, job: JobType, exception: Exception) -> None:
hook = self.get_hook(exception)
if not hook:
return

try:
hook(job, exception)
except Exception:
LOG.warning("Error hook raised an exception", exc_info=True)
5 changes: 3 additions & 2 deletions ergate/job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import Any

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -28,10 +29,10 @@ def mark_running(self, step_name: str) -> None:
self.status = JobStatus.RUNNING
self.step_name = step_name

def mark_failed(self, exception_traceback: str) -> None:
def mark_failed(self, exception: Exception) -> None:
self.status = JobStatus.FAILED
self.step_name = None
self.exception_traceback = exception_traceback
self.exception_traceback = traceback.format_exc()

def mark_n_steps_completed(
self,
Expand Down
41 changes: 18 additions & 23 deletions ergate/job_runner.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
import traceback
from typing import Callable, TypeVar
from typing import Generic, TypeVar

from .exceptions import AbortJob, SkipNSteps
from .handler import ErrorHookHandler
from .interrupt import DelayedKeyboardInterrupt
from .job import Job
from .job_state_store import JobStateStoreWorkerProtocol
from .log import LOG
from .queue import QueueProtocol
from .workflow_registry import WorkflowRegistry

T = TypeVar("T", bound=Job)
JobType = TypeVar("JobType", bound=Job)


class JobRunner:
class JobRunner(Generic[JobType]):
def __init__(
self,
queue: QueueProtocol[T],
queue: QueueProtocol[JobType],
workflow_registry: WorkflowRegistry,
job_state_store: JobStateStoreWorkerProtocol[T],
on_error_callback: Callable[[Exception], None],
job_state_store: JobStateStoreWorkerProtocol[JobType],
error_hook_handler: ErrorHookHandler[JobType],
) -> None:
self.queue = queue
self.workflow_registry = workflow_registry
self.job_state_store = job_state_store
self.on_error_callback = on_error_callback
self.error_hook_handler = error_hook_handler

def _run_job(self, job: T) -> None:
def _run_job(self, job: JobType) -> None:
input_value = job.get_input_value()

workflow = self.workflow_registry[job.workflow_name]
step_to_run = workflow[job.steps_completed]

job.mark_running(step_to_run.name)
self.job_state_store.update(job)

try:
workflow = self.workflow_registry[job.workflow_name]
step_to_run = workflow[job.steps_completed]
job.mark_running(step_to_run.name)
self.job_state_store.update(job)
LOG.info("Running %s - input value: %s", str(step_to_run), input_value)
retval = step_to_run(input_value, job.user_context)
except AbortJob as exc:
Expand All @@ -43,8 +45,8 @@ def _run_job(self, job: T) -> None:
job.mark_n_steps_completed(exc.n + 1, exc.retval, len(workflow))
except Exception as exc:
LOG.exception("Job raised an exception")
job.mark_failed(traceback.format_exc())
self.on_error_callback(exc)
job.mark_failed(exc)
self.error_hook_handler.notify(job, exc)
else:
LOG.info("Step completed successfully - return value: %s", retval)
job.mark_n_steps_completed(1, retval, len(workflow))
Expand All @@ -62,17 +64,10 @@ def run(self) -> None:
job = self.queue.get_one()
except KeyboardInterrupt:
return
except Exception as exc:
self.on_error_callback(exc)
raise

LOG.info("Job acquired")
try:
with DelayedKeyboardInterrupt():
try:
self._run_job(job)
except Exception as exc:
self.on_error_callback(exc)
raise
self._run_job(job)
except KeyboardInterrupt:
return
6 changes: 3 additions & 3 deletions ergate/job_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from .job import Job

T = TypeVar("T", bound=Job, contravariant=True)
JobType = TypeVar("JobType", bound=Job, contravariant=True)


class JobStateStoreWorkerProtocol(Protocol[T]):
def update(self, job: T) -> None: ...
class JobStateStoreWorkerProtocol(Protocol[JobType]):
def update(self, job: JobType) -> None: ...
8 changes: 4 additions & 4 deletions ergate/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from .job import Job

T = TypeVar("T", bound=Job)
JobType = TypeVar("JobType", bound=Job)


class QueueProtocol(Protocol[T]):
def put(self, job: T) -> None: ...
def get_one(self) -> T: ...
class QueueProtocol(Protocol[JobType]):
def put(self, job: JobType) -> None: ...
def get_one(self) -> JobType: ...
9 changes: 9 additions & 0 deletions ergate/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from collections.abc import Callable
from typing import ContextManager, TypeVar

AppType = TypeVar("AppType")
JobType = TypeVar("JobType")

Lifespan = Callable[[AppType], ContextManager[None]] | None

ExceptionHook = Callable[[JobType, Exception], None]