diff --git a/flytekit/control_plane/launch_plan.py b/flytekit/control_plane/launch_plan.py new file mode 100644 index 00000000000..a57019068e5 --- /dev/null +++ b/flytekit/control_plane/launch_plan.py @@ -0,0 +1,279 @@ +import uuid as _uuid +from typing import Any, List + +import six as _six + +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import launchable as _launchable_mixin +from flytekit.configuration import sdk as _sdk_config +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interface +from flytekit.control_plane import nodes as _nodes +from flytekit.control_plane import notifications as _notifications +from flytekit.control_plane import workflow_execution as _workflow_execution +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution_models +from flytekit.models import identifier as _identifier_model +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import literals as _literal_models + + +class FlyteLaunchPlan( + _launchable_mixin.LaunchableEntity, _launch_plan_models.LaunchPlanSpec, +): + def __init__(self, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = None + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + + @classmethod + def promote_from_model(cls, model: _launch_plan_models.LaunchPlanSpec) -> "FlyteLaunchPlan": + return cls( + workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), + default_inputs=_interface_models.ParameterMap( + { + k: _promises.Input.promote_from_model(v).rename_and_return_reference(k) + for k, v in _six.iteritems(model.default_inputs.parameters) + } + ), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + ) + + @_exception_scopes.system_entry_point + def register(self, project, domain, name, version): + # NOTE: does this need to be implemented in the control plane? + pass + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteLaunchPlan": + """ + This function uses the engine loader to call create a hydrated task from Admin. + :param project: + :param domain: + :param name: + :param version: + """ + from flytekit.control_plane import workflow as _workflow + + launch_plan_id = _identifier.Identifier( + _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version + ) + + lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) + flyte_lp = cls.promote_from_model(lp.spec) + flyte_lp._id = lp.id + + # TODO: Add a test for this, and this function as a whole + wf_id = flyte_lp.workflow_id + lp_wf = _workflow.FlyteWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) + flyte_lp._interface = lp_wf.interface + flyte_lp._has_registered = True + return flyte_lp + + @_exception_scopes.system_entry_point + def serialize(self): + """ + Serializing a launch plan should produce an object similar to what the registration step produces, + in preparation for actual registration to Admin. + + :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan + """ + # NOTE: does this need to be implemented in the control plane? + pass + + @property + def id(self) -> _identifier.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def auth_role(self) -> _common_models.AuthRole: + fixed_auth = super(FlyteLaunchPlan, self).auth_role + if fixed_auth is not None and ( + fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None + ): + return fixed_auth + + assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() + kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() + + if not (assumable_iam_role or kubernetes_service_account): + _logging.warning( + "Using deprecated `role` from config. Please update your config to use `assumable_iam_role` instead" + ) + assumable_iam_role = _sdk_config.ROLE.get() + return _common_models.AuthRole( + assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + ) + + @property + def workflow_id(self) -> _identifier.Identifier: + return self._workflow_id + + @property + def interface(self) -> _interface.TypedInterface: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + @property + def raw_output_data_config(self) -> _common_models.RawOutputDataConfig: + raw_output_data_config = super(FlyteLaunchPlan, self).raw_output_data_config + if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "": + return raw_output_data_config + + # If it was not set explicitly then let's use the value found in the configuration. + return _common_models.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get()) + + @_exception_scopes.system_entry_point + def validate(self): + # TODO: Validate workflow is satisfied + pass + + @_exception_scopes.system_entry_point + def update(self, state: _launch_plan_models.LaunchPlanState): + if not self.id: + raise _user_exceptions.FlyteAssertion( + "Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " + "or register the identifier first" + ) + return _flyte_engine.get_client().update_launch_plan(self.id, state) + + @_deprecated(reason="Use launch_with_literals instead", version="0.9.0") + def execute_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): + """ + Deprecated. + """ + return self.launch_with_literals( + project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + ) + + @_exception_scopes.system_entry_point + def launch_with_literals( + self, + project: str, + domain: str, + literal_inputs: _literal_models.LiteralMap, + name: str = None, + notification_overrides: List[_notifications.Notification] = None, + label_overrides: _common_models.Labels = None, + annotation_overrides: _common_models.Annotations = None, + ) -> _workflow_execution.FlyteWorkflowExecution: + """ + Executes the launch plan and returns the execution identifier. This version of execution is meant for when + you already have a LiteralMap of inputs. + + :param project: + :param domain: + :param literal_inputs: Inputs to the execution. + :param name: If specified, an execution will be created with this name. Note: the name must + be unique within the context of the project and domain. + :param notification_overrides: If specified, these are the notifications that will be honored for this + execution. An empty list signals to disable all notifications. + :param label_overrides: + :param annotation_overrides: + """ + # Kubernetes requires names starting with an alphabet for some resources. + name = name or "f" + _uuid.uuid4().hex[:19] + disable_all = notification_overrides == [] + if disable_all: + notification_overrides = None + else: + notification_overrides = _uuid.NotificationList(notification_overrides or []) + disable_all = None + + client = _flyte_engine.get_client() + try: + exec_id = client.create_execution( + project, + domain, + name, + _execution_models.ExecutionSpec( + self.id, + _execution_models.ExecutionMetadata( + _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, + "sdk", # TODO: get principle + 0, # TODO: Detect nesting + ), + notifications=notification_overrides, + disable_all=disable_all, + labels=label_overrides, + annotations=annotation_overrides, + ), + literal_inputs, + ) + except _user_exceptions.FlyteEntityAlreadyExistsException: + exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) + return _workflow_execution.FlyteWorkflowExecution.promote_from_model(client.get_execution(exec_id)) + + @_exception_scopes.system_entry_point + def __call__(self, *args, **input_map: Any) -> _nodes.FlyteNode: + r""" + :param args: Do not specify. Kwargs only are supported for this function. + :param input_map: Map of inputs. Can be statically defined or OutputReference links. + """ + if len(args) > 0: + raise _user_exceptions.FlyteAssertion( + "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We " + "detected {} positional args.".format(len(args)) + ) + + # Take the default values from the launch plan + default_inputs = {k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required} + default_inputs.update(input_map) + + # TODO: implement control_plan.interface.TypedInterface.create_bindings_for_inputs method + bindings, upstream_nodes = self.interface.create_bindings_for_inputs(default_inputs) + + return _nodes.FlyteNode( + id=None, + metadata=_workflow_models.NodeMetadata("", _datetime.timedelta(), _literal_models.RetryStrategy(0)), + bindings=sorted(bindings, key=lambda b: b.var), + upstream_nodes=upstream_nodes, + flyte_launch_plan=self, + ) + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})" diff --git a/flytekit/control_plane/notifications.py b/flytekit/control_plane/notifications.py new file mode 100644 index 00000000000..20f709e99f2 --- /dev/null +++ b/flytekit/control_plane/notifications.py @@ -0,0 +1,102 @@ +from typing import List + +from flyteidl.admin import common_pb2 as _common_pb2 + +from flytekit.models import common as _common_model +from flytekit.models import execution as _execution_model +from flytekit.models.core import execution as _core_execution_model + + +Phases = List[_core_execution_model.WorkflowExecutionPhase] + + +class Notification(_common_model.Notification): + + VALID_PHASES = { + _execution_model.WorkflowExecutionPhase.ABORTED, + _execution_model.WorkflowExecutionPhase.FAILED, + _execution_model.WorkflowExecutionPhase.SUCCEEDED, + _execution_model.WorkflowExecutionPhase.TIMED_OUT, + } + + def __init__(self, phases: Phases, email=None, pager_duty=None, slack=None): + """ + :param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for + terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase + """ + self._validate_phases(phases) + super(Notification, self).__init__(phases, email=email, pager_duty=pager_duty, slack=slack) + + def _validate_phases(self, phases: Phases): + """ + :param phases: + """ + if len(phases) == 0: + raise _user_exceptions.FlyteAssertion("You must specify at least one phase for a notification.") + for phase in phases: + if phase not in self.VALID_PHASES: + raise _user_exceptions.FlyteValueException( + phase, + self.VALID_PHASES, + additional_message="Notifications can only be specified on terminal states.", + ) + + @classmethod + def from_flyte_idl(cls, p: _common_pb2.Notification) -> "Notification": + """ + :param p: FlyteIDL Notification + """ + if p.HasField("email"): + return cls(p.phases, p.email.recipients_email) + elif p.HasField("pager_duty"): + return cls(p.phases, p.pager_duty.recipients_email) + else: + return cls(p.phases, p.slack.recipients_email) + + +class PagerDuty(Notification): + def __init__(self, phases: Phases, recipients_email: List[str]): + """ + :param phases: A required list of phases for which to fire the event. Events can only be fired for terminal + phases. + :param recipients_email: A required non-empty list of recipients for the notification. + """ + super(PagerDuty, self).__init__(phases, pager_duty=_common_model.PagerDutyNotification(recipients_email)) + + @classmethod + def promote_from_model(cls, base_model: _common_model.Notification) -> Notification: + return cls(base_model.phases, base_model.pager_duty.recipients_email) + + +class Email(Notification): + def __init__(self, phases: Phases, recipients_email: List[str]): + """ + :param phases: A required list of phases for which to fire the event. Events can only be fired for terminal + phases. + :param recipients_email: A required non-empty list of recipients for the notification. + """ + super(Email, self).__init__(phases, email=_common_model.EmailNotification(recipients_email)) + + @classmethod + def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification": + """ + :param base_model: + """ + return cls(base_model.phases, base_model.email.recipients_email) + + +class Slack(Notification): + def __init__(self, phases: Phases, recipients_email: List[str]): + """ + :param phases: A required list of phases for which to fire the event. Events can only be fired for terminal + phases. + :param recipients_email: A required non-empty list of recipients for the notification. + """ + super(Slack, self).__init__(phases, slack=_common_model.SlackNotification(recipients_email)) + + @classmethod + def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification": + """ + :param base_model: + """ + return cls(base_model.phases, base_model.slack.recipients_email)