Skip to content

Commit

Permalink
Restrict dynamic task & add unit tests (#2849)
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 authored Oct 22, 2024
1 parent da8436e commit f1adbe8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 30 deletions.
65 changes: 38 additions & 27 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
PythonAutoContainerTask,
default_notebook_task_resolver,
)
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import (
FlyteAssertion,
FlyteEntityAlreadyExistsException,
FlyteEntityNotExistException,
FlyteValueException,
Expand Down Expand Up @@ -198,6 +200,38 @@ def _get_git_repo_url(source_path: str):
return ""


def _get_pickled_target_dict(root_entity: typing.Union[WorkflowBase, PythonTask]) -> typing.Dict[str, typing.Any]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: The pickled target dictionary.
"""
queue = [root_entity]
pickled_target_dict = {}
while queue:
entity = queue.pop()
if isinstance(entity, PythonFunctionTask):
if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
raise FlyteAssertion(
f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task."
)

if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
elif isinstance(entity, CoreNode):
queue.append(entity.flyte_entity)
return pickled_target_dict


class FlyteRemote(object):
"""Main entrypoint for programmatically accessing a Flyte remote backend.
Expand Down Expand Up @@ -2583,39 +2617,16 @@ def download(
for var, literal in lm.items():
download_literal(self.file_access, var, literal, download_to)

def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, typing.Any]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: The pickled target dictionary.
"""
queue = [root_entity]
pickled_target_dict = {}
while queue:
entity = queue.pop()
if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
elif isinstance(entity, CoreNode):
queue.append(entity.flyte_entity)
return pickled_target_dict

def _pickle_and_upload_entity(self, entity: typing.Any) -> typing.Tuple[bytes, FastSerializationSettings]:
def _pickle_and_upload_entity(
self, entity: typing.Union[WorkflowBase, PythonTask]
) -> typing.Tuple[bytes, FastSerializationSettings]:
"""
Pickle the entity to the specified location. This is useful for debugging and for sharing entities across
different environments.
:param entity: The entity to pickle
"""
# get all entity tasks
pickled_dict = self._get_pickled_target_dict(entity)
pickled_dict = _get_pickled_target_dict(entity)
with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH)
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
Expand Down
68 changes: 65 additions & 3 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
from functools import partial

import mock
import pytest
Expand All @@ -15,13 +16,13 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map_task, dynamic
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import FlyteEntityNotExistException
from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteAssertion
from flytekit.models import common as common_models
from flytekit.models import security
from flytekit.models.admin.workflow import Workflow, WorkflowClosure
Expand All @@ -33,7 +34,7 @@
from flytekit.models.task import Task
from flytekit.remote import FlyteTask
from flytekit.remote.lazy_entity import LazyEntity
from flytekit.remote.remote import FlyteRemote, _get_git_repo_url
from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict
from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan
from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES

Expand Down Expand Up @@ -690,3 +691,64 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist
def test_fetch_active_launchplan_not_found(mock_client, remote):
mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None


def test_get_pickled_target_dict():
@task
def t1() -> int:
return 1

@task
def t2(a: int) -> int:
return a + 2

@workflow
def w() -> int:
return t2(a=t1())

target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 2
assert t1.name in target_dict
assert t2.name in target_dict
assert target_dict[t1.name] == t1
assert target_dict[t2.name] == t2

def test_get_pickled_target_dict_with_map_task():
@task
def t1(x: int, y: int) -> int:
return x + y

@workflow
def w() -> int:
return map_task(partial(t1, y=2))(x=[1, 2, 3])

target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 1
assert t1.name in target_dict
assert target_dict[t1.name] == t1

def test_get_pickled_target_dict_with_dynamic():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@workflow
def subwf(a: int):
t1(a=a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
subwf(a=a)
return s

@workflow
def my_wf(a: int) -> typing.List[str]:
v = my_subwf(a=a)
return v

with pytest.raises(FlyteAssertion):
_get_pickled_target_dict(my_wf)

0 comments on commit f1adbe8

Please sign in to comment.