Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile the workflow only at compile time #1311

Merged
merged 29 commits into from
Feb 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
@@ -1061,7 +1061,6 @@ def flyte_entity_call_handler(
)

ctx = FlyteContextManager.current_context()

if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
12 changes: 11 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
@@ -230,10 +230,12 @@ def interface(self) -> _interface_models.TypedInterface:

@property
def output_bindings(self) -> List[_literal_models.Binding]:
self.compile()
return self._output_bindings

@property
def nodes(self) -> List[Node]:
self.compile()
return self._nodes

def __repr__(self):
@@ -257,11 +259,15 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
# Get default arguments and override with kwargs passed in
input_kwargs = self.python_interface.default_inputs_as_kwargs
input_kwargs.update(kwargs)
self.compile()
return flyte_entity_call_handler(self, *args, **input_kwargs)

def execute(self, **kwargs):
raise Exception("Should not be called")

def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
@@ -272,6 +278,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
@@ -612,6 +619,7 @@ def __init__(
python_interface=native_interface,
docs=docs,
)
self.compiled = False

@property
def function(self):
@@ -625,6 +633,9 @@ def compile(self, **kwargs):
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
a 'closure' in the traditional sense of the word.
"""
if self.compiled:
return
self.compiled = True
ctx = FlyteContextManager.current_context()
self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface)
all_nodes = []
@@ -759,7 +770,6 @@ def wrapper(fn):
docstring=Docstring(callable_=fn),
docs=docs,
)
workflow_instance.compile()
update_wrapper(workflow_instance, fn)
return workflow_instance

2 changes: 2 additions & 0 deletions plugins/flytekit-pandera/tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]:
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]:
return transform2(df=transform1(df=invalid_df))

invalid_wf()

# raise error when executing workflow with invalid input
@workflow
def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]:
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
@@ -196,3 +196,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]:
@workflow
def wf():
return t3()

wf()
6 changes: 6 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
@@ -167,19 +167,25 @@ def decompose_unary() -> int:
result = return_true()
return conditional("test").if_(result).then(success()).else_().then(failed())

decompose_unary()

with pytest.raises(AssertionError):

@workflow
def decompose_none() -> int:
return conditional("test").if_(None).then(success()).else_().then(failed())

decompose_none()

with pytest.raises(AssertionError):

@workflow
def decompose_is() -> int:
result = return_true()
return conditional("test").if_(result is True).then(success()).else_().then(failed())

decompose_is()

@workflow
def decompose() -> int:
result = return_true()
7 changes: 6 additions & 1 deletion tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -34,11 +34,16 @@ def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@workflow
def subwf(a: int):
t1(a=a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
subwf(a=a)
return s

@workflow
@@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]:
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec._nodes) == 6
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
@@ -159,6 +159,8 @@ def wf1(a: int):
def wf2(a: typing.List[int]):
return map_task(wf1)(a=a)

wf2()

lp = LaunchPlan.create("test", wf1)

with pytest.raises(ValueError):
@@ -167,6 +169,8 @@ def wf2(a: typing.List[int]):
def wf3(a: typing.List[int]):
return map_task(lp)(a=a)

wf3()


def test_inputs_outputs_length():
@task
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
@@ -96,6 +96,8 @@ def empty_wf2():
def empty_wf2():
create_node(t2, "foo")

empty_wf2()


def test_more_normal_task():
nt = typing.NamedTuple("OneOutput", t1_str_output=str)
@@ -141,6 +143,8 @@ def my_wf(a: int) -> str:
t1_node = create_node(t1, a=a)
return t1_node.outputs

my_wf()


def test_runs_before():
@task
@@ -330,6 +334,8 @@ def t1(a: str) -> str:
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout="foo")

my_wf()


@pytest.mark.parametrize(
"retries,expected",
@@ -443,3 +449,5 @@ def my_wf(a: str) -> str:
@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)

my_wf()
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -440,6 +440,8 @@ def my_wf() -> wf_outputs:
# Note only Namedtuple can be created like this
return wf_outputs(say_hello(), say_hello())

my_wf()


def test_serialized_docstrings():
@task
14 changes: 11 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
@@ -176,11 +176,11 @@ def my_wf(a: int, b: str) -> (int, str):
d = t2(a=y, b=b)
return x, d

assert len(my_wf._nodes) == 2
assert len(my_wf.nodes) == 2
assert my_wf._nodes[0].id == "n0"
assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0]

assert len(my_wf._output_bindings) == 2
assert len(my_wf.output_bindings) == 2
assert my_wf._output_bindings[0].var == "o0"
assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output"

@@ -280,18 +280,24 @@ def test_wf_output_mismatch():
def my_wf(a: int, b: str) -> (int, str):
return a

my_wf()

with pytest.raises(AssertionError):

@workflow
def my_wf2(a: int, b: str) -> int:
return a, b # type: ignore

my_wf2()

with pytest.raises(AssertionError):

@workflow
def my_wf3(a: int, b: str) -> int:
return (a,) # type: ignore

my_wf3()

assert context_manager.FlyteContextManager.size() == 1


@@ -676,7 +682,7 @@ def lister() -> typing.List[str]:
return s

assert len(lister.interface.outputs) == 1
binding_data = lister._output_bindings[0].binding # the property should be named binding_data
binding_data = lister.output_bindings[0].binding # the property should be named binding_data
assert binding_data.collection is not None
assert len(binding_data.collection.bindings) == 10

@@ -800,6 +806,8 @@ def my_wf(a: int, b: str) -> (int, str):
conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah")
return x, d

my_wf()

assert context_manager.FlyteContextManager.size() == 1


20 changes: 19 additions & 1 deletion tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,9 @@
from typing_extensions import Annotated # type: ignore

import flytekit.configuration
from flytekit import StructuredDataset, kwtypes
from flytekit import FlyteContextManager, StructuredDataset, kwtypes
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.condition import conditional
from flytekit.core.task import task
from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow
@@ -156,6 +157,8 @@ def no_outputs_wf():
def one_output_wf() -> int: # type: ignore
t1(a=3)

one_output_wf()


def test_wf_no_output():
@task
@@ -320,3 +323,18 @@ def test_structured_dataset_wf():
assert_frame_equal(sd_to_schema_wf(), superset_df)
assert_frame_equal(schema_to_sd_wf()[0], subset_df)
assert_frame_equal(schema_to_sd_wf()[1], subset_df)


def test_compile_wf_at_compile_time():
ctx = FlyteContextManager.current_context()
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
):

@workflow
def wf():
t4()

assert ctx.compilation_state is None
2 changes: 2 additions & 0 deletions tests/flytekit/unit/remote/test_calling.py
Original file line number Diff line number Diff line change
@@ -75,6 +75,8 @@ def test_misnamed():
def wf(a: int) -> int:
return ft(b=a)

wf()


def test_calling_lp():
sub_wf_lp = LaunchPlan.get_or_create(sub_wf)