Skip to content

Commit

Permalink
make sure workers emit stop events (#16002)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Nov 13, 2024
1 parent fc6eff4 commit b4b5cae
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
51 changes: 49 additions & 2 deletions flows/worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
import asyncio
import subprocess
import sys
from threading import Thread
from typing import List

from pydantic_extra_types.pendulum_dt import DateTime

from prefect.events import Event
from prefect.events.clients import get_events_subscriber
from prefect.events.filters import EventFilter, EventNameFilter, EventOccurredFilter


async def watch_worker_events(events: List[Event]):
"""Watch for worker start/stop events and collect them"""
async with get_events_subscriber(
filter=EventFilter(
event=EventNameFilter(prefix=["prefect.worker."]),
occurred=EventOccurredFilter(since=DateTime.now()),
)
) as events_subscriber:
async for event in events_subscriber:
events.append(event)


def run_event_listener(events: List[Event]):
"""Run the async event listener in a thread"""
asyncio.run(watch_worker_events(events))


# Checks to make sure that collections are loaded prior to attempting to start a worker
def main():
events: List[Event] = []

listener_thread = Thread(target=run_event_listener, args=(events,), daemon=True)
listener_thread.start()

subprocess.check_call(
["python", "-m", "pip", "install", "prefect-kubernetes>=0.5.0"],
stdout=sys.stdout,
Expand Down Expand Up @@ -52,11 +82,28 @@ def main():
stderr=sys.stderr,
)
subprocess.check_call(
["prefect", "work-pool", "delete", "test-worker-pool"],
["prefect", "--no-prompt", "work-pool", "delete", "test-worker-pool"],
stdout=sys.stdout,
stderr=sys.stderr,
)

worker_events = [e for e in events if e.event.startswith("prefect.worker.")]
assert (
len(worker_events) == 2
), f"Expected 2 worker events, got {len(worker_events)}"

start_events = [e for e in worker_events if e.event == "prefect.worker.started"]
stop_events = [e for e in worker_events if e.event == "prefect.worker.stopped"]

assert len(start_events) == 1, "Expected 1 worker start event"
assert len(stop_events) == 1, "Expected 1 worker stop event"

print("Captured expected worker start and stop events!")

assert (
stop_events[0].follows == start_events[0].id
), "Stop event should follow start event"


if __name__ == "__main__":
main()
22 changes: 13 additions & 9 deletions src/prefect/events/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,18 @@ def includes(self, event: Event) -> bool:

class EventNameFilter(EventDataFilter):
prefix: Optional[List[str]] = Field(
None, description="Only include events matching one of these prefixes"
default=None, description="Only include events matching one of these prefixes"
)
exclude_prefix: Optional[List[str]] = Field(
None, description="Exclude events matching one of these prefixes"
default=None, description="Exclude events matching one of these prefixes"
)

name: Optional[List[str]] = Field(
None, description="Only include events matching one of these names exactly"
default=None,
description="Only include events matching one of these names exactly",
)
exclude_name: Optional[List[str]] = Field(
None, description="Exclude events matching one of these names exactly"
default=None, description="Exclude events matching one of these names exactly"
)

def includes(self, event: Event) -> bool:
Expand Down Expand Up @@ -230,24 +231,27 @@ class EventFilter(EventDataFilter):
description="Filter criteria for when the events occurred",
)
event: Optional[EventNameFilter] = Field(
None,
default=None,
description="Filter criteria for the event name",
)
any_resource: Optional[EventAnyResourceFilter] = Field(
None, description="Filter criteria for any resource involved in the event"
default=None,
description="Filter criteria for any resource involved in the event",
)
resource: Optional[EventResourceFilter] = Field(
None, description="Filter criteria for the resource of the event"
default=None,
description="Filter criteria for the resource of the event",
)
related: Optional[EventRelatedFilter] = Field(
None, description="Filter criteria for the related resources of the event"
default=None,
description="Filter criteria for the related resources of the event",
)
id: EventIDFilter = Field(
default_factory=lambda: EventIDFilter(id=[]),
description="Filter criteria for the events' ID",
)

order: EventOrder = Field(
EventOrder.DESC,
default=EventOrder.DESC,
description="The order to return filtered events",
)
9 changes: 7 additions & 2 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,14 @@ async def teardown(self, *exc_info):
for scope in self._scheduled_task_scopes:
scope.cancel()

await self._exit_stack.__aexit__(*exc_info)
# Emit stopped event before closing client
if self._started_event:
await self._emit_worker_stopped_event(self._started_event)
try:
await self._emit_worker_stopped_event(self._started_event)
except Exception:
self._logger.exception("Failed to emit worker stopped event")

await self._exit_stack.__aexit__(*exc_info)
self._runs_task_group = None
self._client = None

Expand Down
2 changes: 1 addition & 1 deletion tests/workers/test_base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ async def test_job_configuration_from_template_overrides_with_remote_variables()

class ArbitraryJobConfiguration(BaseJobConfiguration):
var1: str
env: Dict[str, str]
env: Dict[str, str] = Field(default_factory=dict)

config = await ArbitraryJobConfiguration.from_template_and_values(
base_job_template=template,
Expand Down

0 comments on commit b4b5cae

Please sign in to comment.