From 2d8924e02ca234e42ede7530ecff9fd78c7e6995 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Wed, 28 Apr 2021 13:52:18 -0400 Subject: [PATCH] add new control plane classes (#425) * implement new control plane classes Signed-off-by: cosmicBboy * revert dep changes Signed-off-by: cosmicBboy * remove unneeded mock integration test files Signed-off-by: cosmicBboy * remove pytest.ini Signed-off-by: cosmicBboy * add integration tests to ci, update reqs Signed-off-by: cosmicBboy * add unit tests Signed-off-by: cosmicBboy * lint Signed-off-by: cosmicBboy * address comments @wild-endeavor Signed-off-by: cosmicBboy Signed-off-by: Max Hoffman --- .github/workflows/pythonbuild.yml | 5 + .gitignore | 2 + dev-requirements.in | 1 + dev-requirements.txt | 272 ++++++++++++++++- doc-requirements.txt | 19 +- docs/source/design/control_plane.rst | 4 +- flytekit/control_plane/component_nodes.py | 136 +++++++++ flytekit/control_plane/identifier.py | 137 +++++++++ flytekit/control_plane/interface.py | 24 ++ flytekit/control_plane/launch_plan.py | 196 ++++++++++++ flytekit/control_plane/nodes.py | 281 ++++++++++++++++++ flytekit/control_plane/tasks/__init__.py | 0 flytekit/control_plane/tasks/executions.py | 132 ++++++++ flytekit/control_plane/tasks/task.py | 95 ++++++ flytekit/control_plane/workflow.py | 167 +++++++++++ flytekit/control_plane/workflow_execution.py | 150 ++++++++++ requirements-spark2.txt | 15 +- requirements.txt | 15 +- .../control_plane/mock_flyte_repo/.gitignore | 1 + .../control_plane/mock_flyte_repo/README.md | 4 + .../control_plane/mock_flyte_repo/__init__.py | 0 .../mock_flyte_repo/in_container.mk | 24 ++ .../mock_flyte_repo/workflows/Dockerfile | 35 +++ .../mock_flyte_repo/workflows/Makefile | 208 +++++++++++++ .../mock_flyte_repo/workflows/__init__.py | 0 .../workflows/basic/__init__.py | 0 .../workflows/basic/basic_workflow.py | 54 ++++ .../workflows/basic/hello_world.py | 40 +++ .../mock_flyte_repo/workflows/requirements.in | 4 + .../workflows/requirements.txt | 136 +++++++++ .../mock_flyte_repo/workflows/sandbox.config | 7 + .../control_plane/test_workflow.py | 90 ++++++ tests/flytekit/unit/control_plane/__init__.py | 0 .../unit/control_plane/tasks/test_task.py | 34 +++ .../unit/control_plane/test_identifier.py | 77 +++++ .../unit/control_plane/test_workflow.py | 23 ++ 36 files changed, 2382 insertions(+), 6 deletions(-) create mode 100644 flytekit/control_plane/component_nodes.py create mode 100644 flytekit/control_plane/identifier.py create mode 100644 flytekit/control_plane/interface.py create mode 100644 flytekit/control_plane/launch_plan.py create mode 100644 flytekit/control_plane/nodes.py create mode 100644 flytekit/control_plane/tasks/__init__.py create mode 100644 flytekit/control_plane/tasks/executions.py create mode 100644 flytekit/control_plane/tasks/task.py create mode 100644 flytekit/control_plane/workflow.py create mode 100644 flytekit/control_plane/workflow_execution.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/README.md create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config create mode 100644 tests/flytekit/integration/control_plane/test_workflow.py create mode 100644 tests/flytekit/unit/control_plane/__init__.py create mode 100644 tests/flytekit/unit/control_plane/tasks/test_task.py create mode 100644 tests/flytekit/unit/control_plane/test_identifier.py create mode 100644 tests/flytekit/unit/control_plane/test_workflow.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index cc0d7f0a52e..f56f8b300f7 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -44,6 +44,11 @@ jobs: - name: Test with coverage run: | coverage run -m pytest tests/flytekit/unit tests/scripts plugins/tests + - name: Integration Tests with coverage + # https://github.com/actions/runner/issues/241#issuecomment-577360161 + shell: 'script -q -e -c "bash {0}"' + run: | + coverage run --append -m pytest tests/flytekit/integration - name: Codecov uses: codecov/codecov-action@v1 with: diff --git a/.gitignore b/.gitignore index 43cf5e71353..ca971b1a315 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ dist .python-version _build/ docs/source/generated/ +.pytest-flyte +htmlcov diff --git a/dev-requirements.in b/dev-requirements.in index 6fa5ebb0805..dd38d0f7fc5 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,6 @@ -c requirements.txt +git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte black coverage[toml] flake8 diff --git a/dev-requirements.txt b/dev-requirements.txt index 9249420a938..1cc519a6f12 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,6 +4,10 @@ # # make dev-requirements.txt # +-e file:.#egg=flytekit + # via + # -c requirements.txt + # pytest-flyte appdirs==1.4.4 # via # -c requirements.txt @@ -11,18 +15,86 @@ appdirs==1.4.4 attrs==20.3.0 # via # -c requirements.txt + # jsonschema # pytest +<<<<<<< HEAD black==21.4b2 +======= + # pytest-docker + # scantree +bcrypt==3.2.0 + # via + # -c requirements.txt + # paramiko +black==20.8b1 +>>>>>>> add new control plane classes (#425) # via # -c requirements.txt # -r dev-requirements.in # flake8-black +cached-property==1.5.2 + # via docker-compose +certifi==2020.12.5 + # via + # -c requirements.txt + # requests +cffi==1.14.5 + # via + # -c requirements.txt + # bcrypt + # cryptography + # pynacl +chardet==4.0.0 + # via + # -c requirements.txt + # requests click==7.1.2 # via # -c requirements.txt # black + # flytekit coverage[toml]==5.5 # via -r dev-requirements.in +croniter==1.0.12 + # via + # -c requirements.txt + # flytekit +cryptography==3.4.7 + # via + # -c requirements.txt + # paramiko +dataclasses-json==0.5.2 + # via + # -c requirements.txt + # flytekit +decorator==5.0.7 + # via + # -c requirements.txt + # retry +deprecated==1.2.12 + # via + # -c requirements.txt + # flytekit +dirhash==0.2.1 + # via + # -c requirements.txt + # flytekit +distro==1.5.0 + # via docker-compose +docker-compose==1.29.1 + # via + # pytest-docker + # pytest-flyte +docker-image-py==0.1.10 + # via + # -c requirements.txt + # flytekit +docker[ssh]==5.0.0 + # via docker-compose +dockerpty==0.4.1 + # via docker-compose +docopt==0.6.2 + # via docker-compose flake8-black==0.2.1 # via -r dev-requirements.in flake8-isort==4.0.0 @@ -32,12 +104,57 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort +flyteidl==0.18.38 + # via + # -c requirements.txt + # flytekit +grpcio==1.37.0 + # via + # -c requirements.txt + # flytekit +idna==2.10 + # via + # -c requirements.txt + # requests +importlib-metadata==4.0.1 + # via + # -c requirements.txt + # flake8 + # jsonschema + # keyring + # pluggy + # pytest iniconfig==1.1.1 # via pytest isort==5.8.0 # via # -r dev-requirements.in # flake8-isort +jinja2==2.11.3 + # via + # -c requirements.txt + # pytest-flyte +jsonschema==3.2.0 + # via + # -c requirements.txt + # docker-compose +keyring==23.0.1 + # via + # -c requirements.txt + # flytekit +markupsafe==1.1.1 + # via + # -c requirements.txt + # jinja2 +marshmallow-enum==1.5.1 + # via + # -c requirements.txt + # dataclasses-json +marshmallow==3.11.1 + # via + # -c requirements.txt + # dataclasses-json + # marshmallow-enum mccabe==0.6.1 # via flake8 mock==4.0.3 @@ -47,38 +164,155 @@ mypy-extensions==0.4.3 # -c requirements.txt # black # mypy + # typing-inspect mypy==0.812 # via -r dev-requirements.in +natsort==7.1.1 + # via + # -c requirements.txt + # flytekit +numpy==1.20.2 + # via + # -c requirements.txt + # pandas + # pyarrow packaging==20.9 # via # -c requirements.txt # pytest +pandas==1.2.4 + # via + # -c requirements.txt + # flytekit +paramiko==2.7.2 + # via + # -c requirements.txt + # docker pathspec==0.8.1 # via # -c requirements.txt # black + # scantree pluggy==0.13.1 # via pytest +protobuf==3.15.8 + # via + # -c requirements.txt + # flyteidl + # flytekit py==1.10.0 # via # -c requirements.txt # pytest + # retry +pyarrow==3.0.0 + # via + # -c requirements.txt + # flytekit pycodestyle==2.7.0 # via flake8 +pycparser==2.20 + # via + # -c requirements.txt + # cffi pyflakes==2.3.1 # via flake8 +pynacl==1.4.0 + # via + # -c requirements.txt + # paramiko pyparsing==2.4.7 # via # -c requirements.txt # packaging -pytest==6.2.3 +pyrsistent==0.17.3 + # via + # -c requirements.txt + # jsonschema +pytest-docker==0.10.1 + # via pytest-flyte +git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte # via -r dev-requirements.in +pytest==6.2.3 + # via + # -r dev-requirements.in + # pytest-docker + # pytest-flyte +python-dateutil==2.8.1 + # via + # -c requirements.txt + # croniter + # flytekit + # pandas +python-dotenv==0.17.0 + # via docker-compose +pytimeparse==1.1.8 + # via + # -c requirements.txt + # flytekit +pytz==2018.4 + # via + # -c requirements.txt + # flytekit + # pandas +pyyaml==5.4.1 + # via + # -c requirements.txt + # docker-compose regex==2021.4.4 # via # -c requirements.txt # black + # docker-image-py +requests==2.25.1 + # via + # -c requirements.txt + # docker + # docker-compose + # flytekit + # responses +responses==0.13.2 + # via + # -c requirements.txt + # flytekit +retry==0.9.2 + # via + # -c requirements.txt + # flytekit +scantree==0.0.1 + # via + # -c requirements.txt + # dirhash +six==1.15.0 + # via + # -c requirements.txt + # bcrypt + # dockerpty + # flytekit + # grpcio + # jsonschema + # protobuf + # pynacl + # python-dateutil + # responses + # scantree + # websocket-client +sortedcontainers==2.3.0 + # via + # -c requirements.txt + # flytekit +statsd==3.3.0 + # via + # -c requirements.txt + # flytekit +stringcase==1.2.0 + # via + # -c requirements.txt + # dataclasses-json testfixtures==6.17.1 # via flake8-isort +texttable==1.6.3 + # via docker-compose toml==0.10.2 # via # -c requirements.txt @@ -97,4 +331,40 @@ typed-ast==1.4.3 typing-extensions==3.7.4.3 # via # -c requirements.txt +<<<<<<< HEAD +======= + # black + # importlib-metadata +>>>>>>> add new control plane classes (#425) # mypy + # typing-inspect +typing-inspect==0.6.0 + # via + # -c requirements.txt + # dataclasses-json +urllib3==1.25.11 + # via + # -c requirements.txt + # flytekit + # requests + # responses +websocket-client==0.58.0 + # via + # docker + # docker-compose +wheel==0.36.2 + # via + # -c requirements.txt + # flytekit +wrapt==1.12.1 + # via + # -c requirements.txt + # deprecated + # flytekit +zipp==3.4.1 + # via + # -c requirements.txt + # importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/doc-requirements.txt b/doc-requirements.txt index 5b6450538c8..b201fbca96b 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -99,10 +99,14 @@ entrypoints==0.3 # nbconvert # papermill <<<<<<< HEAD +<<<<<<< HEAD flyteidl==0.18.39 ======= flyteidl==0.18.37 >>>>>>> Sqlalchemy Task (#445) +======= +flyteidl==0.18.38 +>>>>>>> add new control plane classes (#425) # via flytekit furo==2021.4.11b34 # via -r doc-requirements.in @@ -121,7 +125,9 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -406,10 +412,21 @@ traitlets==5.0.5 <<<<<<< HEAD ======= typed-ast==1.4.3 +<<<<<<< HEAD # via black >>>>>>> Sqlalchemy Task (#445) typing-extensions==3.7.4.3 # via typing-inspect +======= + # via + # astroid + # black +typing-extensions==3.7.4.3 + # via + # black + # importlib-metadata + # typing-inspect +>>>>>>> add new control plane classes (#425) typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/docs/source/design/control_plane.rst b/docs/source/design/control_plane.rst index 7b8f49539b6..1d24b50b8de 100644 --- a/docs/source/design/control_plane.rst +++ b/docs/source/design/control_plane.rst @@ -3,9 +3,9 @@ ############################ Control Plane Objects ############################ -For those who require programmatic access to the control place, certain APIs are available through "control plane classes". +For those who require programmatic access to the control plane, certain APIs are available through "control plane classes". -.. note:: +.. warning:: The syntax of this section, while it will continue to work, is subject to change. diff --git a/flytekit/control_plane/component_nodes.py b/flytekit/control_plane/component_nodes.py new file mode 100644 index 00000000000..10434ab8301 --- /dev/null +++ b/flytekit/control_plane/component_nodes.py @@ -0,0 +1,136 @@ +import logging as _logging +from typing import Dict + +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.models import task as _task_model +from flytekit.models.core import workflow as _workflow_model + + +class FlyteTaskNode(_workflow_model.TaskNode): + def __init__(self, flyte_task: "flytekit.control_plane.tasks.task.FlyteTask"): + self._flyte_task = flyte_task + super(FlyteTaskNode, self).__init__(None) + + @property + def reference_id(self) -> _identifier.Identifier: + """A globally unique identifier for the task.""" + return self._flyte_task.id + + @property + def flyte_task(self) -> "flytekit.control_plane.tasks.task.FlyteTask": + return self._flyte_task + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.TaskNode, + tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + ) -> "FlyteTaskNode": + """ + Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the + FlyteTask control plane. + + :param base_model: + :param tasks: + """ + from flytekit.control_plane.tasks import task as _task + + if base_model.reference_id in tasks: + task = tasks[base_model.reference_id] + _logging.info(f"Found existing task template for {task.id}, will not retrieve from Admin") + flyte_task = _task.FlyteTask.promote_from_model(task) + return cls(flyte_task) + + # if not found, fetch it from Admin + _logging.debug(f"Fetching task template for {base_model.reference_id} from Admin") + return cls( + _task.FlyteTask.fetch( + base_model.reference_id.project, + base_model.reference_id.domain, + base_model.reference_id.name, + base_model.reference_id.version, + ) + ) + + +class FlyteWorkflowNode(_workflow_model.WorkflowNode): + def __init__( + self, + flyte_workflow: "flytekit.control_plane.workflow.FlyteWorkflow" = None, + flyte_launch_plan: "flytekit.control_plane.launch_plan.FlyteLaunchPlan" = None, + ): + if flyte_workflow and flyte_launch_plan: + raise _system_exceptions.FlyteSystemException( + "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " + f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", + ) + + self._flyte_workflow = flyte_workflow + self._flyte_launch_plan = flyte_launch_plan + super(FlyteWorkflowNode, self).__init__( + launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, + sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, + ) + + def __repr__(self) -> str: + if self.flyte_workflow is not None: + return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" + return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" + + @property + def launchplan_ref(self) -> _identifier.Identifier: + """A globally unique identifier for the launch plan, which should map to Admin.""" + return self._flyte_launch_plan.id if self._flyte_launch_plan else None + + @property + def sub_workflow_ref(self): + return self._flyte_workflow.id if self._flyte_workflow else None + + @property + def flyte_launch_plan(self) -> "flytekit.control_plane.launch_plan.FlyteLaunchPlan": + return self._flyte_launch_plan + + @property + def flyte_workflow(self) -> "flytekit.control_plane.workflow.FlyteWorkflow": + return self._flyte_workflow + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.WorkflowNode, + sub_workflows: Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate], + tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + ) -> "FlyteWorkflowNode": + from flytekit.control_plane import launch_plan as _launch_plan + from flytekit.control_plane import workflow as _workflow + + fetch_args = ( + base_model.reference.project, + base_model.reference.domain, + base_model.reference.name, + base_model.reference.version, + ) + + if base_model.launch_plan_ref is not None: + return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args)) + elif base_model.sub_workflow_ref is not None: + # the workflow tempaltes for sub-workflows should have been included in the original response + if base_model.reference in sub_workflows: + return cls( + flyte_workflow=_workflow.FlyteWorkflow.promote_from_model( + sub_workflows[base_model.reference], + sub_workflows=sub_workflows, + tasks=tasks, + ) + ) + + # If not found for some reason, fetch it from Admin again. The reason there is a warning here but not for + # tasks is because sub-workflows should always be passed along. Ideally subworkflows are never even + # registered with Admin, so fetching from Admin ideelly doesn't return anything + _logging.warning(f"Your subworkflow with id {base_model.reference} is not included in the promote call.") + return cls(flyte_workflow=_workflow.FlyteWorkflow.fetch(*fetch_args)) + + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor launchplan specified." + ) diff --git a/flytekit/control_plane/identifier.py b/flytekit/control_plane/identifier.py new file mode 100644 index 00000000000..611c9af6390 --- /dev/null +++ b/flytekit/control_plane/identifier.py @@ -0,0 +1,137 @@ +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models.core import identifier as _core_identifier + + +class Identifier(_core_identifier.Identifier): + + _STRING_TO_TYPE_MAP = { + "lp": _core_identifier.ResourceType.LAUNCH_PLAN, + "wf": _core_identifier.ResourceType.WORKFLOW, + "tsk": _core_identifier.ResourceType.TASK, + } + _TYPE_TO_STRING_MAP = {v: k for k, v in _STRING_TO_TYPE_MAP.items()} + + @classmethod + def promote_from_model(cls, base_model: _core_identifier.Identifier) -> "Identifier": + return cls(base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version) + + @classmethod + def from_urn(cls, urn: str) -> "Identifier": + """ + Parses a string urn in the correct format into an identifier + """ + segments = urn.split(":") + if len(segments) != 5: + raise _user_exceptions.FlyteValueException( + urn, + "The provided string was not in a parseable format. The string for an identifier must be in the " + "format entity_type:project:domain:name:version.", + ) + + resource_type, project, domain, name, version = segments + + if resource_type not in cls._STRING_TO_TYPE_MAP: + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an identifier must be one of: " + f"{list(cls._STRING_TO_TYPE_MAP.keys())}. ", + ) + + return cls(cls._STRING_TO_TYPE_MAP[resource_type], project, domain, name, version) + + def __str__(self): + return ( + f"{type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, '')}:" + f"{self.project}:" + f"{self.domain}:" + f"{self.name}:" + f"{self.version}" + ) + + +class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier): + @classmethod + def promote_from_model( + cls, base_model: _core_identifier.WorkflowExecutionIdentifier + ) -> "WorkflowExecutionIdentifier": + return cls(base_model.project, base_model.domain, base_model.name) + + @classmethod + def from_urn(cls, string: str) -> "WorkflowExecutionIdentifier": + """ + Parses a string in the correct format into an identifier + """ + segments = string.split(":") + if len(segments) != 4: + raise _user_exceptions.FlyteValueException( + string, + "The provided string was not in a parseable format. The string for an identifier must be in the format" + " ex:project:domain:name.", + ) + + resource_type, project, domain, name = segments + + if resource_type != "ex": + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", + ) + + return cls(project, domain, name) + + def __str__(self): + return f"ex:{self.project}:{self.domain}:{self.name}" + + +class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier): + @classmethod + def promote_from_model(cls, base_model: _core_identifier.TaskExecutionIdentifier) -> "TaskExecutionIdentifier": + return cls( + task_id=base_model.task_id, + node_execution_id=base_model.node_execution_id, + retry_attempt=base_model.retry_attempt, + ) + + @classmethod + def from_urn(cls, string: str) -> "TaskExecutionIdentifier": + """ + Parses a string in the correct format into an identifier + """ + segments = string.split(":") + if len(segments) != 10: + raise _user_exceptions.FlyteValueException( + string, + "The provided string was not in a parseable format. The string for an identifier must be in the format" + " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", + ) + + resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments + + if resource_type != "te": + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", + ) + + return cls( + task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), + node_execution_id=_core_identifier.NodeExecutionIdentifier( + node_id=node_id, + execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + ), + retry_attempt=int(retry), + ) + + def __str__(self): + return ( + "te:" + f"{self.node_execution_id.execution_id.project}:" + f"{self.node_execution_id.execution_id.domain}:" + f"{self.node_execution_id.execution_id.name}:" + f"{self.node_execution_id.node_id}:" + f"{self.task_id.project}:" + f"{self.task_id.domain}:" + f"{self.task_id.name}:" + f"{self.task_id.version}:" + f"{self.retry_attempt}" + ) diff --git a/flytekit/control_plane/interface.py b/flytekit/control_plane/interface.py new file mode 100644 index 00000000000..1a7b2c6c15b --- /dev/null +++ b/flytekit/control_plane/interface.py @@ -0,0 +1,24 @@ +from typing import Any, Dict, List, Tuple + +from flytekit.control_plane import nodes as _nodes +from flytekit.models import interface as _interface_models +from flytekit.models import literals as _literal_models + + +class TypedInterface(_interface_models.TypedInterface): + @classmethod + def promote_from_model(cls, model): + """ + :param flytekit.models.interface.TypedInterface model: + :rtype: TypedInterface + """ + return cls(model.inputs, model.outputs) + + def create_bindings_for_inputs( + self, map_of_bindings: Dict[str, Any] + ) -> Tuple[List[_literal_models.Binding], List[_nodes.FlyteNode]]: + """ + :param: map_of_bindings: this can be scalar primitives, it can be node output references, lists, etc. + :raises: flytekit.common.exceptions.user.FlyteAssertion + """ + return [], [] diff --git a/flytekit/control_plane/launch_plan.py b/flytekit/control_plane/launch_plan.py new file mode 100644 index 00000000000..ca189fd33dc --- /dev/null +++ b/flytekit/control_plane/launch_plan.py @@ -0,0 +1,196 @@ +import uuid as _uuid +from typing import Any, List + +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interface +from flytekit.control_plane import nodes as _nodes +from flytekit.control_plane import workflow_execution as _workflow_execution +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution_models +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import literals as _literal_models +from flytekit.models.core import identifier as _identifier_model + + +class FlyteLaunchPlan(_launch_plan_models.LaunchPlanSpec): + def __init__(self, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = None + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + + @classmethod + def promote_from_model(cls, model: _launch_plan_models.LaunchPlanSpec) -> "FlyteLaunchPlan": + return cls( + workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), + default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + ) + + @_exception_scopes.system_entry_point + def register(self, project, domain, name, version): + # NOTE: does this need to be implemented in the control plane? + pass + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteLaunchPlan": + """ + This function uses the engine loader to call create a hydrated task from Admin. + :param project: + :param domain: + :param name: + :param version: + """ + from flytekit.control_plane import workflow as _workflow + + launch_plan_id = _identifier.Identifier( + _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version + ) + + lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) + flyte_lp = cls.promote_from_model(lp.spec) + flyte_lp._id = lp.id + + # TODO: Add a test for this, and this function as a whole + wf_id = flyte_lp.workflow_id + lp_wf = _workflow.FlyteWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) + flyte_lp._interface = lp_wf.interface + return flyte_lp + + @_exception_scopes.system_entry_point + def serialize(self): + """ + Serializing a launch plan should produce an object similar to what the registration step produces, + in preparation for actual registration to Admin. + + :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan + """ + # NOTE: does this need to be implemented in the control plane? + pass + + @property + def id(self) -> _identifier.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def workflow_id(self) -> _identifier.Identifier: + return self._workflow_id + + @property + def interface(self) -> _interface.TypedInterface: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + @_exception_scopes.system_entry_point + def validate(self): + # TODO: Validate workflow is satisfied + pass + + @_exception_scopes.system_entry_point + def update(self, state: _launch_plan_models.LaunchPlanState): + if not self.id: + raise _user_exceptions.FlyteAssertion( + "Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " + "or register the identifier first" + ) + return _flyte_engine.get_client().update_launch_plan(self.id, state) + + @_exception_scopes.system_entry_point + def launch_with_literals( + self, + project: str, + domain: str, + literal_inputs: _literal_models.LiteralMap, + name: str = None, + notification_overrides: List[_common_models.Notification] = None, + label_overrides: _common_models.Labels = None, + annotation_overrides: _common_models.Annotations = None, + ) -> _workflow_execution.FlyteWorkflowExecution: + """ + Executes the launch plan and returns the execution identifier. This version of execution is meant for when + you already have a LiteralMap of inputs. + + :param project: + :param domain: + :param literal_inputs: Inputs to the execution. + :param name: If specified, an execution will be created with this name. Note: the name must + be unique within the context of the project and domain. + :param notification_overrides: If specified, these are the notifications that will be honored for this + execution. An empty list signals to disable all notifications. + :param label_overrides: + :param annotation_overrides: + """ + # Kubernetes requires names starting with an alphabet for some resources. + name = name or "f" + _uuid.uuid4().hex[:19] + disable_all = notification_overrides == [] + if disable_all: + notification_overrides = None + else: + notification_overrides = _execution_models.NotificationList(notification_overrides or []) + disable_all = None + + client = _flyte_engine.get_client() + try: + exec_id = client.create_execution( + project, + domain, + name, + _execution_models.ExecutionSpec( + self.id, + _execution_models.ExecutionMetadata( + _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, + "sdk", # TODO: get principle + 0, # TODO: Detect nesting + ), + notifications=notification_overrides, + disable_all=disable_all, + labels=label_overrides, + annotations=annotation_overrides, + ), + literal_inputs, + ) + except _user_exceptions.FlyteEntityAlreadyExistsException: + exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) + return _workflow_execution.FlyteWorkflowExecution.promote_from_model(client.get_execution(exec_id)) + + @_exception_scopes.system_entry_point + def __call__(self, *args, **input_map: Any) -> _nodes.FlyteNode: + raise NotImplementedError + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})" diff --git a/flytekit/control_plane/nodes.py b/flytekit/control_plane/nodes.py new file mode 100644 index 00000000000..287695bae0b --- /dev/null +++ b/flytekit/control_plane/nodes.py @@ -0,0 +1,281 @@ +import logging as _logging +import os as _os +from typing import Any, Dict, List, Optional + +from flyteidl.core import literals_pb2 as _literals_pb2 + +from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions +from flytekit.common import constants as _constants +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact_mixin +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.common.utils import _dnsify +from flytekit.control_plane import component_nodes as _component_nodes +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane.tasks import executions as _task_executions +from flytekit.core.promise import NodeOutput +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.models import literals as _literal_models +from flytekit.models import node_execution as _node_execution_models +from flytekit.models import task as _task_model +from flytekit.models.core import execution as _execution_models +from flytekit.models.core import workflow as _workflow_model + + +class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + def __init__( + self, + id, + upstream_nodes, + bindings, + metadata, + flyte_task: "flytekit.control_plan.tasks.task.FlyteTask" = None, + flyte_workflow: "flytekit.control_plane.workflow.FlyteWorkflow" = None, + flyte_launch_plan=None, + flyte_branch=None, + parameter_mapping=True, + ): + non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch])) + if len(non_none_entities) != 1: + raise _user_exceptions.FlyteAssertion( + "An Flyte node must have one underlying entity specified at once. Received the following " + "entities: {}".format(non_none_entities) + ) + + workflow_node = None + if flyte_workflow is not None: + workflow_node = _component_nodes.FlyteWorkflowNode(flyte_workflow=flyte_workflow) + elif flyte_launch_plan is not None: + workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) + + super(FlyteNode, self).__init__( + id=_dnsify(id) if id else None, + metadata=metadata, + inputs=bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + task_node=_component_nodes.FlyteTaskNode(flyte_task) if flyte_task else None, + workflow_node=workflow_node, + branch_node=flyte_branch, + ) + self._upstream = upstream_nodes + + @classmethod + def promote_from_model( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate]], + tasks: Optional[Dict[_identifier.Identifier, _task_model.TaskTemplate]], + ) -> "FlyteNode": + id = model.id + if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: + _logging.warning(f"Should not call promote from model on a start node or end node {model}") + return None + + flyte_task_node, flyte_workflow_node = None, None + if model.task_node is not None: + flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks) + elif model.workflow_node is not None: + flyte_workflow_node = _component_nodes.FlyteWorkflowNode.promote_from_model( + model.workflow_node, sub_workflows, tasks + ) + else: + raise _system_exceptions.FlyteSystemException("Bad Node model, neither task nor workflow detected.") + + # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a + # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + for model_input in model.inputs: + if ( + model_input.binding.promise is not None + and model_input.binding.promise.node_id == _constants.START_NODE_ID + ): + model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID + + if flyte_task_node is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + flyte_task=flyte_task_node.flyte_task, + ) + elif flyte_workflow_node is not None: + if flyte_workflow_node.flyte_workflow is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + flyte_workflow=flyte_workflow_node.flyte_workflow, + ) + elif flyte_workflow_node.flyte_launch_plan is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=models.inputs, + metadata=model.metadata, + flyte_launch_plan=flyte_workflow_node.flyte_launch_plan, + ) + raise _system_exceptions.FlyteSystemException( + "Bad FlyteWorkflowNode model, both launch plan and workflow are None" + ) + raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") + + @property + def upstream_nodes(self) -> List["FlyteNode"]: + return self._upstream + + @property + def upstream_node_ids(self) -> List[str]: + return list(sorted(n.id for n in self.upstream_nodes)) + + @property + def outputs(self) -> Dict[str, NodeOutput]: + return self._outputs + + def assign_id_and_return(self, id: str): + if self.id: + raise _user_exceptions.FlyteAssertion( + f"Error assigning ID: {id} because {self} is already assigned. Has this node been ssigned to another " + "workflow already?" + ) + self._id = _dnsify(id) if id else None + self._metadata.name = id + return self + + def with_overrides(self, *args, **kwargs): + # TODO: Implement overrides + raise NotImplementedError("Overrides are not supported in Flyte yet.") + + def __repr__(self) -> str: + return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})" + + +class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteNodeExecution, self).__init__(*args, **kwargs) + self._task_executions = None + self._workflow_executions = None + self._inputs = None + self._outputs = None + + @property + def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]: + return self._task_executions or [] + + @property + def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]: + return self._workflow_executions or [] + + @property + def executions(self) -> _artifact_mixin.ExecutionArtifact: + return self.task_executions or self.workflow_executions or [] + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dicatated by the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_node_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + + # TODO: need to convert flyte literals to python types. For now just use literals + # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + self._inputs = input_map + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_node_execution_data(self.id) + + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + output_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + self._outputs = output_map + return self._outputs + + @property + def error(self) -> _execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting error information." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + _execution_models.NodeExecutionPhase.ABORTED, + _execution_models.NodeExecutionPhase.FAILED, + _execution_models.NodeExecutionPhase.SKIPPED, + _execution_models.NodeExecutionPhase.SUCCEEDED, + _execution_models.NodeExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> "FlyteNodeExecution": + return cls(closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + if not self.is_complete or self.task_executions is not None: + client = _flyte_engine.get_client() + self._closure = client.get_node_execution(self.id).closure + self._task_executions = [ + _task_executions.FlyteTaskExecution.promote_from_model(t) + for t in _iterate_task_executions(client, self.id) + ] + # TODO: sync sub-workflows as well + + def _sync_closure(self): + """ + Syncs the closure of the underlying execution artifact with the state observed by the platform. + """ + self._closure = _flyte_engine.get_client().get_node_execution(self.id).closure diff --git a/flytekit/control_plane/tasks/__init__.py b/flytekit/control_plane/tasks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/flytekit/control_plane/tasks/executions.py b/flytekit/control_plane/tasks/executions.py new file mode 100644 index 00000000000..838746f3929 --- /dev/null +++ b/flytekit/control_plane/tasks/executions.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, Optional + +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact_mixin +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models.admin import task_execution as _task_execution_model +from flytekit.models.core import execution as _execution_models + + +class FlyteTaskExecution(_task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteTaskExecution, self).__init__(*args, **kwargs) + self._inputs = None + self._outputs = None + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + _execution_models.TaskExecutionPhase.ABORTED, + _execution_models.TaskExecutionPhase.FAILED, + _execution_models.TaskExecutionPhase.SUCCEEDED, + } + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs of the task execution in the standard Python format that is produced by + the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_task_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + + self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs of the task execution, if available, in the standard Python format that is produced by + the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_task_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as t: + tmp_name = _os.path.join(t.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + output_map = _literal_models.LiteralMap({}) + + self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + return self._outputs + + @property + def error(self) -> Optional[_execution_models.ExecutionError]: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before requesting error information." + ) + return self.closure.error + + def get_child_executions(self, filters=None): + from flytekit.control_plane import nodes as _nodes + + if not self.is_parent: + raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.") + client = _flyte_engine.get_client() + models = { + v.id.node_id: v + for v in _iterate_node_executions(client, task_execution_identifier=self.id, filters=filters) + } + + return {k: _nodes.FlyteNodeExecution.promote_from_model(v) for k, v in models.items()} + + @classmethod + def promote_from_model(cls, base_model: _task_execution_model.TaskExecution) -> "FlyteTaskExecution": + return cls( + closure=base_model.closure, + id=base_model.id, + input_uri=base_model.input_uri, + is_parent=base_model.is_parent, + ) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + self._sync_closure() + + def _sync_closure(self): + """ + Syncs the closure of the underlying execution artifact with the state observed by the platform. + """ + self._closure = _flyte_engine.get_client().get_task_execution(self.id).closure diff --git a/flytekit/control_plane/tasks/task.py b/flytekit/control_plane/tasks/task.py new file mode 100644 index 00000000000..71f159e5234 --- /dev/null +++ b/flytekit/control_plane/tasks/task.py @@ -0,0 +1,95 @@ +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interfaces +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import common as _common_model +from flytekit.models import task as _task_model +from flytekit.models.admin import common as _admin_common +from flytekit.models.core import identifier as _identifier_model + + +class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): + def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): + super(FlyteTask, self).__init__( + id, + type, + metadata, + interface, + custom, + container=container, + task_type_version=task_type_version, + config=config, + ) + + @property + def interface(self) -> _interfaces.TypedInterface: + return super(FlyteTask, self).interface + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.TASK + + @property + def entity_type_text(self) -> str: + return "Task" + + @classmethod + def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": + t = cls( + id=base_model.id, + type=base_model.type, + metadata=base_model.metadata, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + custom=base_model.custom, + container=base_model.container, + task_type_version=base_model.task_type_version, + ) + # Override the newly generated name if one exists in the base model + if not base_model.id.is_empty: + t._id = _identifier.Identifier.promote_from_model(base_model.id) + + return t + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteTask": + """ + This function uses the engine loader to call create a hydrated task from Admin. + + :param project: + :param domain: + :param name: + :param version: + """ + task_id = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version) + admin_task = _flyte_engine.get_client().get_task(task_id) + + flyte_task = cls.promote_from_model(admin_task.closure.compiled_task.template) + flyte_task._id = task_id + return flyte_task + + @classmethod + @_exception_scopes.system_entry_point + def fetch_latest(cls, project: str, domain: str, name: str) -> "FlyteTask": + """ + This function uses the engine loader to call create a latest hydrated task from Admin. + + :param project: + :param domain: + :param name: + """ + named_task = _common_model.NamedEntityIdentifier(project, domain, name) + client = _flyte_engine.get_client() + task_list, _ = client.list_tasks_paginated( + named_task, + limit=1, + sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), + ) + admin_task = task_list[0] if task_list else None + + if not admin_task: + raise _user_exceptions.FlyteEntityNotExistException("Named task {} not found".format(named_task)) + flyte_task = cls.promote_from_model(admin_task.closure.compiled_task.template) + flyte_task._id = admin_task.id + return flyte_task diff --git a/flytekit/control_plane/workflow.py b/flytekit/control_plane/workflow.py new file mode 100644 index 00000000000..a164b986678 --- /dev/null +++ b/flytekit/control_plane/workflow.py @@ -0,0 +1,167 @@ +from typing import Dict, List, Optional + +from flytekit.common import constants as _constants +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interfaces +from flytekit.control_plane import nodes as _nodes +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import task as _task_models +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import workflow as _workflow_models + + +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, _workflow_models.WorkflowTemplate): + """A Flyte control plane construct.""" + + def __init__( + self, + nodes: List[_nodes.FlyteNode], + interface, + output_bindings, + id, + metadata, + metadata_defaults, + ): + for node in nodes: + for upstream in node.upstream_nodes: + if upstream.id is None: + raise _user_exceptions.FlyteAssertion( + "Some nodes contained in the workflow were not found in the workflow description. Please " + "ensure all nodes are either assigned to attributes within the class or an element in a " + "list, dict, or tuple which is stored as an attribute in the class." + ) + super(FlyteWorkflow, self).__init__( + id=id, + metadata=metadata, + metadata_defaults=metadata_defaults, + interface=interface, + nodes=nodes, + outputs=output_bindings, + ) + self._flyte_nodes = nodes + + @property + def upstream_entities(self): + return set(n.executable_flyte_object for n in self._flyte_nodes) + + @property + def interface(self) -> _interfaces.TypedInterface: + return super(FlyteWorkflow, self).interface + + @property + def entity_type_text(self) -> str: + return "Workflow" + + @property + def resource_type(self): + return _identifier_model.ResourceType.WORKFLOW + + def get_sub_workflows(self) -> List["FlyteWorkflow"]: + result = [] + for node in self.nodes: + if node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None: + if ( + node.executable_flyte_object is not None + and node.executable_flyte_object.entity_type_text == "Workflow" + ): + result.append(node.executable_flyte_object) + result.extend(node.executable_flyte_object.get_sub_workflows()) + else: + raise _system_exceptions.FlyteSystemException( + "workflow node with subworkflow found but bad executable " + "object {}".format(node.executable_flyte_object) + ) + + # get subworkflows in conditional branches + if node.branch_node is not None: + if_else: _workflow_models.IfElseBlock = node.branch_node.if_else + leaf_nodes: List[_nodes.FlyteNode] = filter( + None, + [ + if_else.case.then_node, + *([] if if_else.other is None else [x.then_node for x in if_else.other]), + if_else.else_node, + ], + ) + for leaf_node in leaf_nodes: + exec_flyte_obj = leaf_node.executable_flyte_object + if exec_flyte_obj is not None and exec_flyte_obj.entity_type_text == "Workflow": + result.append(exec_flyte_obj) + result.extend(exec_flyte_obj.get_sub_workflows()) + + return result + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str): + workflow_id = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version) + admin_workflow = _flyte_engine.get_client().get_workflow(workflow_id) + cwc = admin_workflow.closure.compiled_workflow + flyte_workflow = cls.promote_from_model( + base_model=cwc.primary.template, + sub_workflows={sw.template.id: sw.template for sw in cwc.sub_workflows}, + tasks={t.template.id: t.template for t in cwc.tasks}, + ) + flyte_workflow._id = workflow_id + return flyte_workflow + + @classmethod + def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: + return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[_identifier.Identifier, _task_models.TaskTemplate]] = None, + ) -> "FlyteWorkflow": + base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + sub_workflows = sub_workflows or {} + tasks = tasks or {} + node_map = { + n.id: _nodes.FlyteNode.promote_from_model(n, sub_workflows, tasks) for n in base_model_non_system_nodes + } + + # Set upstream nodes for each node + for n in base_model_non_system_nodes: + current = node_map[n.id] + for upstream_id in current.upstream_node_ids: + upstream_node = node_map[upstream_id] + current << upstream_node + + # No inputs/outputs specified, see the constructor for more information on the overrides. + return cls( + nodes=list(node_map.values()), + id=_identifier.Identifier.promote_from_model(base_model.id), + metadata=base_model.metadata, + metadata_defaults=base_model.metadata_defaults, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + output_bindings=base_model.outputs, + ) + + @_exception_scopes.system_entry_point + def register(self, project, domain, name, version): + # TODO + pass + + @_exception_scopes.system_entry_point + def serialize(self): + # TODO + pass + + @_exception_scopes.system_entry_point + def validate(self): + # TODO + pass + + @_exception_scopes.system_entry_point + def create_launch_plan(self, *args, **kwargs): + # TODO + pass + + @_exception_scopes.system_entry_point + def __call__(self, *args, **input_map): + raise NotImplementedError diff --git a/flytekit/control_plane/workflow_execution.py b/flytekit/control_plane/workflow_execution.py new file mode 100644 index 00000000000..11eb3523511 --- /dev/null +++ b/flytekit/control_plane/workflow_execution.py @@ -0,0 +1,150 @@ +import os as _os +from typing import Any, Dict, List + +from flyteidl.core import literals_pb2 as _literals_pb2 + +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact +from flytekit.control_plane import identifier as _core_identifier +from flytekit.control_plane import nodes as _nodes +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.models import execution as _execution_models +from flytekit.models import filters as _filter_models +from flytekit.models import literals as _literal_models +from flytekit.models.core import execution as _core_execution_models + + +class FlyteWorkflowExecution(_execution_models.Execution, _artifact.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) + self._node_executions = None + self._inputs = None + self._outputs = None + + @property + def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]: + return self._node_executions or {} + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dictated by the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.Literalmap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + self._inputs = input_map + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_execution_data(self.id) + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + output_map: LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + self._outputs = output_map + return self._outputs + + @property + def error(self) -> _core_execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until a workflow has completed before checking for an " "error." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """ + Whether or not the execution is complete. + """ + return self.closure.phase in { + _core_execution_models.WorkflowExecutionPhase.ABORTED, + _core_execution_models.WorkflowExecutionPhase.FAILED, + _core_execution_models.WorkflowExecutionPhase.SUCCEEDED, + _core_execution_models.WorkflowExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: _execution_models.Execution) -> "FlyteWorkflowExecution": + return cls( + closure=base_model.closure, + id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id), + spec=base_model.spec, + ) + + @classmethod + def fetch(cls, project: str, domain: str, name: str) -> "FlyteWorkflowExecution": + return cls.promote_from_model( + _flyte_engine.get_client().get_execution( + _core_identifier.WorkflowExecutionIdentifier(project=project, domain=domain, name=name) + ) + ) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + if not self.is_complete or self._node_executions is None: + self._sync_closure() + self._node_executions = self.get_node_executions() + + def _sync_closure(self): + if not self.is_complete: + client = _flyte_engine.get_client() + self._closure = client.get_execution(self.id).closure + + def get_node_executions(self, filters: List[_filter_models.Filter] = None) -> Dict[str, _nodes.FlyteNodeExecution]: + client = _flyte_engine.get_client() + return { + node.id.node_id: _nodes.FlyteNodeExecution.promote_from_model(node) + for node in _iterate_node_executions(client, self.id, filters=filters) + } + + def terminate(self, cause: str): + _flyte_engine.get_client().terminate_execution(self.id, cause) diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 55699627438..62e28648d83 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -82,10 +82,14 @@ entrypoints==0.3 # nbconvert # papermill <<<<<<< HEAD +<<<<<<< HEAD flyteidl==0.18.39 ======= flyteidl==0.18.37 >>>>>>> Sqlalchemy Task (#445) +======= +flyteidl==0.18.38 +>>>>>>> add new control plane classes (#425) # via flytekit gevent==21.1.2 # via sagemaker-training @@ -98,7 +102,9 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -326,7 +332,14 @@ typed-ast==1.4.3 # via black >>>>>>> Sqlalchemy Task (#445) typing-extensions==3.7.4.3 +<<<<<<< HEAD # via typing-inspect +======= + # via + # black + # importlib-metadata + # typing-inspect +>>>>>>> add new control plane classes (#425) typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 8ae4b441395..afa1acaa901 100644 --- a/requirements.txt +++ b/requirements.txt @@ -78,10 +78,14 @@ entrypoints==0.3 # nbconvert # papermill <<<<<<< HEAD +<<<<<<< HEAD flyteidl==0.18.39 ======= flyteidl==0.18.37 >>>>>>> Sqlalchemy Task (#445) +======= +flyteidl==0.18.38 +>>>>>>> add new control plane classes (#425) # via flytekit gevent==21.1.2 # via sagemaker-training @@ -94,7 +98,9 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -318,7 +324,14 @@ typed-ast==1.4.3 # via black >>>>>>> Sqlalchemy Task (#445) typing-extensions==3.7.4.3 +<<<<<<< HEAD # via typing-inspect +======= + # via + # black + # importlib-metadata + # typing-inspect +>>>>>>> add new control plane classes (#425) typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore b/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore new file mode 100644 index 00000000000..9bf95ea6808 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore @@ -0,0 +1 @@ +*.pb \ No newline at end of file diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md b/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md new file mode 100644 index 00000000000..1972a7c6589 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md @@ -0,0 +1,4 @@ +# Mock Flyte Repo + +This is a trimmed down version of the [flytesnacks](https://github.com/flyteorg/flytesnacks) +repo for the purposes of local integration testing. diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk b/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk new file mode 100644 index 00000000000..15bc979759c --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk @@ -0,0 +1,24 @@ +SERIALIZED_PB_OUTPUT_DIR := /tmp/output + +.PHONY: clean +clean: + rm -rf $(SERIALIZED_PB_OUTPUT_DIR)/* + +$(SERIALIZED_PB_OUTPUT_DIR): clean + mkdir -p $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: serialize +serialize: $(SERIALIZED_PB_OUTPUT_DIR) + pyflyte --config /root/sandbox.config serialize workflows -f $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: register +register: serialize + flyte-cli register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development -v ${VERSION} --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} $(SERIALIZED_PB_OUTPUT_DIR)/* + +.PHONY: fast_serialize +fast_serialize: $(SERIALIZED_PB_OUTPUT_DIR) + pyflyte --config /root/sandbox.config serialize fast workflows -f $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: fast_register +fast_register: fast_serialize + flyte-cli fast-register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} --additional-distribution-dir ${ADDL_DISTRIBUTION_DIR} $(SERIALIZED_PB_OUTPUT_DIR)/* diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile new file mode 100644 index 00000000000..7e5d01829fd --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile @@ -0,0 +1,35 @@ +FROM python:3.8-slim-buster +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks + +WORKDIR /root +ENV VENV /opt/venv +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +# This is necessary for opencv to work +RUN apt-get update && apt-get install -y libsm6 libxext6 libxrender-dev ffmpeg build-essential + +# Install the AWS cli separately to prevent issues with boto being written over +RUN pip3 install awscli + +ENV VENV /opt/venv +# Virtual environment +RUN python3 -m venv ${VENV} +ENV PATH="${VENV}/bin:$PATH" + +# Install Python dependencies +COPY workflows/requirements.txt /root +RUN pip install -r /root/requirements.txt + +# Copy the makefile targets to expose on the container. This makes it easier to register +COPY in_container.mk /root/Makefile +COPY workflows/sandbox.config /root + +# Copy the actual code +COPY workflows /root/workflows + +# This tag is supplied by the build script and will be used to determine the version +# when registering tasks, workflows, and launch plans +ARG tag +ENV FLYTE_INTERNAL_IMAGE $tag diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile new file mode 100644 index 00000000000..5812f4893cc --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile @@ -0,0 +1,208 @@ +.SILENT: + +PREFIX=workflows + +# This is used by the image building script referenced below. Normally it just takes the directory name but in this +# case we want it to be called something else. +IMAGE_NAME=flytecookbook +export VERSION ?= $(shell git rev-parse HEAD) + +define PIP_COMPILE +pip-compile $(1) ${PIP_ARGS} --upgrade --verbose +endef + +# Set SANDBOX=1 to automatically fill in sandbox config +ifdef SANDBOX + +# The url for Flyte Control plane +export FLYTE_HOST ?= localhost:30081 + +# Overrides s3 url. This is solely needed for SANDBOX deployments. Shouldn't be overriden in production AWS S3. +export FLYTE_AWS_ENDPOINT ?= http://localhost:30084/ + +# Used to authenticate to s3. For a production AWS S3, it's discouraged to use keys and key ids. +export FLYTE_AWS_ACCESS_KEY_ID ?= minio + +# Used to authenticate to s3. For a production AWS S3, it's discouraged to use keys and key ids. +export FLYTE_AWS_SECRET_ACCESS_KEY ?= miniostorage + +# Used to publish artifacts for fast registration +export ADDL_DISTRIBUTION_DIR ?= s3://my-s3-bucket/fast/ + +# The base of where Blobs, Schemas and other offloaded types are, by default, serialized. +export OUTPUT_DATA_PREFIX ?= s3://my-s3-bucket/raw-data + +# Instructs flyte-cli commands to use insecure channel when communicating with Flyte's control plane. +# If you're port-forwarding your service or running the sandbox Flyte deployment, specify INSECURE=1 before your make command. +# If your Flyte Admin is behind SSL, don't specify anything. +ifndef INSECURE + export INSECURE_FLAG=-i +endif + +# The docker registry that should be used to push images. +# e.g.: +# export REGISTRY ?= ghcr.io/flyteorg +endif + +# The Flyte project that we want to register under +export PROJECT ?= flytesnacks + +# If the REGISTRY environment variable has been set, that means the image name will not just be tagged as +# flytecookbook: but rather, +# ghcr.io/flyteorg/flytecookbook: or whatever your REGISTRY is. +ifdef REGISTRY + FULL_IMAGE_NAME = ${REGISTRY}/${IMAGE_NAME} +endif +ifndef REGISTRY + FULL_IMAGE_NAME = ${IMAGE_NAME} +endif + +# If you are using a different service account on your k8s cluster, add SERVICE_ACCOUNT=my_account before your make command +ifndef SERVICE_ACCOUNT + SERVICE_ACCOUNT=default +endif + +.PHONY: help +help: ## show help message + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[$$()% a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + +.PHONY: debug +debug: + echo "IMAGE NAME ${IMAGE_NAME}" + echo "FULL IMAGE NAME ${FULL_IMAGE_NAME}" + echo "VERSION TAG ${VERSION}" + echo "REGISTRY ${REGISTRY}" + +TAGGED_IMAGE=${FULL_IMAGE_NAME}:${PREFIX}-${VERSION} + +# This should only be used by Admins to push images to the public Dockerhub repo. Make sure you +# specify REGISTRY=ghcr.io/flyteorg or your registry before the make command otherwise this won't actually push +# Also if you want to push the docker image for sagemaker consumption then +# specify ECR_REGISTRY +.PHONY: docker_push +docker_push: docker_build +ifdef REGISTRY + docker push ${TAGGED_IMAGE} +endif + +.PHONY: fmt +fmt: # Format code with black and isort + black . + isort . + +.PHONY: install-piptools +install-piptools: + pip install -U pip-tools + +.PHONY: setup +setup: install-piptools # Install requirements + pip-sync dev-requirements.txt + +.PHONY: lint +lint: # Run linters + flake8 . + +requirements.txt: export CUSTOM_COMPILE_COMMAND := $(MAKE) requirements.txt +requirements.txt: requirements.in install-piptools + $(call PIP_COMPILE,requirements.in) + +.PHONY: requirements +requirements: requirements.txt + +.PHONY: fast_serialize +fast_serialize: clean _pb_output + echo ${CURDIR} + docker run -it --rm \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + -v ${CURDIR}:/root/$(shell basename $(CURDIR)) \ + ${TAGGED_IMAGE} make fast_serialize + +.PHONY: fast_register +fast_register: clean _pb_output ## Packages code and registers without building docker images. + @echo "Tagged Image: " + @echo ${TAGGED_IMAGE} + @echo ${CURDIR} + docker run -it --rm \ + --network host \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + -v ${CURDIR}:/root/$(shell basename $(CURDIR)) \ + ${TAGGED_IMAGE} make fast_register + +.PHONY: docker_build +docker_build: + echo "Tagged Image: " + echo ${TAGGED_IMAGE} + docker build ../ --build-arg tag="${TAGGED_IMAGE}" -t "${TAGGED_IMAGE}" -f Dockerfile + +.PHONY: serialize +serialize: clean _pb_output docker_build + @echo ${VERSION} + @echo ${CURDIR} + docker run -it --rm \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + ${TAGGED_IMAGE} make serialize + + +.PHONY: register +register: clean _pb_output docker_push + @echo ${VERSION} + @echo ${CURDIR} + docker run -it --rm \ + --network host \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + ${TAGGED_IMAGE} make register + +_pb_output: + mkdir -p _pb_output + +.PHONY: clean +clean: + rm -rf _pb_output/* diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py new file mode 100644 index 00000000000..49c42c5911d --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py @@ -0,0 +1,54 @@ +""" +Write a simple workflow +------------------------------ + +Once you've had a handle on tasks, we can move to workflows. Workflow are the other basic building block of Flyte. + +Workflows string together two or more tasks. They are also written as Python functions, but it is important to make a +critical distinction between tasks and workflows. + +The body of a task's function runs at "run time", i.e. on the K8s cluster, using the task's container. The body of a +workflow is not used for computation, it is only used to structure the tasks, i.e. the output of ``t1`` is an input +of ``t2`` in the workflow below. As such, the body of workflows is run at "registration" time. Please refer to the +registration docs for additional information as well since it is actually a two-step process. + +Take a look at the conceptual `discussion `__ +behind workflows for additional information. + +""" +import typing + +from flytekit import task, workflow + + +@task +def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + return a + 2, "world" + + +@task +def t2(a: str, b: str) -> str: + return b + a + + +# %% +# You can treat the outputs of a task as you normally would a Python function. Assign the output to two variables +# and use them in subsequent tasks as normal. See :py:func:`flytekit.workflow` +@workflow +def my_wf(a: int, b: str) -> (int, str): + x, y = t1(a=a) + d = t2(a=y, b=b) + return x, d + + +# %% +# Execute the Workflow, simply by invoking it like a function and passing in +# the necessary parameters +# +# .. note:: +# +# One thing to remember, currently we only support ``Keyword arguments``. So +# every argument should be passed in the form ``arg=value``. Failure to do so +# will result in an error +if __name__ == "__main__": + print(f"Running my_wf(a=50, b='hello') {my_wf(a=50, b='hello')}") diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py new file mode 100644 index 00000000000..da7e61536f1 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py @@ -0,0 +1,40 @@ +""" +Hello World Workflow +-------------------- + +This simple workflow calls a task that returns "Hello World" and then just sets that as the final output of the workflow. + +""" + +from flytekit import task, workflow + + +# You can change the signature of the workflow to take in an argument like this: +# def say_hello(name: str) -> str: +@task +def say_hello() -> str: + return "hello world" + + +# %% +# You can treat the outputs of a task as you normally would a Python function. Assign the output to two variables +# and use them in subsequent tasks as normal. See :py:func:`flytekit.workflow` +# You can change the signature of the workflow to take in an argument like this: +# def my_wf(name: str) -> str: +@workflow +def my_wf() -> str: + res = say_hello() + return res + + +# %% +# Execute the Workflow, simply by invoking it like a function and passing in +# the necessary parameters +# +# .. note:: +# +# One thing to remember, currently we only support ``Keyword arguments``. So +# every argument should be passed in the form ``arg=value``. Failure to do so +# will result in an error +if __name__ == "__main__": + print(f"Running my_wf() {my_wf()}") diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in new file mode 100644 index 00000000000..f7d015b8435 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in @@ -0,0 +1,4 @@ +flytekit>=0.17.0b0 +wheel +matplotlib +opencv-python diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt new file mode 100644 index 00000000000..7b2b8803067 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt @@ -0,0 +1,136 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# /Library/Developer/CommandLineTools/usr/bin/make requirements.txt +# +attrs==20.3.0 + # via scantree +certifi==2020.12.5 + # via requests +chardet==4.0.0 + # via requests +click==7.1.2 + # via flytekit +croniter==1.0.10 + # via flytekit +cycler==0.10.0 + # via matplotlib +dataclasses-json==0.5.2 + # via flytekit +decorator==5.0.4 + # via retry +deprecated==1.2.12 + # via flytekit +dirhash==0.2.1 + # via flytekit +docker-image-py==0.1.10 + # via flytekit +flyteidl==0.18.31 + # via flytekit +flytekit==0.17.0 + # via -r ../common/requirements-common.in +grpcio==1.36.1 + # via flytekit +idna==2.10 + # via requests +importlib-metadata==3.10.0 + # via keyring +keyring==23.0.1 + # via flytekit +kiwisolver==1.3.1 + # via matplotlib +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow==3.11.1 + # via + # dataclasses-json + # marshmallow-enum +matplotlib==3.4.1 + # via -r ../common/requirements-common.in +mypy-extensions==0.4.3 + # via typing-inspect +natsort==7.1.1 + # via flytekit +numpy==1.20.2 + # via + # matplotlib + # opencv-python + # pandas + # pyarrow +opencv-python==4.5.1.48 + # via -r requirements.in +pandas==1.2.3 + # via flytekit +pathspec==0.8.1 + # via scantree +pillow==8.2.0 + # via matplotlib +protobuf==3.15.7 + # via + # flyteidl + # flytekit +py==1.10.0 + # via retry +pyarrow==3.0.0 + # via flytekit +pyparsing==2.4.7 + # via matplotlib +python-dateutil==2.8.1 + # via + # croniter + # flytekit + # matplotlib + # pandas +pytimeparse==1.1.8 + # via flytekit +pytz==2018.4 + # via + # flytekit + # pandas +regex==2021.3.17 + # via docker-image-py +requests==2.25.1 + # via + # flytekit + # responses +responses==0.13.2 + # via flytekit +retry==0.9.2 + # via flytekit +scantree==0.0.1 + # via dirhash +six==1.15.0 + # via + # cycler + # flytekit + # grpcio + # protobuf + # python-dateutil + # responses + # scantree +sortedcontainers==2.3.0 + # via flytekit +statsd==3.3.0 + # via flytekit +stringcase==1.2.0 + # via dataclasses-json +typing-extensions==3.7.4.3 + # via typing-inspect +typing-inspect==0.6.0 + # via dataclasses-json +urllib3==1.25.11 + # via + # flytekit + # requests + # responses +wheel==0.36.2 + # via + # -r ../common/requirements-common.in + # flytekit +wrapt==1.12.1 + # via + # deprecated + # flytekit +zipp==3.4.1 + # via importlib-metadata diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config new file mode 100644 index 00000000000..da3362a4b03 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config @@ -0,0 +1,7 @@ +[sdk] +workflow_packages=workflows +python_venv=flytekit_venv + +[auth] +assumable_iam_role=arn:aws:iam::173840052742:role/flytefunctionaltestsbatchworker-production-iad +raw_output_data_prefix=s3://lyft-modelbuilder/cookbook diff --git a/tests/flytekit/integration/control_plane/test_workflow.py b/tests/flytekit/integration/control_plane/test_workflow.py new file mode 100644 index 00000000000..8b72cd7c2a5 --- /dev/null +++ b/tests/flytekit/integration/control_plane/test_workflow.py @@ -0,0 +1,90 @@ +import datetime +import os +import pathlib +import time + +import pytest + +from flytekit.common.exceptions.user import FlyteAssertion +from flytekit.control_plane import launch_plan +from flytekit.models import literals + +PROJECT = "flytesnacks" +VERSION = os.getpid() + + +@pytest.fixture(scope="session") +def flyte_workflows_source_dir(): + return pathlib.Path(os.path.dirname(__file__)) / "mock_flyte_repo" + + +@pytest.fixture(scope="session") +def flyte_workflows_register(docker_compose): + docker_compose.execute( + f"exec -w /flyteorg/src -e SANDBOX=1 -e PROJECT={PROJECT} -e VERSION=v{VERSION} " + "backend make -C workflows register" + ) + + +def test_client(flyteclient, flyte_workflows_register): + projects = flyteclient.list_projects_paginated(limit=5, token=None) + assert len(projects) <= 5 + + +def test_launch_workflow(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" + ).launch_with_literals(PROJECT, "development", literals.LiteralMap({})) + execution.wait_for_completion() + assert execution.outputs.literals["o0"].scalar.primitive.string_value == "hello world" + + +def test_launch_workflow_with_args(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.basic_workflow.my_wf", f"v{VERSION}" + ).launch_with_literals( + PROJECT, + "development", + literals.LiteralMap( + { + "a": literals.Literal(literals.Scalar(literals.Primitive(integer=10))), + "b": literals.Literal(literals.Scalar(literals.Primitive(string_value="foobar"))), + } + ), + ) + execution.wait_for_completion() + assert execution.outputs.literals["o0"].scalar.primitive.integer == 12 + assert execution.outputs.literals["o1"].scalar.primitive.string_value == "foobarworld" + + +def test_monitor_workflow(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" + ).launch_with_literals(PROJECT, "development", literals.LiteralMap({})) + + poll_interval = datetime.timedelta(seconds=1) + time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60) + + execution.sync() + while datetime.datetime.utcnow() < time_to_give_up: + + if execution.is_complete: + execution.sync() + break + + with pytest.raises( + FlyteAssertion, match="Please wait until the node execution has completed before requesting the outputs" + ): + execution.outputs + + time.sleep(poll_interval.total_seconds()) + execution.sync() + + if execution.node_executions: + assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED + + for key in execution.node_executions: + assert execution.node_executions[key].closure.phase == 3 + + assert execution.node_executions["n0"].outputs.literals["o0"].scalar.primitive.string_value == "hello world" + assert execution.outputs.literals["o0"].scalar.primitive.string_value == "hello world" diff --git a/tests/flytekit/unit/control_plane/__init__.py b/tests/flytekit/unit/control_plane/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/flytekit/unit/control_plane/tasks/test_task.py b/tests/flytekit/unit/control_plane/tasks/test_task.py new file mode 100644 index 00000000000..479221ae9d2 --- /dev/null +++ b/tests/flytekit/unit/control_plane/tasks/test_task.py @@ -0,0 +1,34 @@ +from mock import MagicMock as _MagicMock +from mock import patch as _patch + +from flytekit.control_plane.tasks import task as _task +from flytekit.models import task as _task_models +from flytekit.models.core import identifier as _identifier + + +@_patch("flytekit.engines.flyte.engine._FlyteClientManager") +@_patch("flytekit.configuration.platform.URL") +def test_flyte_task_fetch(mock_url, mock_client_manager): + mock_url.get.return_value = "localhost" + admin_task_v1 = _task_models.Task( + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + _MagicMock(), + ) + admin_task_v2 = _task_models.Task( + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v2"), + _MagicMock(), + ) + mock_client = _MagicMock() + mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task_v2, admin_task_v1], "")) + mock_client_manager.return_value.client = mock_client + + latest_task = _task.FlyteTask.fetch_latest("p1", "d1", "n1") + task_v1 = _task.FlyteTask.fetch("p1", "d1", "n1", "v1") + task_v2 = _task.FlyteTask.fetch("p1", "d1", "n1", "v2") + assert task_v1.id == admin_task_v1.id + assert task_v1.id != latest_task.id + assert task_v2.id == latest_task.id == admin_task_v2.id + + for task in [task_v1, task_v2]: + assert task.entity_type_text == "Task" + assert task.resource_type == _identifier.ResourceType.TASK diff --git a/tests/flytekit/unit/control_plane/test_identifier.py b/tests/flytekit/unit/control_plane/test_identifier.py new file mode 100644 index 00000000000..8976df0bd3c --- /dev/null +++ b/tests/flytekit/unit/control_plane/test_identifier.py @@ -0,0 +1,77 @@ +import pytest + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.models.core import identifier as _core_identifier + + +def test_identifier(): + identifier = _identifier.Identifier(_core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1") + assert identifier == _identifier.Identifier.from_urn("wf:project:domain:name:v1") + assert identifier == _core_identifier.Identifier( + _core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1" + ) + assert identifier.__str__() == "wf:project:domain:name:v1" + + +@pytest.mark.parametrize( + "urn", + [ + "", + "project:domain:name:v1", + "wf:project:domain:name:v1:foobar", + "foobar:project:domain:name:v1", + ], +) +def test_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.Identifier.from_urn(urn) + + +def test_workflow_execution_identifier(): + identifier = _identifier.WorkflowExecutionIdentifier("project", "domain", "name") + assert identifier == _identifier.WorkflowExecutionIdentifier.from_urn("ex:project:domain:name") + assert identifier == _identifier.WorkflowExecutionIdentifier.promote_from_model( + _core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + assert identifier.__str__() == "ex:project:domain:name" + + +@pytest.mark.parametrize( + "urn", ["", "project:domain:name", "project:domain:name:foobar", "ex:project:domain:name:foobar"] +) +def test_workflow_execution_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.WorkflowExecutionIdentifier.from_urn(urn) + + +def test_task_execution_identifier(): + task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK, "project", "domain", "name", "version") + node_execution_id = _core_identifier.NodeExecutionIdentifier( + node_id="n0", execution_id=_core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + identifier = _identifier.TaskExecutionIdentifier( + task_id=task_id, + node_execution_id=node_execution_id, + retry_attempt=0, + ) + assert identifier == _identifier.TaskExecutionIdentifier.from_urn( + "te:project:domain:name:n0:project:domain:name:version:0" + ) + assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model( + _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id, 0) + ) + assert identifier.__str__() == "te:project:domain:name:n0:project:domain:name:version:0" + + +@pytest.mark.parametrize( + "urn", + [ + "", + "te:project:domain:name:n0:project:domain:name:version", + "foobar:project:domain:name:n0:project:domain:name:version:0", + ], +) +def test_task_execution_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.TaskExecutionIdentifier.from_urn(urn) diff --git a/tests/flytekit/unit/control_plane/test_workflow.py b/tests/flytekit/unit/control_plane/test_workflow.py new file mode 100644 index 00000000000..81c82bf706d --- /dev/null +++ b/tests/flytekit/unit/control_plane/test_workflow.py @@ -0,0 +1,23 @@ +from mock import MagicMock as _MagicMock +from mock import patch as _patch + +from flytekit.control_plane import workflow as _workflow +from flytekit.models.admin import workflow as _workflow_models +from flytekit.models.core import identifier as _identifier + + +@_patch("flytekit.engines.flyte.engine._FlyteClientManager") +@_patch("flytekit.configuration.platform.URL") +def test_flyte_workflow_integration(mock_url, mock_client_manager): + mock_url.get.return_value = "localhost" + admin_workflow = _workflow_models.Workflow( + _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p1", "d1", "n1", "v1"), + _MagicMock(), + ) + mock_client = _MagicMock() + mock_client.list_workflows_paginated = _MagicMock(returnValue=([admin_workflow], "")) + mock_client_manager.return_value.client = mock_client + + workflow = _workflow.FlyteWorkflow.fetch("p1", "d1", "n1", "v1") + assert workflow.entity_type_text == "Workflow" + assert workflow.id == admin_workflow.id