Skip to content

Commit

Permalink
Signal use (#1398)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jan 5, 2023
1 parent 0dcda1d commit d008917
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 14 deletions.
17 changes: 17 additions & 0 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import grpc
import requests as _requests
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import auth_pb2
from flyteidl.service import auth_pb2_grpc as auth_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service
from flyteidl.service import signal_pb2_grpc as signal_service
from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub
from google.protobuf.json_format import MessageToJson as _MessageToJson

Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
)
self._stub = _admin_service.AdminServiceStub(self._channel)
self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel)
self._signal = signal_service.SignalServiceStub(self._channel)
try:
resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest())
self._public_client_config = resp
Expand Down Expand Up @@ -406,6 +409,20 @@ def get_task(self, get_object_request):
"""
return self._stub.GetTask(get_object_request, metadata=self._metadata)

@_handle_rpc_error(retry=True)
def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse:
"""
This sets a signal
"""
return self._signal.SetSignal(signal_set_request, metadata=self._metadata)

@_handle_rpc_error(retry=True)
def list_signals(self, signal_list_request: SignalListRequest) -> SignalList:
"""
This lists signals
"""
return self._signal.ListSignals(signal_list_request, metadata=self._metadata)

####################################################################################################################
#
# Workflow Endpoints
Expand Down
29 changes: 20 additions & 9 deletions flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def promote_from_model(
return cls(new_if_else_block), converted_sub_workflows


class FlyteGateNode(_workflow_model.GateNode):
@classmethod
def promote_from_model(cls, model: _workflow_model.GateNode):
return cls(model.signal, model.sleep, model.approve)


class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node):
"""A class encapsulating a remote Flyte node."""

Expand All @@ -343,22 +349,23 @@ def __init__(
upstream_nodes,
bindings,
metadata,
task_node: FlyteTaskNode = None,
workflow_node: FlyteWorkflowNode = None,
branch_node: FlyteBranchNode = None,
task_node: Optional[FlyteTaskNode] = None,
workflow_node: Optional[FlyteWorkflowNode] = None,
branch_node: Optional[FlyteBranchNode] = None,
gate_node: Optional[FlyteGateNode] = None,
):
if not task_node and not workflow_node and not branch_node:
if not task_node and not workflow_node and not branch_node and not gate_node:
raise _user_exceptions.FlyteAssertion(
"An Flyte node must have one of task|workflow|branch entity specified at once"
"An Flyte node must have one of task|workflow|branch|gate entity specified at once"
)
# todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from
# the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it.
# TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead
# of a node?
if task_node:
self._flyte_entity = task_node.flyte_task
elif workflow_node:
self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan
else:
self._flyte_entity = branch_node
self._flyte_entity = branch_node or gate_node

super(FlyteNode, self).__init__(
id=id,
Expand All @@ -369,6 +376,7 @@ def __init__(
task_node=task_node,
workflow_node=workflow_node,
branch_node=branch_node,
gate_node=gate_node,
)
self._upstream = upstream_nodes

Expand Down Expand Up @@ -412,7 +420,7 @@ def promote_from_model(
remote_logger.warning(f"Should not call promote from model on a start node or end node {model}")
return None, converted_sub_workflows

flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None
flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None
if model.task_node is not None:
if model.task_node.reference_id not in tasks:
raise RuntimeError(
Expand All @@ -435,6 +443,8 @@ def promote_from_model(
tasks,
converted_sub_workflows,
)
elif model.gate_node is not None:
flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node)
else:
raise _system_exceptions.FlyteSystemException(
f"Bad Node model, neither task nor workflow detected, node: {model}"
Expand All @@ -459,6 +469,7 @@ def promote_from_model(
task_node=flyte_task_node,
workflow_node=flyte_workflow_node,
branch_node=flyte_branch_node,
gate_node=flyte_gate_node,
),
converted_sub_workflows,
)
Expand Down
67 changes: 66 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta

from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2

from flytekit import Literal
Expand All @@ -40,11 +41,12 @@
from flytekit.models import launch_plan as launch_plan_models
from flytekit.models import literals as literal_models
from flytekit.models import task as task_models
from flytekit.models import types as type_models
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.common import Sort
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.execution import (
ExecutionMetadata,
Expand Down Expand Up @@ -350,6 +352,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N
# Listing Entities #
######################

def list_signals(
self,
execution_name: str,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
limit: int = 100,
filters: typing.Optional[typing.List[filter_models.Filter]] = None,
) -> typing.List[Signal]:
"""
:param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param limit: The number of signals to fetch
:param filters: Optional list of filters
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters)
resp = self.client.list_signals(req)
s = resp.signals
return s

def set_signal(
self,
signal_id: str,
execution_name: str,
value: typing.Union[literal_models.Literal, typing.Any],
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
python_type: typing.Optional[typing.Type] = None,
literal_type: typing.Optional[type_models.LiteralType] = None,
):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to
convert into a Literal. This argument is only value for wait_for_input type signals.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param python_type: Provide a python type to help with conversion if the value you provided is not a Literal.
:param literal_type: Provide a Flyte literal type to help with conversion if the value you provided
is not a Literal
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
if isinstance(value, Literal):
remote_logger.debug(f"Using provided {value} as existing Literal value")
lit = value
else:
lt = literal_type or (
TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value))
)
lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt)
remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}")

req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl())

# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)

def recent_executions(
self,
project: typing.Optional[str] = None,
Expand Down
5 changes: 5 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,11 @@ def get_serializable(
elif isinstance(entity, BranchNode):
cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options)

elif isinstance(entity, GateNode):
import ipdb

ipdb.set_trace()

elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow):
if entity.should_register:
if isinstance(entity, FlyteTask):
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def register_for_protocol(
if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
if h.python_type in cls.DEFAULT_FORMATS and not override:
if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
logger.debug(
logger.info(
f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
)
else:
Expand Down
8 changes: 6 additions & 2 deletions tests/flytekit/unit/clients/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock:
return auth_stub_mock


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.auth_service")
@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
@mock.patch("flytekit.clients.raw.grpc.secure_channel")
def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy):
def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
mock_secure_channel.return_value = True
mock_channel.return_value = True
mock_admin.AdminServiceStub.return_value = True
Expand Down Expand Up @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad
mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key)


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.get_basic_authorization_header")
@mock.patch("flytekit.clients.raw.get_token")
Expand All @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic(
mock_get_token,
mock_get_basic_header,
mock_dataproxy,
mock_signal,
):
mock_secure_channel.return_value = True
mock_channel.return_value = True
Expand All @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic(
assert client._metadata[0][0] == "authorization"


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.auth_service")
@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
@mock.patch("flytekit.clients.raw.grpc.secure_channel")
def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy):
def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
mock_secure_channel.return_value = True
mock_channel.return_value = True
mock_admin.AdminServiceStub.return_value = True
Expand Down
35 changes: 34 additions & 1 deletion tests/flytekit/unit/core/test_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
from flytekit.remote.entities import FlyteWorkflow
from flytekit.tools.translator import gather_dependent_entities, get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
Expand Down Expand Up @@ -290,3 +291,35 @@ def cond_wf(a: int) -> float:
x = cond_wf(a=3)
assert x == 6
assert stdin.read() == ""


def test_promote():
@task
def t1(a: int) -> int:
return a + 5

@task
def t2(a: int) -> int:
return a + 6

@workflow
def wf(a: int) -> typing.Tuple[int, int, int]:
zzz = sleep(timedelta(seconds=10))
x = t1(a=a)
s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool)
s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int)
z = t1(a=5)
y = t2(a=s2)
q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2)))
zzz >> x
x >> s1
s1 >> z

return y, z, q

entries = OrderedDict()
wf_spec = get_serializable(entries, serialization_settings, wf)
tts, wf_specs, lp_specs = gather_dependent_entities(entries)

fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts)
assert fwf.template.nodes[2].gate_node is not None
42 changes: 42 additions & 0 deletions tests/flytekit/unit/core/test_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from flyteidl.admin.signal_pb2 import Signal, SignalList
from mock import MagicMock

from flytekit.configuration import Config
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.remote.remote import FlyteRemote


def test_remote_list_signals():
ctx = FlyteContextManager.current_context()
wfeid = WorkflowExecutionIdentifier("p", "d", "execid")
signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl()
lt = TypeEngine.to_literal_type(int)
signal = Signal(
id=signal_id,
type=lt.to_flyte_idl(),
value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(),
)

mock_client = MagicMock()
mock_client.list_signals.return_value = SignalList(signals=[signal], token="")

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
res = remote.list_signals("execid", "p", "d", limit=10)
assert len(res) == 1


def test_remote_set_signal():
mock_client = MagicMock()

def checker(request):
assert request.id.signal_id == "sigid"
assert request.value.scalar.primitive.integer == 3

mock_client.set_signal.side_effect = checker

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
remote.set_signal("sigid", "execid", 3)

0 comments on commit d008917

Please sign in to comment.