Skip to content

Commit

Permalink
Access current update info with ID inside update handler (#544)
Browse files Browse the repository at this point in the history
Fixes #542
  • Loading branch information
cretz authored Jun 6, 2024
1 parent 2d65d82 commit 58d6951
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
5 changes: 5 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ def _apply_do_update(
# inside the task, since the update may not be defined until after we have started the workflow - for example
# if an update is in the first WFT & is also registered dynamically at the top of workflow code.
async def run_update() -> None:
# Set the current update for the life of this task
temporalio.workflow._set_current_update_info(
temporalio.workflow.UpdateInfo(id=job.id, name=job.name)
)

command = self._add_command()
command.update_response.protocol_instance_id = job.protocol_instance_id
past_validation = False
Expand Down
37 changes: 37 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextvars
import inspect
import logging
import threading
Expand Down Expand Up @@ -424,6 +425,17 @@ class ParentInfo:
workflow_id: str


@dataclass(frozen=True)
class UpdateInfo:
"""Information about a workflow update."""

id: str
"""Update ID."""

name: str
"""Update type name."""


class _Runtime(ABC):
@staticmethod
def current() -> _Runtime:
Expand Down Expand Up @@ -654,6 +666,31 @@ async def workflow_wait_condition(
...


_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
"__temporal_current_update_info"
)


def _set_current_update_info(info: UpdateInfo) -> None:
_current_update_info.set(info)


def current_update_info() -> Optional[UpdateInfo]:
"""Info for the current update if any.
This is powered by :py:mod:`contextvars` so it is only valid within the
update handler and coroutines/tasks it has started.
.. warning::
This API is experimental
Returns:
Info for the current update handler the code calling this is executing
within if any.
"""
return _current_update_info.get(None)


def deprecate_patch(id: str) -> None:
"""Mark a patch as deprecated.
Expand Down
75 changes: 75 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4927,3 +4927,78 @@ async def test_workflow_wait_utility(client: Client):
task_queue=worker.task_queue,
)
assert len(result) == 10


@workflow.defn
class CurrentUpdateWorkflow:
def __init__(self) -> None:
self._pending_get_update_id_tasks: List[asyncio.Task[str]] = []

@workflow.run
async def run(self) -> List[str]:
# Confirm no update info
assert not workflow.current_update_info()

# Wait for all tasks to come in, then return the full set
await workflow.wait_condition(
lambda: len(self._pending_get_update_id_tasks) == 5
)
assert not workflow.current_update_info()
return list(await asyncio.gather(*self._pending_get_update_id_tasks))

@workflow.update
async def do_update(self) -> str:
# Check that simple helper awaited has the ID
info = workflow.current_update_info()
assert info
assert info.name == "do_update"
assert info.id == await self.get_update_id()

# Also schedule the task and wait for it in the main workflow to confirm
# it still gets the update ID
self._pending_get_update_id_tasks.append(
asyncio.create_task(self.get_update_id())
)

# Re-fetch and return
info = workflow.current_update_info()
assert info
return info.id

@do_update.validator
def do_update_validator(self) -> None:
info = workflow.current_update_info()
assert info
assert info.name == "do_update"

async def get_update_id(self) -> str:
await asyncio.sleep(0.01)
info = workflow.current_update_info()
assert info
return info.id


async def test_workflow_current_update(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
pytest.skip(
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
)
async with new_worker(client, CurrentUpdateWorkflow) as worker:
handle = await client.start_workflow(
CurrentUpdateWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
update_ids = await asyncio.gather(
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update1"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update2"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update3"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update4"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update5"),
)
assert {"update1", "update2", "update3", "update4", "update5"} == set(
update_ids
)
assert {"update1", "update2", "update3", "update4", "update5"} == set(
await handle.result()
)

0 comments on commit 58d6951

Please sign in to comment.