Skip to content

Commit

Permalink
Run compilation even in local execution for dynamic tasks to early de…
Browse files Browse the repository at this point in the history
…tect errors (#1121)

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Aug 5, 2022
1 parent 9e154a0 commit 0e830c5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
11 changes: 10 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar, Union

from flytekit.configuration import SerializationSettings
from flytekit.configuration.default_images import DefaultImages
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
Expand Down Expand Up @@ -257,10 +259,17 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
representing that newly generated workflow, instead of executing it.
"""
ctx = FlyteContextManager.current_context()
# This is a placeholder SerializationSettings placeholder and is only used to test compilation for dynamic tasks
# when run locally. The output of the compilation should never actually be used anywhere.
_LOCAL_ONLY_SS = SerializationSettings.for_image(DefaultImages.default_image(), "v", "p", "d")

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)
with FlyteContextManager.with_context(ctx.with_execution_state(updated_exec_state)):
with FlyteContextManager.with_context(
ctx.with_execution_state(updated_exec_state).with_serialization_settings(_LOCAL_ONLY_SS)
) as ctx:
logger.debug(f"Running compilation for {self} as part of local run as check")
self.compile_into_workflow(ctx, task_function, **kwargs)
logger.info("Executing Dynamic workflow, using raw inputs")
return exception_scopes.user_entry_point(task_function)(**kwargs)

Expand Down
48 changes: 34 additions & 14 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import pytest

import flytekit.configuration
from flytekit import dynamic
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
Expand All @@ -10,6 +12,19 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow

settings = flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir="/User/flyte/workflows",
distribution_location="s3://my-s3-bucket/fast/123",
),
)


def test_wf1_with_fast_dynamic():
@task
Expand All @@ -30,20 +45,7 @@ def my_wf(a: int) -> typing.List[str]:
return v

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(
flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir="/User/flyte/workflows",
distribution_location="s3://my-s3-bucket/fast/123",
),
)
)
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
Expand Down Expand Up @@ -111,6 +113,24 @@ def wf(a: int, b: int) -> typing.List[str]:
assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"]


def test_dynamic_local_use():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@dynamic
def use_result(a: int) -> int:
x = t1(a=a)
if len(x) > 6:
return 5
else:
return 0

with pytest.raises(TypeError):
use_result(a=6)


def test_create_node_dynamic_local():
@task
def task1(s: str) -> str:
Expand Down
10 changes: 6 additions & 4 deletions tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from pytest import fixture
from typing_extensions import Annotated

from flytekit import SQLTask, dynamic, kwtypes
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import kwtypes
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.task import TaskMetadata, task
Expand Down Expand Up @@ -309,13 +311,13 @@ def t1(a: int) -> int:

# We should have a cache miss in the first call to downstream_t and have a cache hit
# on the second call.
v_1 = downstream_t(a=v)
downstream_t(a=v)
v_2 = downstream_t(a=v)

return v_1 + v_2
return v_2

assert n_cached_task_calls == 0
assert t1(a=3) == (6 + 6)
assert t1(a=3) == 6
assert n_cached_task_calls == 1


Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,11 +1197,11 @@ def t1(a: int) -> int:

# We should have a cache miss in the first call to downstream_t
v_1 = downstream_t(a=v, df=df)
v_2 = downstream_t(a=v, df=df)
downstream_t(a=v, df=df)

return v_1 + v_2
return v_1

assert t1(a=3) == (6 + 6 + 6)
assert t1(a=3) == 9


def test_literal_hash_int_not_set():
Expand Down

0 comments on commit 0e830c5

Please sign in to comment.