Skip to content

Commit

Permalink
IWF-356: Add atomic checking decision for channels (#45)
Browse files Browse the repository at this point in the history
* IWF-356: Add atomic checking decision for channels

* DONE

* DONE

* DONE

* DONE

* Fix mypy

* Fix bug

* fix env

* fix import

* IWF-363: Support dynamic internal channel by prefix MVP version (#47)

* Support dynamic internal channel by prefix

* fix ignore

* Update iwf/command_results.py

Co-authored-by: Samuel Caçador <[email protected]>

---------

Co-authored-by: Samuel Caçador <[email protected]>

---------

Co-authored-by: Samuel Caçador <[email protected]>
  • Loading branch information
longquanzheng and samuel27m authored Nov 22, 2024
1 parent c94432a commit ecf41da
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__pycache__/
*.py[cod]
*$py.class
requirements.*

# C extensions
*.so
Expand Down
20 changes: 17 additions & 3 deletions iwf/command_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Union

from iwf.errors import WorkflowDefinitionError
from iwf.iwf_api.models import (
ChannelRequestStatus,
CommandResults as IdlCommandResults,
Expand Down Expand Up @@ -57,12 +58,25 @@ def from_idl_command_results(

if not isinstance(idl_results.inter_state_channel_results, Unset):
for inter in idl_results.inter_state_channel_results:
val_type = internal_channel_types.get(inter.channel_name)
if val_type is None:
# fallback to assume it's prefix
# TODO use is_prefix to implement like Java SDK
for name, t in internal_channel_types.items():
if inter.channel_name.startswith(name):
val_type = t
break
if val_type is None:
raise WorkflowDefinitionError(
"internal channel is not registered: " + inter.channel_name
)

encoded = object_encoder.decode(inter.value, val_type)

results.internal_channel_commands.append(
InternalChannelCommandResult(
inter.channel_name,
object_encoder.decode(
inter.value, internal_channel_types.get(inter.channel_name)
),
encoded,
inter.request_status,
inter.command_id,
)
Expand Down
10 changes: 8 additions & 2 deletions iwf/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
self._state_movements.append(movement)

def publish_to_internal_channel(self, channel_name: str, value: Any = None):
if channel_name not in self._type_store:
registered_type = self._type_store.get(channel_name)

if registered_type is None:
for name, t in self._type_store.items():
if channel_name.startswith(name):
registered_type = t

if registered_type is None:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)

registered_type = self._type_store.get(channel_name)
if (
value is not None
and registered_type is not None
Expand Down
11 changes: 9 additions & 2 deletions iwf/communication_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@ class CommunicationMethod:
name: str
method_type: CommunicationMethodType
value_type: Optional[type]
is_prefix: bool

@classmethod
def signal_channel_def(cls, name: str, value_type: type):
return CommunicationMethod(
name, CommunicationMethodType.SignalChannel, value_type
name, CommunicationMethodType.SignalChannel, value_type, False
)

@classmethod
def internal_channel_def(cls, name: str, value_type: type):
return CommunicationMethod(
name, CommunicationMethodType.InternalChannel, value_type
name, CommunicationMethodType.InternalChannel, value_type, False
)

@classmethod
def internal_channel_def_by_prefix(cls, name_prefix: str, value_type: type):
return CommunicationMethod(
name_prefix, CommunicationMethodType.InternalChannel, value_type, True
)


Expand Down
2 changes: 2 additions & 0 deletions iwf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _register_internal_channels(self, wf: ObjectWorkflow):
for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.InternalChannel:
types[method.name] = method.value_type
# TODO use is_prefix to implement like Java SDK
#
self._internal_channel_type_store[wf_type] = types

def _register_signal_channels(self, wf: ObjectWorkflow):
Expand Down
77 changes: 74 additions & 3 deletions iwf/state_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import typing

from iwf.iwf_api.models import WorkflowConditionalClose, WorkflowConditionalCloseType

if typing.TYPE_CHECKING:
from iwf.registry import Registry
from iwf.workflow_state import WorkflowState
Expand All @@ -16,10 +18,21 @@
from iwf.state_movement import StateMovement, _to_idl_state_movement


@dataclass
class InternalConditionalClose:
conditional_close_type: WorkflowConditionalCloseType

channel_name: str

close_input: Any = None


@dataclass
class StateDecision:
next_states: List[StateMovement]

conditional_close: typing.Optional[InternalConditionalClose] = None

dead_end: typing.ClassVar[StateDecision]

@classmethod
Expand Down Expand Up @@ -50,16 +63,74 @@ def multi_next_states(
]
return StateDecision(next_list)

# Atomically force complete the workflow if internal channel is empty, otherwise trigger the state movements from the current thread
# This is to ensure all the messages in the channel are processed before completing the workflow, otherwise messages may be lost.
# Without this atomic API, if just checking the channel emptiness in the State WaitUntil, a workflow may receive new messages during the
# execution of state APIs.
#
# Note that it's only for internal messages published from RPCs.
# It doesn't cover the cases that internal messages are published from other State APIs.
# If you do want to use other State APIs to publish messages to the channel at the same time, you can use persistence locking to
# ensure only the State APIs are not executed in parallel.
@classmethod
def force_complete_if_internal_channel_empty_or_else(
cls,
internal_channel_name: str,
workflow_complete_output: Any = None, # if channel is empty, complete the workflow with the output
or_else_state: Union[
str, type[WorkflowState]
] = "", # required not empty -- if channel is NOT empty, go to this state with the state input
state_input: Any = None,
) -> StateDecision:
return StateDecision(
[StateMovement.create(or_else_state, state_input)],
InternalConditionalClose(
WorkflowConditionalCloseType.FORCE_COMPLETE_ON_INTERNAL_CHANNEL_EMPTY,
internal_channel_name,
workflow_complete_output,
),
)

# Atomically force complete the workflow if signal channel is empty, otherwise trigger the state movements from the current thread
# This is to ensure all the messages in the channel are processed before completing the workflow, otherwise messages may be lost.
# Without this atomic API, if just checking the channel emptiness in the State WaitUntil, a workflow may receive new messages during the
# execution of state APIs.
@classmethod
def force_complete_if_signal_channel_empty_or_else(
cls,
signal_channel_name: str,
workflow_complete_output: Any = None, # if channel is empty, complete the workflow with the output
or_else_state: Union[
str, type[WorkflowState]
] = "", # required not empty-- if channel is NOT empty, go to this state with the state input
state_input: Any = None,
) -> StateDecision:
return StateDecision(
[StateMovement.create(or_else_state, state_input)],
InternalConditionalClose(
WorkflowConditionalCloseType.FORCE_COMPLETE_ON_SIGNAL_CHANNEL_EMPTY,
signal_channel_name,
workflow_complete_output,
),
)


StateDecision.dead_end = StateDecision([StateMovement.dead_end])


def _to_idl_state_decision(
decision: StateDecision, wf_type: str, registry: Registry, encoder: ObjectEncoder
) -> IdlStateDecision:
return IdlStateDecision(
[
idl_decision = IdlStateDecision()
if len(decision.next_states) > 0:
idl_decision.next_states = [
_to_idl_state_movement(movement, wf_type, registry, encoder)
for movement in decision.next_states
]
)
if decision.conditional_close is not None:
idl_decision.conditional_close = WorkflowConditionalClose(
conditional_close_type=decision.conditional_close.conditional_close_type,
channel_name=decision.conditional_close.channel_name,
close_input=encoder.encode(decision.conditional_close.close_input),
)
return idl_decision
1 change: 1 addition & 0 deletions iwf/tests/iwf-service-env/.env
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ MYSQL_VERSION=8
POSTGRESQL_VERSION=13
TEMPORAL_VERSION=1.25
TEMPORAL_UI_VERSION=2.31.2
TEMPORAL_ADMIN_TOOLS_VERSION=1.25.2-tctl-1.18.1-cli-1.1.1
2 changes: 1 addition & 1 deletion iwf/tests/iwf-service-env/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ services:
environment:
- TEMPORAL_ADDRESS=temporal:7233
- TEMPORAL_CLI_ADDRESS=temporal:7233
image: temporalio/admin-tools:${TEMPORAL_VERSION}
image: temporalio/admin-tools:${TEMPORAL_ADMIN_TOOLS_VERSION}
networks:
- temporal-network
stdin_open: true
Expand Down
125 changes: 125 additions & 0 deletions iwf/tests/test_conditional_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import inspect
import time
import unittest

from iwf.client import Client
from iwf.command_request import (
CommandRequest,
InternalChannelCommand,
SignalChannelCommand,
)
from iwf.command_results import CommandResults
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.persistence_schema import PersistenceField, PersistenceSchema
from iwf.rpc import rpc
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

test_signal_channel = "test-1"
test_internal_channel = "test-2"

da_counter = "counter"


class WaitState(WorkflowState[bool]):
def wait_until(
self,
ctx: WorkflowContext,
use_signal: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
if use_signal:
return CommandRequest.for_all_command_completed(
SignalChannelCommand.by_name(test_signal_channel),
)
else:
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_internal_channel),
)

def execute(
self,
ctx: WorkflowContext,
use_signal: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
counter = persistence.get_data_attribute(da_counter)
if counter is None:
counter = 0
counter += 1
persistence.set_data_attribute(da_counter, counter)

if ctx.state_execution_id == "WaitState-1":
# wait for 3 seconds so that the channel can have a new message
time.sleep(3)
if use_signal:
return StateDecision.force_complete_if_signal_channel_empty_or_else(
test_signal_channel, counter, WaitState, use_signal
)
else:
return StateDecision.force_complete_if_internal_channel_empty_or_else(
test_internal_channel, counter, WaitState, use_signal
)


class ConditionalCompleteWorkflow(ObjectWorkflow):
def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.signal_channel_def(test_signal_channel, int),
CommunicationMethod.internal_channel_def(test_internal_channel, int),
)

def get_persistence_schema(self) -> PersistenceSchema:
return PersistenceSchema.create(
PersistenceField.data_attribute_def(da_counter, int),
)

def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(WaitState())

@rpc()
def test_rpc_publish_channel(self, com: Communication):
com.publish_to_internal_channel(test_internal_channel, 0)


class TestConditionalComplete(unittest.TestCase):
@classmethod
def setUpClass(cls):
wf = ConditionalCompleteWorkflow()
registry.add_workflow(wf)

def test_internal_channel_conditional_complete(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
self.do_test_conditional_workflow(wf_id, False)

def test_signal_channel_conditional_complete(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
self.do_test_conditional_workflow(wf_id, True)

def do_test_conditional_workflow(self, wf_id: str, use_signal: bool):
self.client = Client(registry)

self.client.start_workflow(ConditionalCompleteWorkflow, wf_id, 10, use_signal)

for x in range(3):
if use_signal:
self.client.signal_workflow(wf_id, test_signal_channel, 123)
else:
self.client.invoke_rpc(
wf_id, ConditionalCompleteWorkflow.test_rpc_publish_channel
)
if x == 0:
# wait for a second so that the workflow is in execute state
time.sleep(1)

res = self.client.get_simple_workflow_result_with_wait(wf_id)
assert res == 3
9 changes: 9 additions & 0 deletions iwf/tests/test_internal_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
test_channel_name3 = "test-internal-channel-3"
test_channel_name4 = "test-internal-channel-4"

test_channel_name_prefix = "test-internal-channel-prefix-"


class InitState(WorkflowState[None]):
def execute(
Expand All @@ -46,6 +48,9 @@ def wait_until(
) -> CommandRequest:
communication.publish_to_internal_channel(test_channel_name3, 123)
communication.publish_to_internal_channel(test_channel_name4, "str-value")
communication.publish_to_internal_channel(
test_channel_name_prefix + "abc", "str-value-for-prefix"
)
return CommandRequest.for_any_command_completed(
InternalChannelCommand.by_name(test_channel_name1),
InternalChannelCommand.by_name(test_channel_name2),
Expand Down Expand Up @@ -90,6 +95,7 @@ def wait_until(
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_channel_name3),
InternalChannelCommand.by_name(test_channel_name4),
InternalChannelCommand.by_name(test_channel_name_prefix + "abc"),
)

def execute(
Expand All @@ -116,6 +122,9 @@ def get_communication_schema(self) -> CommunicationSchema:
CommunicationMethod.internal_channel_def(test_channel_name2, type(None)),
CommunicationMethod.internal_channel_def(test_channel_name3, int),
CommunicationMethod.internal_channel_def(test_channel_name4, str),
CommunicationMethod.internal_channel_def_by_prefix(
test_channel_name_prefix, str
),
)


Expand Down
Loading

0 comments on commit ecf41da

Please sign in to comment.