Skip to content

Commit

Permalink
DONE
Browse files Browse the repository at this point in the history
  • Loading branch information
longquanzheng committed Nov 21, 2024
1 parent 0190592 commit f2b7c41
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 26 deletions.
55 changes: 29 additions & 26 deletions iwf/state_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from iwf_api.models import WorkflowConditionalClose, WorkflowConditionalCloseType

from iwf.errors import WorkflowDefinitionError

if typing.TYPE_CHECKING:
from iwf.registry import Registry
from iwf.workflow_state import WorkflowState
Expand Down Expand Up @@ -76,14 +74,20 @@ def multi_next_states(
# ensure only the State APIs are not executed in parallel.
@classmethod
def force_complete_if_internal_channel_empty_or_else(
cls, internal_channel_name: str, output: Any = None
cls,
internal_channel_name: str,
workflow_complete_output: Any = None, # if channel is empty, complete the workflow with the output
or_else_state: Union[
str, type[WorkflowState]
] = None, # if channel is NOT empty, go to this state with the state input
state_input: Any = None,
) -> StateDecision:
return StateDecision(
[],
[StateMovement.create(or_else_state, state_input)],
InternalConditionalClose(
WorkflowConditionalCloseType.FORCE_COMPLETE_ON_INTERNAL_CHANNEL_EMPTY,
internal_channel_name,
output,
workflow_complete_output,
),
)

Expand All @@ -93,14 +97,20 @@ def force_complete_if_internal_channel_empty_or_else(
# execution of state APIs.
@classmethod
def force_complete_if_signal_channel_empty_or_else(
cls, signal_channel_name: str, output: Any = None
cls,
signal_channel_name: str,
workflow_complete_output: Any = None, # if channel is empty, complete the workflow with the output
or_else_state: Union[
str, type[WorkflowState]
] = None, # if channel is NOT empty, go to this state with the state input
state_input: Any = None,
) -> StateDecision:
return StateDecision(
[],
[StateMovement.create(or_else_state, state_input)],
InternalConditionalClose(
WorkflowConditionalCloseType.FORCE_COMPLETE_ON_SIGNAL_CHANNEL_EMPTY,
signal_channel_name,
output,
workflow_complete_output,
),
)

Expand All @@ -111,23 +121,16 @@ def force_complete_if_signal_channel_empty_or_else(
def _to_idl_state_decision(
decision: StateDecision, wf_type: str, registry: Registry, encoder: ObjectEncoder
) -> IdlStateDecision:
idl_decision = IdlStateDecision()
if len(decision.next_states) > 0:
return IdlStateDecision(
[
_to_idl_state_movement(movement, wf_type, registry, encoder)
for movement in decision.next_states
]
)
else:
internal_conditional_close = decision.conditional_close
if internal_conditional_close is None:
raise WorkflowDefinitionError(
"must have either next states or conditional close"
)

conditional_close = WorkflowConditionalClose(
conditional_close_type=internal_conditional_close.conditional_close_type,
channel_name=internal_conditional_close.channel_name,
close_input=encoder.encode(internal_conditional_close.close_input),
idl_decision.next_states = [
_to_idl_state_movement(movement, wf_type, registry, encoder)
for movement in decision.next_states
]
if decision.conditional_close is not None:
idl_decision.conditional_close = WorkflowConditionalClose(
conditional_close_type=decision.conditional_close.conditional_close_type,
channel_name=decision.conditional_close.channel_name,
close_input=encoder.encode(decision.conditional_close.close_input),
)
return IdlStateDecision([], conditional_close)
return idl_decision
120 changes: 120 additions & 0 deletions iwf/tests/test_conditional_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import inspect
import time
import unittest

from iwf.client import Client
from iwf.command_request import (
CommandRequest,
InternalChannelCommand,
SignalChannelCommand,
)
from iwf.command_results import CommandResults
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.persistence_schema import PersistenceField, PersistenceSchema
from iwf.rpc import rpc
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

test_signal_channel = "test-1"
test_internal_channel = "test-2"

da_counter = "counter"


class WaitState(WorkflowState[bool]):
def wait_until(
self,
ctx: WorkflowContext,
use_signal: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
if use_signal:
return CommandRequest.for_all_command_completed(
SignalChannelCommand.by_name(test_signal_channel),
)
else:
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_signal_channel),
)

def execute(
self,
ctx: WorkflowContext,
use_signal: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
counter = persistence.get_data_attribute(da_counter)
if counter is None:
counter = 0
counter += 1
persistence.set_data_attribute(da_counter, counter)

if ctx.state_execution_id == "WaitState-1":
# wait for 3 seconds so that the channel can have a new message
time.sleep(3)
if use_signal:
return StateDecision.force_complete_if_signal_channel_empty_or_else(
test_signal_channel, counter, WaitState, use_signal
)
else:
return StateDecision.force_complete_if_internal_channel_empty_or_else(
test_internal_channel, counter, WaitState, use_signal
)


class ConditionalCompleteWorkflow(ObjectWorkflow):
def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.signal_channel_def(test_signal_channel, int),
CommunicationMethod.internal_channel_def(test_internal_channel, int),
)

def get_persistence_schema(self) -> PersistenceSchema:
return PersistenceSchema.create(
PersistenceField.data_attribute_def(da_counter, int),
)

def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(WaitState())

@rpc()
def test_rpc_publish_channel(self, com: Communication):
com.publish_to_internal_channel(test_internal_channel, 0)


class TestConditionalComplete(unittest.TestCase):
@classmethod
def setUpClass(cls):
wf = ConditionalCompleteWorkflow()
registry.add_workflow(wf)
cls.client = Client(registry)

def test_signal_conditional_complete(self):
self.do_test_conditional_workflow(True)

def do_test_conditional_workflow(self, use_signal: bool):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
self.client.start_workflow(ConditionalCompleteWorkflow, wf_id, 10, use_signal)

for x in range(3):
if use_signal:
self.client.signal_workflow(wf_id, test_signal_channel, 123)
else:
self.client.invoke_rpc(
wf_id, ConditionalCompleteWorkflow.test_rpc_publish_channel
)
if x == 0:
# wait for a second so that the workflow is in execute state
time.sleep(1)

res = self.client.get_simple_workflow_result_with_wait(wf_id)
assert res == 3
1 change: 1 addition & 0 deletions iwf/tests/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def handle_rpc():
# the WebUI will be able to show you the error with stacktrace
@_flask_app.errorhandler(Exception)
def internal_error(exception):
print("encounter errors:", exception)
return traceback.format_exc(), 500


Expand Down

0 comments on commit f2b7c41

Please sign in to comment.