-
Notifications
You must be signed in to change notification settings - Fork 300
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add FlyteLaunchPlan and Notification
Signed-off-by: cosmicBboy <[email protected]>
- Loading branch information
1 parent
df001ed
commit 77f2d9d
Showing
2 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
import uuid as _uuid | ||
from typing import Any, List | ||
|
||
import six as _six | ||
|
||
from flytekit.common.exceptions import scopes as _exception_scopes | ||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.common.mixins import launchable as _launchable_mixin | ||
from flytekit.configuration import sdk as _sdk_config | ||
from flytekit.control_plane import identifier as _identifier | ||
from flytekit.control_plane import interface as _interface | ||
from flytekit.control_plane import nodes as _nodes | ||
from flytekit.control_plane import notifications as _notifications | ||
from flytekit.control_plane import workflow_execution as _workflow_execution | ||
from flytekit.engines.flyte import engine as _flyte_engine | ||
from flytekit.models import common as _common_models | ||
from flytekit.models import execution as _execution_models | ||
from flytekit.models import identifier as _identifier_model | ||
from flytekit.models import interface as _interface_models | ||
from flytekit.models import launch_plan as _launch_plan_models | ||
from flytekit.models import literals as _literal_models | ||
|
||
|
||
class FlyteLaunchPlan( | ||
_launchable_mixin.LaunchableEntity, _launch_plan_models.LaunchPlanSpec, | ||
): | ||
def __init__(self, *args, **kwargs): | ||
super(FlyteLaunchPlan, self).__init__(*args, **kwargs) | ||
# Set all the attributes we expect this class to have | ||
self._id = None | ||
|
||
# The interface is not set explicitly unless fetched in an engine context | ||
self._interface = None | ||
|
||
@classmethod | ||
def promote_from_model(cls, model: _launch_plan_models.LaunchPlanSpec) -> "FlyteLaunchPlan": | ||
return cls( | ||
workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), | ||
default_inputs=_interface_models.ParameterMap( | ||
{ | ||
k: _promises.Input.promote_from_model(v).rename_and_return_reference(k) | ||
for k, v in _six.iteritems(model.default_inputs.parameters) | ||
} | ||
), | ||
fixed_inputs=model.fixed_inputs, | ||
entity_metadata=model.entity_metadata, | ||
labels=model.labels, | ||
annotations=model.annotations, | ||
auth_role=model.auth_role, | ||
raw_output_data_config=model.raw_output_data_config, | ||
) | ||
|
||
@_exception_scopes.system_entry_point | ||
def register(self, project, domain, name, version): | ||
# NOTE: does this need to be implemented in the control plane? | ||
pass | ||
|
||
@classmethod | ||
@_exception_scopes.system_entry_point | ||
def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteLaunchPlan": | ||
""" | ||
This function uses the engine loader to call create a hydrated task from Admin. | ||
:param project: | ||
:param domain: | ||
:param name: | ||
:param version: | ||
""" | ||
from flytekit.control_plane import workflow as _workflow | ||
|
||
launch_plan_id = _identifier.Identifier( | ||
_identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version | ||
) | ||
|
||
lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) | ||
flyte_lp = cls.promote_from_model(lp.spec) | ||
flyte_lp._id = lp.id | ||
|
||
# TODO: Add a test for this, and this function as a whole | ||
wf_id = flyte_lp.workflow_id | ||
lp_wf = _workflow.FlyteWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) | ||
flyte_lp._interface = lp_wf.interface | ||
flyte_lp._has_registered = True | ||
return flyte_lp | ||
|
||
@_exception_scopes.system_entry_point | ||
def serialize(self): | ||
""" | ||
Serializing a launch plan should produce an object similar to what the registration step produces, | ||
in preparation for actual registration to Admin. | ||
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan | ||
""" | ||
# NOTE: does this need to be implemented in the control plane? | ||
pass | ||
|
||
@property | ||
def id(self) -> _identifier.Identifier: | ||
return self._id | ||
|
||
@property | ||
def is_scheduled(self) -> bool: | ||
if self.entity_metadata.schedule.cron_expression: | ||
return True | ||
elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: | ||
return True | ||
elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: | ||
return True | ||
else: | ||
return False | ||
|
||
@property | ||
def auth_role(self) -> _common_models.AuthRole: | ||
fixed_auth = super(FlyteLaunchPlan, self).auth_role | ||
if fixed_auth is not None and ( | ||
fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None | ||
): | ||
return fixed_auth | ||
|
||
assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() | ||
kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() | ||
|
||
if not (assumable_iam_role or kubernetes_service_account): | ||
_logging.warning( | ||
"Using deprecated `role` from config. Please update your config to use `assumable_iam_role` instead" | ||
) | ||
assumable_iam_role = _sdk_config.ROLE.get() | ||
return _common_models.AuthRole( | ||
assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, | ||
) | ||
|
||
@property | ||
def workflow_id(self) -> _identifier.Identifier: | ||
return self._workflow_id | ||
|
||
@property | ||
def interface(self) -> _interface.TypedInterface: | ||
""" | ||
The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and | ||
from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= | ||
object and get a node. | ||
""" | ||
return self._interface | ||
|
||
@property | ||
def resource_type(self) -> _identifier_model.ResourceType: | ||
return _identifier_model.ResourceType.LAUNCH_PLAN | ||
|
||
@property | ||
def entity_type_text(self) -> str: | ||
return "Launch Plan" | ||
|
||
@property | ||
def raw_output_data_config(self) -> _common_models.RawOutputDataConfig: | ||
raw_output_data_config = super(FlyteLaunchPlan, self).raw_output_data_config | ||
if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "": | ||
return raw_output_data_config | ||
|
||
# If it was not set explicitly then let's use the value found in the configuration. | ||
return _common_models.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get()) | ||
|
||
@_exception_scopes.system_entry_point | ||
def validate(self): | ||
# TODO: Validate workflow is satisfied | ||
pass | ||
|
||
@_exception_scopes.system_entry_point | ||
def update(self, state: _launch_plan_models.LaunchPlanState): | ||
if not self.id: | ||
raise _user_exceptions.FlyteAssertion( | ||
"Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " | ||
"or register the identifier first" | ||
) | ||
return _flyte_engine.get_client().update_launch_plan(self.id, state) | ||
|
||
@_deprecated(reason="Use launch_with_literals instead", version="0.9.0") | ||
def execute_with_literals( | ||
self, | ||
project, | ||
domain, | ||
literal_inputs, | ||
name=None, | ||
notification_overrides=None, | ||
label_overrides=None, | ||
annotation_overrides=None, | ||
): | ||
""" | ||
Deprecated. | ||
""" | ||
return self.launch_with_literals( | ||
project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, | ||
) | ||
|
||
@_exception_scopes.system_entry_point | ||
def launch_with_literals( | ||
self, | ||
project: str, | ||
domain: str, | ||
literal_inputs: _literal_models.LiteralMap, | ||
name: str = None, | ||
notification_overrides: List[_notifications.Notification] = None, | ||
label_overrides: _common_models.Labels = None, | ||
annotation_overrides: _common_models.Annotations = None, | ||
) -> _workflow_execution.FlyteWorkflowExecution: | ||
""" | ||
Executes the launch plan and returns the execution identifier. This version of execution is meant for when | ||
you already have a LiteralMap of inputs. | ||
:param project: | ||
:param domain: | ||
:param literal_inputs: Inputs to the execution. | ||
:param name: If specified, an execution will be created with this name. Note: the name must | ||
be unique within the context of the project and domain. | ||
:param notification_overrides: If specified, these are the notifications that will be honored for this | ||
execution. An empty list signals to disable all notifications. | ||
:param label_overrides: | ||
:param annotation_overrides: | ||
""" | ||
# Kubernetes requires names starting with an alphabet for some resources. | ||
name = name or "f" + _uuid.uuid4().hex[:19] | ||
disable_all = notification_overrides == [] | ||
if disable_all: | ||
notification_overrides = None | ||
else: | ||
notification_overrides = _uuid.NotificationList(notification_overrides or []) | ||
disable_all = None | ||
|
||
client = _flyte_engine.get_client() | ||
try: | ||
exec_id = client.create_execution( | ||
project, | ||
domain, | ||
name, | ||
_execution_models.ExecutionSpec( | ||
self.id, | ||
_execution_models.ExecutionMetadata( | ||
_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, | ||
"sdk", # TODO: get principle | ||
0, # TODO: Detect nesting | ||
), | ||
notifications=notification_overrides, | ||
disable_all=disable_all, | ||
labels=label_overrides, | ||
annotations=annotation_overrides, | ||
), | ||
literal_inputs, | ||
) | ||
except _user_exceptions.FlyteEntityAlreadyExistsException: | ||
exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) | ||
return _workflow_execution.FlyteWorkflowExecution.promote_from_model(client.get_execution(exec_id)) | ||
|
||
@_exception_scopes.system_entry_point | ||
def __call__(self, *args, **input_map: Any) -> _nodes.FlyteNode: | ||
r""" | ||
:param args: Do not specify. Kwargs only are supported for this function. | ||
:param input_map: Map of inputs. Can be statically defined or OutputReference links. | ||
""" | ||
if len(args) > 0: | ||
raise _user_exceptions.FlyteAssertion( | ||
"When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We " | ||
"detected {} positional args.".format(len(args)) | ||
) | ||
|
||
# Take the default values from the launch plan | ||
default_inputs = {k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required} | ||
default_inputs.update(input_map) | ||
|
||
# TODO: implement control_plan.interface.TypedInterface.create_bindings_for_inputs method | ||
bindings, upstream_nodes = self.interface.create_bindings_for_inputs(default_inputs) | ||
|
||
return _nodes.FlyteNode( | ||
id=None, | ||
metadata=_workflow_models.NodeMetadata("", _datetime.timedelta(), _literal_models.RetryStrategy(0)), | ||
bindings=sorted(bindings, key=lambda b: b.var), | ||
upstream_nodes=upstream_nodes, | ||
flyte_launch_plan=self, | ||
) | ||
|
||
def __repr__(self) -> str: | ||
return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import List | ||
|
||
from flyteidl.admin import common_pb2 as _common_pb2 | ||
|
||
from flytekit.models import common as _common_model | ||
from flytekit.models import execution as _execution_model | ||
from flytekit.models.core import execution as _core_execution_model | ||
|
||
|
||
Phases = List[_core_execution_model.WorkflowExecutionPhase] | ||
|
||
|
||
class Notification(_common_model.Notification): | ||
|
||
VALID_PHASES = { | ||
_execution_model.WorkflowExecutionPhase.ABORTED, | ||
_execution_model.WorkflowExecutionPhase.FAILED, | ||
_execution_model.WorkflowExecutionPhase.SUCCEEDED, | ||
_execution_model.WorkflowExecutionPhase.TIMED_OUT, | ||
} | ||
|
||
def __init__(self, phases: Phases, email=None, pager_duty=None, slack=None): | ||
""" | ||
:param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for | ||
terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase | ||
""" | ||
self._validate_phases(phases) | ||
super(Notification, self).__init__(phases, email=email, pager_duty=pager_duty, slack=slack) | ||
|
||
def _validate_phases(self, phases: Phases): | ||
""" | ||
:param phases: | ||
""" | ||
if len(phases) == 0: | ||
raise _user_exceptions.FlyteAssertion("You must specify at least one phase for a notification.") | ||
for phase in phases: | ||
if phase not in self.VALID_PHASES: | ||
raise _user_exceptions.FlyteValueException( | ||
phase, | ||
self.VALID_PHASES, | ||
additional_message="Notifications can only be specified on terminal states.", | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, p: _common_pb2.Notification) -> "Notification": | ||
""" | ||
:param p: FlyteIDL Notification | ||
""" | ||
if p.HasField("email"): | ||
return cls(p.phases, p.email.recipients_email) | ||
elif p.HasField("pager_duty"): | ||
return cls(p.phases, p.pager_duty.recipients_email) | ||
else: | ||
return cls(p.phases, p.slack.recipients_email) | ||
|
||
|
||
class PagerDuty(Notification): | ||
def __init__(self, phases: Phases, recipients_email: List[str]): | ||
""" | ||
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal | ||
phases. | ||
:param recipients_email: A required non-empty list of recipients for the notification. | ||
""" | ||
super(PagerDuty, self).__init__(phases, pager_duty=_common_model.PagerDutyNotification(recipients_email)) | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: _common_model.Notification) -> Notification: | ||
return cls(base_model.phases, base_model.pager_duty.recipients_email) | ||
|
||
|
||
class Email(Notification): | ||
def __init__(self, phases: Phases, recipients_email: List[str]): | ||
""" | ||
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal | ||
phases. | ||
:param recipients_email: A required non-empty list of recipients for the notification. | ||
""" | ||
super(Email, self).__init__(phases, email=_common_model.EmailNotification(recipients_email)) | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification": | ||
""" | ||
:param base_model: | ||
""" | ||
return cls(base_model.phases, base_model.email.recipients_email) | ||
|
||
|
||
class Slack(Notification): | ||
def __init__(self, phases: Phases, recipients_email: List[str]): | ||
""" | ||
:param phases: A required list of phases for which to fire the event. Events can only be fired for terminal | ||
phases. | ||
:param recipients_email: A required non-empty list of recipients for the notification. | ||
""" | ||
super(Slack, self).__init__(phases, slack=_common_model.SlackNotification(recipients_email)) | ||
|
||
@classmethod | ||
def promote_from_model(cls, base_model: _common_model.Notification) -> "Notification": | ||
""" | ||
:param base_model: | ||
""" | ||
return cls(base_model.phases, base_model.slack.recipients_email) |