Skip to content

Commit

Permalink
add FlyteLaunchPlan and Notification
Browse files Browse the repository at this point in the history
Signed-off-by: cosmicBboy <[email protected]>
  • Loading branch information
cosmicBboy committed Apr 1, 2021
1 parent df001ed commit 77f2d9d
Show file tree
Hide file tree
Showing 2 changed files with 381 additions and 0 deletions.
279 changes: 279 additions & 0 deletions flytekit/control_plane/launch_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
import uuid as _uuid
from typing import Any, List

import six as _six

from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.mixins import launchable as _launchable_mixin
from flytekit.configuration import sdk as _sdk_config
from flytekit.control_plane import identifier as _identifier
from flytekit.control_plane import interface as _interface
from flytekit.control_plane import nodes as _nodes
from flytekit.control_plane import notifications as _notifications
from flytekit.control_plane import workflow_execution as _workflow_execution
from flytekit.engines.flyte import engine as _flyte_engine
from flytekit.models import common as _common_models
from flytekit.models import execution as _execution_models
from flytekit.models import identifier as _identifier_model
from flytekit.models import interface as _interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import literals as _literal_models


class FlyteLaunchPlan(
_launchable_mixin.LaunchableEntity, _launch_plan_models.LaunchPlanSpec,
):
def __init__(self, *args, **kwargs):
super(FlyteLaunchPlan, self).__init__(*args, **kwargs)
# Set all the attributes we expect this class to have
self._id = None

# The interface is not set explicitly unless fetched in an engine context
self._interface = None

@classmethod
def promote_from_model(cls, model: _launch_plan_models.LaunchPlanSpec) -> "FlyteLaunchPlan":
return cls(
workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id),
default_inputs=_interface_models.ParameterMap(
{
k: _promises.Input.promote_from_model(v).rename_and_return_reference(k)
for k, v in _six.iteritems(model.default_inputs.parameters)
}
),
fixed_inputs=model.fixed_inputs,
entity_metadata=model.entity_metadata,
labels=model.labels,
annotations=model.annotations,
auth_role=model.auth_role,
raw_output_data_config=model.raw_output_data_config,
)

@_exception_scopes.system_entry_point
def register(self, project, domain, name, version):
# NOTE: does this need to be implemented in the control plane?
pass

@classmethod
@_exception_scopes.system_entry_point
def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteLaunchPlan":
"""
This function uses the engine loader to call create a hydrated task from Admin.
:param project:
:param domain:
:param name:
:param version:
"""
from flytekit.control_plane import workflow as _workflow

launch_plan_id = _identifier.Identifier(
_identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version
)

lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id)
flyte_lp = cls.promote_from_model(lp.spec)
flyte_lp._id = lp.id

# TODO: Add a test for this, and this function as a whole
wf_id = flyte_lp.workflow_id
lp_wf = _workflow.FlyteWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
flyte_lp._interface = lp_wf.interface
flyte_lp._has_registered = True
return flyte_lp

@_exception_scopes.system_entry_point
def serialize(self):
"""
Serializing a launch plan should produce an object similar to what the registration step produces,
in preparation for actual registration to Admin.
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan
"""
# NOTE: does this need to be implemented in the control plane?
pass

@property
def id(self) -> _identifier.Identifier:
return self._id

@property
def is_scheduled(self) -> bool:
if self.entity_metadata.schedule.cron_expression:
return True
elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value:
return True
elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule:
return True
else:
return False

@property
def auth_role(self) -> _common_models.AuthRole:
fixed_auth = super(FlyteLaunchPlan, self).auth_role
if fixed_auth is not None and (
fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None
):
return fixed_auth

assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get()
kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get()

if not (assumable_iam_role or kubernetes_service_account):
_logging.warning(
"Using deprecated `role` from config. Please update your config to use `assumable_iam_role` instead"
)
assumable_iam_role = _sdk_config.ROLE.get()
return _common_models.AuthRole(
assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account,
)

@property
def workflow_id(self) -> _identifier.Identifier:
return self._workflow_id

@property
def interface(self) -> _interface.TypedInterface:
"""
The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and
from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the=
object and get a node.
"""
return self._interface

@property
def resource_type(self) -> _identifier_model.ResourceType:
return _identifier_model.ResourceType.LAUNCH_PLAN

@property
def entity_type_text(self) -> str:
return "Launch Plan"

@property
def raw_output_data_config(self) -> _common_models.RawOutputDataConfig:
raw_output_data_config = super(FlyteLaunchPlan, self).raw_output_data_config
if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "":
return raw_output_data_config

# If it was not set explicitly then let's use the value found in the configuration.
return _common_models.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get())

@_exception_scopes.system_entry_point
def validate(self):
# TODO: Validate workflow is satisfied
pass

@_exception_scopes.system_entry_point
def update(self, state: _launch_plan_models.LaunchPlanState):
if not self.id:
raise _user_exceptions.FlyteAssertion(
"Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch "
"or register the identifier first"
)
return _flyte_engine.get_client().update_launch_plan(self.id, state)

@_deprecated(reason="Use launch_with_literals instead", version="0.9.0")
def execute_with_literals(
self,
project,
domain,
literal_inputs,
name=None,
notification_overrides=None,
label_overrides=None,
annotation_overrides=None,
):
"""
Deprecated.
"""
return self.launch_with_literals(
project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides,
)

@_exception_scopes.system_entry_point
def launch_with_literals(
self,
project: str,
domain: str,
literal_inputs: _literal_models.LiteralMap,
name: str = None,
notification_overrides: List[_notifications.Notification] = None,
label_overrides: _common_models.Labels = None,
annotation_overrides: _common_models.Annotations = None,
) -> _workflow_execution.FlyteWorkflowExecution:
"""
Executes the launch plan and returns the execution identifier. This version of execution is meant for when
you already have a LiteralMap of inputs.
:param project:
:param domain:
:param literal_inputs: Inputs to the execution.
:param name: If specified, an execution will be created with this name. Note: the name must
be unique within the context of the project and domain.
:param notification_overrides: If specified, these are the notifications that will be honored for this
execution. An empty list signals to disable all notifications.
:param label_overrides:
:param annotation_overrides:
"""
# Kubernetes requires names starting with an alphabet for some resources.
name = name or "f" + _uuid.uuid4().hex[:19]
disable_all = notification_overrides == []
if disable_all:
notification_overrides = None
else:
notification_overrides = _uuid.NotificationList(notification_overrides or [])
disable_all = None

client = _flyte_engine.get_client()
try:
exec_id = client.create_execution(
project,
domain,
name,
_execution_models.ExecutionSpec(
self.id,
_execution_models.ExecutionMetadata(
_execution_models.ExecutionMetadata.ExecutionMode.MANUAL,
"sdk", # TODO: get principle
0, # TODO: Detect nesting
),
notifications=notification_overrides,
disable_all=disable_all,
labels=label_overrides,
annotations=annotation_overrides,
),
literal_inputs,
)
except _user_exceptions.FlyteEntityAlreadyExistsException:
exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name)
return _workflow_execution.FlyteWorkflowExecution.promote_from_model(client.get_execution(exec_id))

@_exception_scopes.system_entry_point
def __call__(self, *args, **input_map: Any) -> _nodes.FlyteNode:
r"""
:param args: Do not specify. Kwargs only are supported for this function.
:param input_map: Map of inputs. Can be statically defined or OutputReference links.
"""
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
"When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We "
"detected {} positional args.".format(len(args))
)

# Take the default values from the launch plan
default_inputs = {k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required}
default_inputs.update(input_map)

# TODO: implement control_plan.interface.TypedInterface.create_bindings_for_inputs method
bindings, upstream_nodes = self.interface.create_bindings_for_inputs(default_inputs)

return _nodes.FlyteNode(
id=None,
metadata=_workflow_models.NodeMetadata("", _datetime.timedelta(), _literal_models.RetryStrategy(0)),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_launch_plan=self,
)

def __repr__(self) -> str:
return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})"
102 changes: 102 additions & 0 deletions flytekit/control_plane/notifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List

from flyteidl.admin import common_pb2 as _common_pb2

from flytekit.models import common as _common_model
from flytekit.models import execution as _execution_model
from flytekit.models.core import execution as _core_execution_model


Phases = List[_core_execution_model.WorkflowExecutionPhase]


class Notification(_common_model.Notification):

VALID_PHASES = {
_execution_model.WorkflowExecutionPhase.ABORTED,
_execution_model.WorkflowExecutionPhase.FAILED,
_execution_model.WorkflowExecutionPhase.SUCCEEDED,
_execution_model.WorkflowExecutionPhase.TIMED_OUT,
}

def __init__(self, phases: Phases, email=None, pager_duty=None, slack=None):
"""
:param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for
terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase
"""
self._validate_phases(phases)
super(Notification, self).__init__(phases, email=email, pager_duty=pager_duty, slack=slack)

def _validate_phases(self, phases: Phases):
"""
:param phases:
"""
if len(phases) == 0:
raise _user_exceptions.FlyteAssertion("You must specify at least one phase for a notification.")
for phase in phases:
if phase not in self.VALID_PHASES:
raise _user_exceptions.FlyteValueException(
phase,
self.VALID_PHASES,
additional_message="Notifications can only be specified on terminal states.",
)

@classmethod
def from_flyte_idl(cls, p: _common_pb2.Notification) -> "Notification":
"""
:param p: FlyteIDL Notification
"""
if p.HasField("email"):
return cls(p.phases, p.email.recipients_email)
elif p.HasField("pager_duty"):
return cls(p.phases, p.pager_duty.recipients_email)
else:
return cls(p.phases, p.slack.recipients_email)


class PagerDuty(Notification):
def __init__(self, phases: Phases, recipients_email: List[str]):
"""
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal
phases.
:param recipients_email: A required non-empty list of recipients for the notification.
"""
super(PagerDuty, self).__init__(phases, pager_duty=_common_model.PagerDutyNotification(recipients_email))

@classmethod
def promote_from_model(cls, base_model: _common_model.Notification) -> Notification:
return cls(base_model.phases, base_model.pager_duty.recipients_email)


class Email(Notification):
def __init__(self, phases: Phases, recipients_email: List[str]):
"""
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal
phases.
:param recipients_email: A required non-empty list of recipients for the notification.
"""
super(Email, self).__init__(phases, email=_common_model.EmailNotification(recipients_email))

@classmethod
def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification":
"""
:param base_model:
"""
return cls(base_model.phases, base_model.email.recipients_email)


class Slack(Notification):
def __init__(self, phases: Phases, recipients_email: List[str]):
"""
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal
phases.
:param recipients_email: A required non-empty list of recipients for the notification.
"""
super(Slack, self).__init__(phases, slack=_common_model.SlackNotification(recipients_email))

@classmethod
def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification":
"""
:param base_model:
"""
return cls(base_model.phases, base_model.slack.recipients_email)

0 comments on commit 77f2d9d

Please sign in to comment.