From d0089176870b0834d30ca39ef20808a1ad5036fa Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 5 Jan 2023 13:34:43 -0800 Subject: [PATCH] Signal use (#1398) Signed-off-by: Yee Hing Tong --- flytekit/clients/raw.py | 17 +++++ flytekit/remote/entities.py | 29 +++++--- flytekit/remote/remote.py | 67 ++++++++++++++++++- flytekit/tools/translator.py | 5 ++ .../types/structured/structured_dataset.py | 2 +- tests/flytekit/unit/clients/test_raw.py | 8 ++- tests/flytekit/unit/core/test_gate.py | 35 +++++++++- tests/flytekit/unit/core/test_signal.py | 42 ++++++++++++ 8 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 tests/flytekit/unit/core/test_signal.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 7c4439d83d..6c8f54e9ce 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -10,11 +10,13 @@ import grpc import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest +from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse from flyteidl.service import admin_pb2_grpc as _admin_service from flyteidl.service import auth_pb2 from flyteidl.service import auth_pb2_grpc as auth_service from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2 from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service +from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub from google.protobuf.json_format import MessageToJson as _MessageToJson @@ -145,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) + self._signal = signal_service.SignalServiceStub(self._channel) try: resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) self._public_client_config = resp @@ -406,6 +409,20 @@ def get_task(self, get_object_request): """ return self._stub.GetTask(get_object_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse: + """ + This sets a signal + """ + return self._signal.SetSignal(signal_set_request, metadata=self._metadata) + + @_handle_rpc_error(retry=True) + def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: + """ + This lists signals + """ + return self._signal.ListSignals(signal_list_request, metadata=self._metadata) + #################################################################################################################### # # Workflow Endpoints diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 0c745c11bb..c9de5aea33 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -334,6 +334,12 @@ def promote_from_model( return cls(new_if_else_block), converted_sub_workflows +class FlyteGateNode(_workflow_model.GateNode): + @classmethod + def promote_from_model(cls, model: _workflow_model.GateNode): + return cls(model.signal, model.sleep, model.approve) + + class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): """A class encapsulating a remote Flyte node.""" @@ -343,22 +349,23 @@ def __init__( upstream_nodes, bindings, metadata, - task_node: FlyteTaskNode = None, - workflow_node: FlyteWorkflowNode = None, - branch_node: FlyteBranchNode = None, + task_node: Optional[FlyteTaskNode] = None, + workflow_node: Optional[FlyteWorkflowNode] = None, + branch_node: Optional[FlyteBranchNode] = None, + gate_node: Optional[FlyteGateNode] = None, ): - if not task_node and not workflow_node and not branch_node: + if not task_node and not workflow_node and not branch_node and not gate_node: raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one of task|workflow|branch entity specified at once" + "An Flyte node must have one of task|workflow|branch|gate entity specified at once" ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. + # TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead + # of a node? if task_node: self._flyte_entity = task_node.flyte_task elif workflow_node: self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan else: - self._flyte_entity = branch_node + self._flyte_entity = branch_node or gate_node super(FlyteNode, self).__init__( id=id, @@ -369,6 +376,7 @@ def __init__( task_node=task_node, workflow_node=workflow_node, branch_node=branch_node, + gate_node=gate_node, ) self._upstream = upstream_nodes @@ -412,7 +420,7 @@ def promote_from_model( remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") return None, converted_sub_workflows - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None + flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None if model.task_node is not None: if model.task_node.reference_id not in tasks: raise RuntimeError( @@ -435,6 +443,8 @@ def promote_from_model( tasks, converted_sub_workflows, ) + elif model.gate_node is not None: + flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) else: raise _system_exceptions.FlyteSystemException( f"Bad Node model, neither task nor workflow detected, node: {model}" @@ -459,6 +469,7 @@ def promote_from_model( task_node=flyte_task_node, workflow_node=flyte_workflow_node, branch_node=flyte_branch_node, + gate_node=flyte_gate_node, ), converted_sub_workflows, ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 6473d46ec9..edd899d081 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -17,6 +17,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 from flytekit import Literal @@ -40,11 +41,12 @@ from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models from flytekit.models import task as task_models +from flytekit.models import types as type_models from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.common import Sort from flytekit.models.core import workflow as workflow_model -from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier +from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( ExecutionMetadata, @@ -350,6 +352,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N # Listing Entities # ###################### + def list_signals( + self, + execution_name: str, + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + limit: int = 100, + filters: typing.Optional[typing.List[filter_models.Filter]] = None, + ) -> typing.List[Signal]: + """ + :param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param limit: The number of signals to fetch + :param filters: Optional list of filters + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters) + resp = self.client.list_signals(req) + s = resp.signals + return s + + def set_signal( + self, + signal_id: str, + execution_name: str, + value: typing.Union[literal_models.Literal, typing.Any], + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + python_type: typing.Optional[typing.Type] = None, + literal_type: typing.Optional[type_models.LiteralType] = None, + ): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to + convert into a Literal. This argument is only value for wait_for_input type signals. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param python_type: Provide a python type to help with conversion if the value you provided is not a Literal. + :param literal_type: Provide a Flyte literal type to help with conversion if the value you provided + is not a Literal + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + if isinstance(value, Literal): + remote_logger.debug(f"Using provided {value} as existing Literal value") + lit = value + else: + lt = literal_type or ( + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) + ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) + remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") + + req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + def recent_executions( self, project: typing.Optional[str] = None, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index f0ad5e96c6..ec2bdb0cb9 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -658,6 +658,11 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) + elif isinstance(entity, GateNode): + import ipdb + + ipdb.set_trace() + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index ec6b367c20..0e4649203a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -484,7 +484,7 @@ def register_for_protocol( if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: if h.python_type in cls.DEFAULT_FORMATS and not override: if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format: - logger.debug( + logger.info( f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." ) else: diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index b3f1807b96..10a7e09333 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock: return auth_stub_mock +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.get_basic_authorization_header") @mock.patch("flytekit.clients.raw.get_token") @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic( mock_get_token, mock_get_basic_header, mock_dataproxy, + mock_signal, ): mock_secure_channel.return_value = True mock_channel.return_value = True @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic( assert client._metadata[0][0] == "authorization" +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index a4689ed814..c92e1c9e19 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -13,7 +13,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.tools.translator import get_serializable +from flytekit.remote.entities import FlyteWorkflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( @@ -290,3 +291,35 @@ def cond_wf(a: int) -> float: x = cond_wf(a=3) assert x == 6 assert stdin.read() == "" + + +def test_promote(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @workflow + def wf(a: int) -> typing.Tuple[int, int, int]: + zzz = sleep(timedelta(seconds=10)) + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + zzz >> x + x >> s1 + s1 >> z + + return y, z, q + + entries = OrderedDict() + wf_spec = get_serializable(entries, serialization_settings, wf) + tts, wf_specs, lp_specs = gather_dependent_entities(entries) + + fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts) + assert fwf.template.nodes[2].gate_node is not None diff --git a/tests/flytekit/unit/core/test_signal.py b/tests/flytekit/unit/core/test_signal.py new file mode 100644 index 0000000000..a37da8955f --- /dev/null +++ b/tests/flytekit/unit/core/test_signal.py @@ -0,0 +1,42 @@ +from flyteidl.admin.signal_pb2 import Signal, SignalList +from mock import MagicMock + +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier +from flytekit.remote.remote import FlyteRemote + + +def test_remote_list_signals(): + ctx = FlyteContextManager.current_context() + wfeid = WorkflowExecutionIdentifier("p", "d", "execid") + signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl() + lt = TypeEngine.to_literal_type(int) + signal = Signal( + id=signal_id, + type=lt.to_flyte_idl(), + value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(), + ) + + mock_client = MagicMock() + mock_client.list_signals.return_value = SignalList(signals=[signal], token="") + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + remote._client = mock_client + res = remote.list_signals("execid", "p", "d", limit=10) + assert len(res) == 1 + + +def test_remote_set_signal(): + mock_client = MagicMock() + + def checker(request): + assert request.id.signal_id == "sigid" + assert request.value.scalar.primitive.integer == 3 + + mock_client.set_signal.side_effect = checker + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + remote._client = mock_client + remote.set_signal("sigid", "execid", 3)