From 2a3aaf9a572236997f39d857eae4b1cc6e03bf7d Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 3 Jun 2020 14:57:22 -0700 Subject: [PATCH] Implement launch single task execution (#115) --- flytekit/__init__.py | 2 +- flytekit/clis/flyte_cli/main.py | 46 ++++++++++- flytekit/clis/sdk_in_container/launch_plan.py | 6 +- flytekit/common/launch_plan.py | 42 ++++++---- .../mixins/{executable.py => launchable.py} | 42 ++++++++-- flytekit/common/tasks/task.py | 47 ++++++++++- flytekit/common/workflow.py | 9 +-- flytekit/engines/common.py | 26 ++++++- flytekit/engines/flyte/engine.py | 78 ++++++++++++++++++- flytekit/engines/unit/engine.py | 4 + flytekit/models/common.py | 53 +++++++++++++ flytekit/models/execution.py | 13 +++- flytekit/models/launch_plan.py | 17 ++-- setup.py | 2 +- .../unit/common_tests/test_launch_plan.py | 12 +-- .../unit/engines/flyte/test_engine.py | 6 +- tests/flytekit/unit/models/test_common.py | 14 ++++ .../flytekit/unit/models/test_launch_plan.py | 14 ---- 18 files changed, 357 insertions(+), 76 deletions(-) rename flytekit/common/mixins/{executable.py => launchable.py} (66%) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 409eb5113f..928d698d7d 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -2,4 +2,4 @@ import flytekit.plugins -__version__ = '0.8.2' +__version__ = '0.9.0b0' diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 2c573ab960..389de517df 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -18,6 +18,7 @@ parse_args_into_dict as _parse_args_into_dict from flytekit.common import utils as _utils, launch_plan as _launch_plan_common from flytekit.common.core import identifier as _identifier +from flytekit.common.tasks import task as _tasks_common from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file from flytekit.configuration import platform as _platform_config @@ -680,6 +681,49 @@ def get_task(urn, host, insecure): _click.echo("") +@_flyte_cli.command('launch-task', cls=_FlyteSubCommand) +@_project_option +@_domain_option +@_optional_name_option +@_host_option +@_insecure_option +@_urn_option +@_click.argument('task_args', nargs=-1, type=_click.UNPROCESSED) +def launch_task(project, domain, name, host, insecure, urn, task_args): + """ + Kick off a single task execution. Note that the {project, domain, name} specified in the command line + will be for the execution. The project/domain for the task are specified in the urn. + + Use a -- to separate arguments to this cli, and arguments to the task. + e.g. + $ flyte-cli -h localhost:30081 -p flyteexamples -d development launch-task \ + -u tsk:flyteexamples:development:some-task:abc123 -- input=hi \ + other-input=123 moreinput=qwerty + + These arguments are then collected, and passed into the `task_args` variable as a Tuple[Text]. + Users should use the get-task command to ascertain the names of inputs to use. + """ + _welcome_message() + + with _platform_config.URL.get_patcher(host), _platform_config.INSECURE.get_patcher(_tt(insecure)): + task_id = _identifier.Identifier.from_python_std(urn) + task = _tasks_common.SdkTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version) + + text_args = _parse_args_into_dict(task_args) + inputs = {} + for var_name, variable in _six.iteritems(task.interface.inputs): + sdk_type = _type_helpers.get_sdk_type_from_literal_type(variable.type) + if var_name in text_args and text_args[var_name] is not None: + inputs[var_name] = sdk_type.from_string(text_args[var_name]).to_python_std() + + # TODO: Implement notification overrides + # TODO: Implement label overrides + # TODO: Implement annotation overrides + execution = task.launch(project, domain, inputs=inputs, name=name) + _click.secho("Launched execution: {}".format(_tt(execution.id)), fg='blue') + _click.echo("") + + ######################################################################################################################## # # Workflow Commands @@ -1060,7 +1104,7 @@ def execute_launch_plan(project, domain, name, host, insecure, urn, principal, v # TODO: Implement notification overrides # TODO: Implement label overrides # TODO: Implement annotation overrides - execution = lp.execute_with_literals(project, domain, inputs, name=name) + execution = lp.launch_with_literals(project, domain, inputs, name=name) _click.secho("Launched execution: {}".format(_tt(execution.id)), fg='blue') _click.echo("") diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index b3965f60d7..de5bb7d497 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -8,7 +8,6 @@ from flytekit.clis.sdk_in_container import constants as _constants from flytekit.common import utils as _utils from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan -from flytekit.common.mixins import executable as _executable_mixins from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \ IMAGE as _IMAGE from flytekit.models import launch_plan as _launch_plan_model @@ -31,7 +30,8 @@ def list_commands(self, ctx): pkgs = ctx.obj[_constants.CTX_PACKAGES] # Discover all launch plans by loading the modules for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_executable_mixins.ExecutableEntity}, detect_unreferenced_entities=False): + pkgs, include_entities={_SdkLaunchPlan}, + detect_unreferenced_entities=False): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) commands.append(safe_name) lps[safe_name] = lp @@ -52,7 +52,7 @@ def get_command(self, ctx, lp_argument): launch_plan = ctx.obj['lps'][lp_argument] else: for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_executable_mixins.ExecutableEntity}, detect_unreferenced_entities=False): + pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) if lp_argument == safe_name: launch_plan = lp diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py index 28352d2c7a..9c7fdbbeef 100644 --- a/flytekit/common/launch_plan.py +++ b/flytekit/common/launch_plan.py @@ -5,14 +5,15 @@ from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions import scopes as _exception_scopes, user as _user_exceptions -from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin, executable as _executable_mixin +from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin, launchable as _launchable_mixin from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import sdk as _sdk_config, internal as _internal_config, auth as _auth_config +from flytekit.configuration import sdk as _sdk_config, auth as _auth_config from flytekit.engines import loader as _engine_loader from flytekit.models import launch_plan as _launch_plan_models, schedule as _schedule_model, interface as \ _interface_models, literals as _literal_models, common as _common_models from flytekit.models.core import identifier as _identifier_model, workflow as _workflow_models import datetime as _datetime +from deprecated import deprecated as _deprecated import logging as _logging import six as _six import uuid as _uuid @@ -22,7 +23,7 @@ class SdkLaunchPlan( _six.with_metaclass( _sdk_bases.ExtendedSdkType, _launch_plan_models.LaunchPlanSpec, - _executable_mixin.ExecutableEntity, + _launchable_mixin.LaunchableEntity, ) ): def __init__(self, *args, **kwargs): @@ -51,7 +52,7 @@ def promote_from_model(cls, model): entity_metadata=model.entity_metadata, labels=model.labels, annotations=model.annotations, - auth=model.auth, + auth_role=model.auth_role, ) @classmethod @@ -100,11 +101,11 @@ def is_scheduled(self): return False @property - def auth(self): + def auth_role(self): """ - :rtype: flytekit.models.LaunchPlan.Auth + :rtype: flytekit.models.common.AuthRole """ - fixed_auth = super(SdkLaunchPlan, self).auth + fixed_auth = super(SdkLaunchPlan, self).auth_role if fixed_auth is not None and\ (fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None): return fixed_auth @@ -116,8 +117,8 @@ def auth(self): _logging.warning("Using deprecated `role` from config. " "Please update your config to use `assumable_iam_role` instead") assumable_iam_role = _sdk_config.ROLE.get() - return _launch_plan_models.Auth(assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account) + return _common_models.AuthRole(assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account) @property def interface(self): @@ -172,10 +173,19 @@ def _python_std_input_map_to_literal_map(self, inputs): } ) - @_exception_scopes.system_entry_point + @_deprecated(reason="Use launch_with_literals instead", version='0.9.0') def execute_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, label_overrides=None, annotation_overrides=None): """ + Deprecated. + """ + return self.launch_with_literals(project, domain, literal_inputs, name, notification_overrides, label_overrides, + annotation_overrides) + + @_exception_scopes.system_entry_point + def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, + label_overrides=None, annotation_overrides=None): + """ Executes the launch plan and returns the execution identifier. This version of execution is meant for when you already have a LiteralMap of inputs. @@ -193,7 +203,7 @@ def execute_with_literals(self, project, domain, literal_inputs, name=None, noti """ # Kubernetes requires names starting with an alphabet for some resources. name = name or "f" + _uuid.uuid4().hex[:19] - execution = _engine_loader.get_engine().get_launch_plan(self).execute( + execution = _engine_loader.get_engine().get_launch_plan(self).launch( project, domain, name, @@ -258,7 +268,7 @@ def __init__( notifications=None, labels=None, annotations=None, - auth=None, + auth_role=None, ): """ :param flytekit.common.workflow.SdkWorkflow sdk_workflow: @@ -273,16 +283,16 @@ def __init__( :param flytekit.models.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows executed by this launch plan. Any custom kubernetes annotations to apply to workflows executed by this launch plan. - :param flytekit.models.launch_plan.Auth auth: The auth method with which to execute the workflow. + :param flytekit.models.common.Authrole auth_role: The auth method with which to execute the workflow. """ - if role and auth: + if role and auth_role: raise ValueError("Cannot set both role and auth. Role is deprecated, use auth instead.") fixed_inputs = fixed_inputs or {} default_inputs = default_inputs or {} if role: - auth = _launch_plan_models.Auth(assumable_iam_role=role) + auth_role = _common_models.AuthRole(assumable_iam_role=role) # The constructor for SdkLaunchPlan sets the id to None anyways so we don't bother passing in an ID. The ID # should be set in one of three places, @@ -306,7 +316,7 @@ def __init__( ), labels or _common_models.Labels({}), annotations or _common_models.Annotations({}), - auth, + auth_role, ) self._interface = _interface.TypedInterface( {k: v.var for k, v in _six.iteritems(default_inputs)}, diff --git a/flytekit/common/mixins/executable.py b/flytekit/common/mixins/launchable.py similarity index 66% rename from flytekit/common/mixins/executable.py rename to flytekit/common/mixins/launchable.py index 1f4d79f51c..6a42d98084 100644 --- a/flytekit/common/mixins/executable.py +++ b/flytekit/common/mixins/launchable.py @@ -2,14 +2,15 @@ import abc as _abc import six as _six +from deprecated import deprecated as _deprecated -class ExecutableEntity(_six.with_metaclass(_abc.ABCMeta, object)): - def execute(self, project, domain, inputs=None, name=None, notification_overrides=None, label_overrides=None, - annotation_overrides=None): +class LaunchableEntity(_six.with_metaclass(_abc.ABCMeta, object)): + def launch(self, project, domain, inputs=None, name=None, notification_overrides=None, label_overrides=None, + annotation_overrides=None): """ - Executes the entity and returns the execution identifier. This version of execution is meant for when - inputs are specified as Python native types/structures. + Creates a remote execution from the entity and returns the execution identifier. + This version of launch is meant for when inputs are specified as Python native types/structures. :param Text project: :param Text domain: @@ -35,13 +36,29 @@ def execute(self, project, domain, inputs=None, name=None, notification_override annotation_overrides=annotation_overrides, ) + @_deprecated(reason="Use launch instead", version='0.9.0') + def execute(self, project, domain, inputs=None, name=None, notification_overrides=None, label_overrides=None, + annotation_overrides=None): + """ + Deprecated. + """ + return self.launch( + project, + domain, + inputs=inputs, + name=name, + notification_overrides=notification_overrides, + label_overrides=label_overrides, + annotation_overrides=annotation_overrides, + ) + @_abc.abstractmethod def _python_std_input_map_to_literal_map(self, inputs): pass @_abc.abstractmethod - def execute_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, + label_overrides=None, annotation_overrides=None): """ Executes the entity and returns the execution identifier. This version of execution is meant for when you already have a LiteralMap of inputs. @@ -56,6 +73,15 @@ def execute_with_literals(self, project, domain, literal_inputs, name=None, noti notifications. :param flytekit.models.common.Labels label_overrides: :param flytekit.models.common.Annotations annotation_overrides: - :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier + :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier: """ pass + + @_deprecated(reason="Use launch_with_literals instead", version='0.9.0') + def execute_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, + label_overrides=None, annotation_overrides=None): + """ + Deprecated. + """ + return self.launch_with_literals(project, domain, literal_inputs, name, notification_overrides, label_overrides, + annotation_overrides) diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py index e375e76b04..0ac0626cc1 100644 --- a/flytekit/common/tasks/task.py +++ b/flytekit/common/tasks/task.py @@ -4,15 +4,18 @@ import six as _six -from flytekit.common import interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases +from flytekit.common import ( + interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases, workflow_execution as _workflow_execution +) from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin +from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin, launchable as _launchable_mixin from flytekit.configuration import internal as _internal_config from flytekit.engines import loader as _engine_loader from flytekit.models import common as _common_model, task as _task_model from flytekit.models.core import workflow as _workflow_model, identifier as _identifier_model from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.types import helpers as _type_helpers class SdkTask( @@ -21,6 +24,7 @@ class SdkTask( _hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate, _registerable.RegisterableEntity, + _launchable_mixin.LaunchableEntity, ) ): @@ -252,3 +256,42 @@ def __repr__(self): task_type=self.type, interface=self.interface ) + + def _python_std_input_map_to_literal_map(self, inputs): + """ + :param dict[Text,Any] inputs: A dictionary of Python standard inputs that will be type-checked and compiled + to a LiteralMap + :rtype: flytekit.models.literals.LiteralMap + """ + return _type_helpers.pack_python_std_map_to_literal_map(inputs, { + k: _type_helpers.get_sdk_type_from_literal_type(v.type) + for k, v in _six.iteritems(self.interface.inputs) + }) + + @_exception_scopes.system_entry_point + def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, + label_overrides=None, annotation_overrides=None): + """ + Launches a single task execution and returns the execution identifier. + :param Text project: + :param Text domain: + :param flytekit.models.literals.LiteralMap literal_inputs: Inputs to the execution. + :param Text name: [Optional] If specified, an execution will be created with this name. Note: the name must + be unique within the context of the project and domain. + :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these + are the notifications that will be honored for this execution. An empty list signals to disable all + notifications. + :param flytekit.models.common.Labels label_overrides: + :param flytekit.models.common.Annotations annotation_overrides: + :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution + """ + execution = _engine_loader.get_engine().get_task(self).launch( + project, + domain, + name=name, + inputs=literal_inputs, + notification_overrides=notification_overrides, + label_overrides=label_overrides, + annotation_overrides=annotation_overrides, + ) + return _workflow_execution.SdkWorkflowExecution.promote_from_model(execution) diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index 2269740ab4..49accd584c 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -14,8 +14,7 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.configuration import internal as _internal_config from flytekit.engines import loader as _engine_loader -from flytekit.models import interface as _interface_models, literals as _literal_models, \ - launch_plan as _launch_plan_models +from flytekit.models import interface as _interface_models, literals as _literal_models, common as _common_models from flytekit.models.core import workflow as _workflow_models, identifier as _identifier_model from flytekit.common.exceptions import system as _system_exceptions from flytekit.common import constants as _constants @@ -344,8 +343,8 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan if role: assumable_iam_role = role - auth = _launch_plan_models.Auth(assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account) + auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account) return (cls or _launch_plan.SdkRunnableLaunchPlan)( sdk_workflow=self, @@ -358,7 +357,7 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan notifications=notifications, labels=labels, annotations=annotations, - auth=auth, + auth_role=auth_role, ) @_exception_scopes.system_entry_point diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index 017597092c..ee7b3e709a 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -183,7 +183,7 @@ def get_child_executions(self, filters=None): pass -class BaseLaunchPlanExecutor(_six.with_metaclass(_common_models.FlyteABCMeta, object)): +class BaseLaunchPlanLauncher(_six.with_metaclass(_common_models.FlyteABCMeta, object)): def __init__(self, sdk_launch_plan): """ @@ -207,8 +207,8 @@ def register(self, identifier): pass @_abc.abstractmethod - def execute(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + def launch(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, + annotation_overrides=None): """ Registers the launch plan and returns the identifier. :param Text project: @@ -261,6 +261,24 @@ def register(self, identifier): """ pass + @_abc.abstractmethod + def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, + label_overrides=None, annotation_overrides=None, auth_role=None): + """ + Executes the task as a single task execution and returns the identifier. + :param Text project: + :param Text domain: + :param Text name: + :param flytekit.models.literals.LiteralMap inputs: The inputs to pass + :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the + notifications. + :param flytekit.models.common.Labels label_overrides: + :param flytekit.models.common.Annotations annotation_overrides: + :param flytekit.models.common.AuthRole auth_role: + :rtype: flytekit.models.execution.Execution + """ + pass + class BaseExecutionEngineFactory(_six.with_metaclass(_common_models.FlyteABCMeta, object)): """ @@ -287,7 +305,7 @@ def get_task(self, sdk_task): def get_launch_plan(self, sdk_launch_plan): """ :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: - :rtype: BaseLaunchPlanExecutor + :rtype: BaseLaunchPlanLauncher """ pass diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 774aca18a7..49d8c24fab 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -4,6 +4,7 @@ import os as _os import traceback as _traceback from datetime import datetime as _datetime +from deprecated import deprecated as _deprecated import six as _six from flyteidl.core import literals_pb2 as _literals_pb2 @@ -14,7 +15,9 @@ _iterate_task_executions from flytekit.common import utils as _common_utils, constants as _constants from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes -from flytekit.configuration import platform as _platform_config, internal as _internal_config, sdk as _sdk_config +from flytekit.configuration import ( + platform as _platform_config, internal as _internal_config, sdk as _sdk_config, auth as _auth_config, +) from flytekit.engines import common as _common_engine from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces.stats.taggable import get_stats as _get_stats @@ -156,7 +159,7 @@ def fetch_workflow(self, workflow_id): ).client.get_workflow(workflow_id) -class FlyteLaunchPlan(_common_engine.BaseLaunchPlanExecutor): +class FlyteLaunchPlan(_common_engine.BaseLaunchPlanLauncher): def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client @@ -168,10 +171,18 @@ def register(self, identifier): except _user_exceptions.FlyteEntityAlreadyExistsException: pass + @_deprecated(reason="Use launch instead", version='0.9.0') def execute(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, annotation_overrides=None): """ - Executes the launch plan. + Deprecated. Use launch instead. + """ + return self.launch(project, domain, name, inputs, notification_overrides, label_overrides, annotation_overrides) + + def launch(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, + annotation_overrides=None): + """ + Creates a workflow execution using parameters specified in the launch plan. :param Text project: :param Text domain: :param Text name: @@ -339,6 +350,67 @@ def execute(self, inputs, context=None): ) _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True) + def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None, + annotation_overrides=None, auth_role=None): + """ + Executes the task as a single task execution and returns the identifier. + :param Text project: + :param Text domain: + :param Text name: + :param flytekit.models.literals.LiteralMap inputs: The inputs to pass + :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the + notifications. + :param flytekit.models.common.Labels label_overrides: + :param flytekit.models.common.Annotations annotation_overrides: + :param flytekit.models.common.AuthRole auth_role: + :rtype: flytekit.models.execution.Execution + """ + disable_all = (notification_overrides == []) + if disable_all: + notification_overrides = None + else: + notification_overrides = _execution_models.NotificationList( + notification_overrides or [] + ) + disable_all = None + + if not auth_role: + assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() + kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() + + if not (assumable_iam_role or kubernetes_service_account): + _logging.warning("Using deprecated `role` from config. " + "Please update your config to use `assumable_iam_role` instead") + assumable_iam_role = _sdk_config.ROLE.get() + auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account) + + try: + # TODO(katrogan): Add handling to register the underlying task if it's not already. + client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client + exec_id = client.create_execution( + project, + domain, + name, + _execution_models.ExecutionSpec( + self.sdk_task.id, + _execution_models.ExecutionMetadata( + _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, + 'sdk', # TODO: get principle + 0 # TODO: Detect nesting + ), + notifications=notification_overrides, + disable_all=disable_all, + labels=label_overrides, + annotations=annotation_overrides, + auth_role=auth_role, + ), + inputs, + ) + except _user_exceptions.FlyteEntityAlreadyExistsException: + exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) + return client.get_execution(exec_id) + class FlyteWorkflowExecution(_common_engine.BaseWorkflowExecution): diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index f81410ad70..ec6ffd1af5 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -128,6 +128,10 @@ def _transform_for_user_output(self, outputs): def register(self, identifier, version): raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.") + def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None, + annotation_overrides=None, auth_role=None): + raise _user_exceptions.FlyteAssertion("You cannot launch unit test tasks.") + class ReturnOutputsTask(UnitTestEngineTask): def _transform_for_user_output(self, outputs): diff --git a/flytekit/models/common.py b/flytekit/models/common.py index efc9393408..7b2376e028 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -421,3 +421,56 @@ def from_flyte_idl(cls, pb): :rtype: UrlBlob """ return cls(url=pb.url, bytes=pb.bytes) + + +class AuthRole(FlyteIdlEntity): + def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): + """ + At most one of assumable_iam_role or kubernetes_service_account can be set. + :param Text assumable_iam_role: IAM identity with set permissions policies. + :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment + administrators are responsible for handling permissions as they relate to the service account. + """ + if assumable_iam_role and kubernetes_service_account: + raise ValueError("Only one of assumable_iam_role or kubernetes_service_account can be set") + self._assumable_iam_role = assumable_iam_role + self._kubernetes_service_account = kubernetes_service_account + + @property + def assumable_iam_role(self): + """ + The IAM role to execute the workflow with + :rtype: Text + """ + return self._assumable_iam_role + + @property + def kubernetes_service_account(self): + """ + The kubernetes service account to execute the workflow with + :rtype: Text + """ + return self._kubernetes_service_account + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.launch_plan_pb2.Auth + """ + return _common_pb2.AuthRole( + assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None, + kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.launch_plan_pb2.Auth pb2_object: + :rtype: Auth + """ + return cls( + assumable_iam_role=pb2_object.assumable_iam_role if pb2_object.HasField("assumable_iam_role") else None, + kubernetes_service_account=pb2_object.kubernetes_service_account if + pb2_object.HasField("kubernetes_service_account") else None, + ) + + diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 0ab80c96e7..e6fed7abb6 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -76,7 +76,7 @@ def from_flyte_idl(cls, pb2_object): class ExecutionSpec(_common_models.FlyteIdlEntity): def __init__(self, launch_plan, metadata, notifications=None, disable_all=None, labels=None, - annotations=None): + annotations=None, auth_role=None): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute :param ExecutionMetadata metadata: The metadata to be associated with this execution @@ -84,6 +84,7 @@ def __init__(self, launch_plan, metadata, notifications=None, disable_all=None, :param bool disable_all: If true, all notifications should be disabled. :param flytekit.models.common.Labels labels: Labels to apply to the execution. :param flytekit.models.common.Annotations annotations: Annotations to apply to the execution + :param flytekit.models.common.AuthRole auth_role: The authorization method with which to execute the workflow. """ self._launch_plan = launch_plan @@ -92,6 +93,7 @@ def __init__(self, launch_plan, metadata, notifications=None, disable_all=None, self._disable_all = disable_all self._labels = labels or _common_models.Labels({}) self._annotations = annotations or _common_models.Annotations({}) + self._auth_role = auth_role or _common_models.AuthRole() @property def launch_plan(self): @@ -136,6 +138,13 @@ def annotations(self): """ return self._annotations + @property + def auth_role(self): + """ + :rtype: flytekit.models.common.AuthRole + """ + return self._auth_role + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionSpec @@ -147,6 +156,7 @@ def to_flyte_idl(self): disable_all=self.disable_all, labels=self.labels.to_flyte_idl(), annotations=self.annotations.to_flyte_idl(), + auth_role=self._auth_role.to_flyte_idl() if self.auth_role else None, ) @classmethod @@ -162,6 +172,7 @@ def from_flyte_idl(cls, p): disable_all=p.disable_all if p.HasField("disable_all") else None, labels=_common_models.Labels.from_flyte_idl(p.labels), annotations=_common_models.Annotations.from_flyte_idl(p.annotations), + auth_role=_common_models.AuthRole.from_flyte_idl(p.auth_role), ) diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index 22e91a60bb..341d37a9b8 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -58,6 +58,7 @@ def from_flyte_idl(cls, pb2_object): class Auth(_common.FlyteIdlEntity): def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): """ + DEPRECATED. Do not use. Use flytekit.models.common.AuthRole instead At most one of assumable_iam_role or kubernetes_service_account can be set. :param Text assumable_iam_role: IAM identity with set permissions policies. :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment @@ -108,7 +109,7 @@ def from_flyte_idl(cls, pb2_object): class LaunchPlanSpec(_common.FlyteIdlEntity): - def __init__(self, workflow_id, entity_metadata, default_inputs, fixed_inputs, labels, annotations, auth): + def __init__(self, workflow_id, entity_metadata, default_inputs, fixed_inputs, labels, annotations, auth_role): """ The spec for a Launch Plan. @@ -120,7 +121,7 @@ def __init__(self, workflow_id, entity_metadata, default_inputs, fixed_inputs, l Any custom kubernetes labels to apply to workflows executed by this launch plan. :param flyteidl.admin.common_pb2.Annotations annotations: Any custom kubernetes annotations to apply to workflows executed by this launch plan. - :param flytekit.models.launch_plan.Auth auth: The auth method with which to execute the workflow. + :param flytekit.models.common.Auth auth_role: The auth method with which to execute the workflow. """ self._workflow_id = workflow_id self._entity_metadata = entity_metadata @@ -128,7 +129,7 @@ def __init__(self, workflow_id, entity_metadata, default_inputs, fixed_inputs, l self._fixed_inputs = fixed_inputs self._labels = labels self._annotations = annotations - self._auth = auth + self._auth_role = auth_role @property def workflow_id(self): @@ -178,12 +179,12 @@ def annotations(self): return self._annotations @property - def auth(self): + def auth_role(self): """ The authorization method with which to execute the workflow. - :return: flytekit.models.launch_plan.Auth + :return: flytekit.models.common.Auth """ - return self._auth + return self._auth_role def to_flyte_idl(self): """ @@ -196,7 +197,7 @@ def to_flyte_idl(self): fixed_inputs=self.fixed_inputs.to_flyte_idl(), labels=self.labels.to_flyte_idl(), annotations=self.annotations.to_flyte_idl(), - auth=self.auth.to_flyte_idl(), + auth_role=self.auth_role.to_flyte_idl(), ) @classmethod @@ -212,7 +213,7 @@ def from_flyte_idl(cls, pb2_object): fixed_inputs=_literals.LiteralMap.from_flyte_idl(pb2_object.fixed_inputs), labels=_common.Labels.from_flyte_idl(pb2_object.labels), annotations=_common.Annotations.from_flyte_idl(pb2_object.annotations), - auth=Auth.from_flyte_idl(pb2_object.auth), + auth_role=_common.AuthRole.from_flyte_idl(pb2_object.auth_role), ) diff --git a/setup.py b/setup.py index 30edd543aa..afef1f4edf 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "flyteidl>=0.17.32,<1.0.0", "click>=6.6,<8.0", "croniter>=0.3.20,<4.0.0", - "deprecation>=2.0,<3.0", + "deprecated>=1.0,<2.0", "boto3>=1.4.4,<2.0", "python-dateutil<2.8.1,>=2.1", "grpcio>=1.3.0,<2.0", diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py index 456ba59557..8cc83daa83 100644 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ b/tests/flytekit/unit/common_tests/test_launch_plan.py @@ -21,7 +21,7 @@ def test_default_assumable_iam_role(): } ) lp = workflow_to_test.create_launch_plan() - assert lp.auth.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' + assert lp.auth_role.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' def test_hard_coded_assumable_iam_role(): @@ -33,7 +33,7 @@ def test_hard_coded_assumable_iam_role(): } ) lp = workflow_to_test.create_launch_plan(assumable_iam_role='override') - assert lp.auth.assumable_iam_role == 'override' + assert lp.auth_role.assumable_iam_role == 'override' def test_default_deprecated_role(): @@ -48,7 +48,7 @@ def test_default_deprecated_role(): } ) lp = workflow_to_test.create_launch_plan() - assert lp.auth.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' + assert lp.auth_role.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' def test_hard_coded_deprecated_role(): @@ -60,7 +60,7 @@ def test_hard_coded_deprecated_role(): } ) lp = workflow_to_test.create_launch_plan(role='override') - assert lp.auth.assumable_iam_role == 'override' + assert lp.auth_role.assumable_iam_role == 'override' def test_kubernetes_service_account(): @@ -72,7 +72,7 @@ def test_kubernetes_service_account(): } ) lp = workflow_to_test.create_launch_plan(kubernetes_service_account='kube-service-acct') - assert lp.auth.kubernetes_service_account == 'kube-service-acct' + assert lp.auth_role.kubernetes_service_account == 'kube-service-acct' def test_fixed_inputs(): @@ -314,7 +314,7 @@ def test_serialize(): s = lp.serialize() assert s.workflow_id == _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl() - assert s.auth.assumable_iam_role == 'iam_role' + assert s.auth_role.assumable_iam_role == 'iam_role' assert s.default_inputs.parameters['default_input'].default.scalar.primitive.integer == 5 diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index ba1e750faa..3d56636eb1 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -115,7 +115,7 @@ def test_execution_notification_overrides(mock_client_factory): ) ) - engine.FlyteLaunchPlan(m).execute( + engine.FlyteLaunchPlan(m).launch( 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[] ) @@ -161,7 +161,7 @@ def test_execution_notification_soft_overrides(mock_client_factory): notification = _common_models.Notification([0, 1, 2], email=_common_models.EmailNotification(["me@place.com"])) - engine.FlyteLaunchPlan(m).execute( + engine.FlyteLaunchPlan(m).launch( 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[notification] ) @@ -252,7 +252,7 @@ def test_execution_annotation_overrides(mock_client_factory): ) annotations = _common_models.Annotations({"my": "annotation"}) - engine.FlyteLaunchPlan(m).execute( + engine.FlyteLaunchPlan(m).launch( 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations ) diff --git a/tests/flytekit/unit/models/test_common.py b/tests/flytekit/unit/models/test_common.py index 4a0bd203b4..bab0cb0124 100644 --- a/tests/flytekit/unit/models/test_common.py +++ b/tests/flytekit/unit/models/test_common.py @@ -66,3 +66,17 @@ def test_annotations(): assert obj.values == {"my": "annotation"} obj2 = _common.Annotations.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj + + +def test_auth_role(): + obj = _common.AuthRole(assumable_iam_role="rollie-pollie") + assert obj.assumable_iam_role == "rollie-pollie" + assert not obj.kubernetes_service_account + obj2 = _common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 + + obj = _common.AuthRole(kubernetes_service_account="service-account-name") + assert obj.kubernetes_service_account == "service-account-name" + assert not obj.assumable_iam_role + obj2 = _common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 4d1df37755..c1683440ca 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -33,17 +33,3 @@ def test_lp_closure(): assert obj == obj2 assert obj2.expected_inputs == parameter_map assert obj2.expected_outputs == variable_map - - -def test_auth(): - obj = launch_plan.Auth(assumable_iam_role="rollie-pollie") - assert obj.assumable_iam_role == "rollie-pollie" - assert not obj.kubernetes_service_account - obj2 = launch_plan.Auth.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 - - obj = launch_plan.Auth(kubernetes_service_account="service-account-name") - assert obj.kubernetes_service_account == "service-account-name" - assert not obj.assumable_iam_role - obj2 = launch_plan.Auth.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2