Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile the workflow only at compile time #1311

Merged
merged 29 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,6 @@ def flyte_entity_call_handler(
)

ctx = FlyteContextManager.current_context()

if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ def interface(self) -> _interface_models.TypedInterface:

@property
def output_bindings(self) -> List[_literal_models.Binding]:
self.compile()
return self._output_bindings

@property
def nodes(self) -> List[Node]:
self.compile()
return self._nodes

def __repr__(self):
Expand All @@ -257,11 +259,15 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
# Get default arguments and override with kwargs passed in
input_kwargs = self.python_interface.default_inputs_as_kwargs
input_kwargs.update(kwargs)
self.compile()
return flyte_entity_call_handler(self, *args, **input_kwargs)

def execute(self, **kwargs):
raise Exception("Should not be called")

def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
Expand All @@ -272,6 +278,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
Expand Down Expand Up @@ -612,6 +619,7 @@ def __init__(
python_interface=native_interface,
docs=docs,
)
self.compiled = False

@property
def function(self):
Expand All @@ -625,6 +633,9 @@ def compile(self, **kwargs):
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
a 'closure' in the traditional sense of the word.
"""
if self.compiled:
return
self.compiled = True
ctx = FlyteContextManager.current_context()
self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface)
all_nodes = []
Expand Down Expand Up @@ -759,7 +770,6 @@ def wrapper(fn):
docstring=Docstring(callable_=fn),
docs=docs,
)
workflow_instance.compile()
update_wrapper(workflow_instance, fn)
return workflow_instance

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-pandera/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]:
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]:
return transform2(df=transform1(df=invalid_df))

invalid_wf()

# raise error when executing workflow with invalid input
@workflow
def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]:
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]:
@workflow
def wf():
return t3()

wf()
6 changes: 6 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,25 @@ def decompose_unary() -> int:
result = return_true()
return conditional("test").if_(result).then(success()).else_().then(failed())

decompose_unary()

with pytest.raises(AssertionError):

@workflow
def decompose_none() -> int:
return conditional("test").if_(None).then(success()).else_().then(failed())

decompose_none()

with pytest.raises(AssertionError):

@workflow
def decompose_is() -> int:
result = return_true()
return conditional("test").if_(result is True).then(success()).else_().then(failed())

decompose_is()

@workflow
def decompose() -> int:
result = return_true()
Expand Down
7 changes: 6 additions & 1 deletion tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@workflow
def subwf(a: int):
t1(a=a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
subwf(a=a)
return s

@workflow
Expand All @@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]:
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec._nodes) == 6
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
Expand Down
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def wf1(a: int):
def wf2(a: typing.List[int]):
return map_task(wf1)(a=a)

wf2()

lp = LaunchPlan.create("test", wf1)

with pytest.raises(ValueError):
Expand All @@ -167,6 +169,8 @@ def wf2(a: typing.List[int]):
def wf3(a: typing.List[int]):
return map_task(lp)(a=a)

wf3()


def test_inputs_outputs_length():
@task
Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def empty_wf2():
def empty_wf2():
create_node(t2, "foo")

empty_wf2()


def test_more_normal_task():
nt = typing.NamedTuple("OneOutput", t1_str_output=str)
Expand Down Expand Up @@ -141,6 +143,8 @@ def my_wf(a: int) -> str:
t1_node = create_node(t1, a=a)
return t1_node.outputs

my_wf()


def test_runs_before():
@task
Expand Down Expand Up @@ -330,6 +334,8 @@ def t1(a: str) -> str:
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout="foo")

my_wf()


@pytest.mark.parametrize(
"retries,expected",
Expand Down Expand Up @@ -443,3 +449,5 @@ def my_wf(a: str) -> str:
@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)

my_wf()
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def my_wf() -> wf_outputs:
# Note only Namedtuple can be created like this
return wf_outputs(say_hello(), say_hello())

my_wf()


def test_serialized_docstrings():
@task
Expand Down
14 changes: 11 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def my_wf(a: int, b: str) -> (int, str):
d = t2(a=y, b=b)
return x, d

assert len(my_wf._nodes) == 2
assert len(my_wf.nodes) == 2
assert my_wf._nodes[0].id == "n0"
assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0]

assert len(my_wf._output_bindings) == 2
assert len(my_wf.output_bindings) == 2
assert my_wf._output_bindings[0].var == "o0"
assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output"

Expand Down Expand Up @@ -280,18 +280,24 @@ def test_wf_output_mismatch():
def my_wf(a: int, b: str) -> (int, str):
return a

my_wf()

with pytest.raises(AssertionError):

@workflow
def my_wf2(a: int, b: str) -> int:
return a, b # type: ignore

my_wf2()

with pytest.raises(AssertionError):

@workflow
def my_wf3(a: int, b: str) -> int:
return (a,) # type: ignore

my_wf3()

assert context_manager.FlyteContextManager.size() == 1


Expand Down Expand Up @@ -676,7 +682,7 @@ def lister() -> typing.List[str]:
return s

assert len(lister.interface.outputs) == 1
binding_data = lister._output_bindings[0].binding # the property should be named binding_data
binding_data = lister.output_bindings[0].binding # the property should be named binding_data
assert binding_data.collection is not None
assert len(binding_data.collection.bindings) == 10

Expand Down Expand Up @@ -800,6 +806,8 @@ def my_wf(a: int, b: str) -> (int, str):
conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah")
return x, d

my_wf()

assert context_manager.FlyteContextManager.size() == 1


Expand Down
20 changes: 19 additions & 1 deletion tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing_extensions import Annotated # type: ignore

import flytekit.configuration
from flytekit import StructuredDataset, kwtypes
from flytekit import FlyteContextManager, StructuredDataset, kwtypes
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.condition import conditional
from flytekit.core.task import task
from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow
Expand Down Expand Up @@ -156,6 +157,8 @@ def no_outputs_wf():
def one_output_wf() -> int: # type: ignore
t1(a=3)

one_output_wf()


def test_wf_no_output():
@task
Expand Down Expand Up @@ -320,3 +323,18 @@ def test_structured_dataset_wf():
assert_frame_equal(sd_to_schema_wf(), superset_df)
assert_frame_equal(schema_to_sd_wf()[0], subset_df)
assert_frame_equal(schema_to_sd_wf()[1], subset_df)


def test_compile_wf_at_compile_time():
ctx = FlyteContextManager.current_context()
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
):

@workflow
def wf():
t4()

assert ctx.compilation_state is None
2 changes: 2 additions & 0 deletions tests/flytekit/unit/remote/test_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def test_misnamed():
def wf(a: int) -> int:
return ft(b=a)

wf()


def test_calling_lp():
sub_wf_lp = LaunchPlan.get_or_create(sub_wf)
Expand Down