Skip to content

Commit

Permalink
Merge branch 'master' into 3011
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw authored Feb 16, 2023
2 parents 75a6090 + 014eea9 commit 09fb418
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 7 deletions.
1 change: 0 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,6 @@ def flyte_entity_call_handler(
)

ctx = FlyteContextManager.current_context()

if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ def interface(self) -> _interface_models.TypedInterface:

@property
def output_bindings(self) -> List[_literal_models.Binding]:
self.compile()
return self._output_bindings

@property
def nodes(self) -> List[Node]:
self.compile()
return self._nodes

def __repr__(self):
Expand All @@ -257,11 +259,15 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
# Get default arguments and override with kwargs passed in
input_kwargs = self.python_interface.default_inputs_as_kwargs
input_kwargs.update(kwargs)
self.compile()
return flyte_entity_call_handler(self, *args, **input_kwargs)

def execute(self, **kwargs):
raise Exception("Should not be called")

def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
Expand All @@ -272,6 +278,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
Expand Down Expand Up @@ -612,6 +619,7 @@ def __init__(
python_interface=native_interface,
docs=docs,
)
self.compiled = False

@property
def function(self):
Expand All @@ -625,6 +633,9 @@ def compile(self, **kwargs):
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
a 'closure' in the traditional sense of the word.
"""
if self.compiled:
return
self.compiled = True
ctx = FlyteContextManager.current_context()
self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface)
all_nodes = []
Expand Down Expand Up @@ -759,7 +770,6 @@ def wrapper(fn):
docstring=Docstring(callable_=fn),
docs=docs,
)
workflow_instance.compile()
update_wrapper(workflow_instance, fn)
return workflow_instance

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-pandera/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]:
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]:
return transform2(df=transform1(df=invalid_df))

invalid_wf()

# raise error when executing workflow with invalid input
@workflow
def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]:
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]:
@workflow
def wf():
return t3()

wf()
6 changes: 6 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,25 @@ def decompose_unary() -> int:
result = return_true()
return conditional("test").if_(result).then(success()).else_().then(failed())

decompose_unary()

with pytest.raises(AssertionError):

@workflow
def decompose_none() -> int:
return conditional("test").if_(None).then(success()).else_().then(failed())

decompose_none()

with pytest.raises(AssertionError):

@workflow
def decompose_is() -> int:
result = return_true()
return conditional("test").if_(result is True).then(success()).else_().then(failed())

decompose_is()

@workflow
def decompose() -> int:
result = return_true()
Expand Down
7 changes: 6 additions & 1 deletion tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@workflow
def subwf(a: int):
t1(a=a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
subwf(a=a)
return s

@workflow
Expand All @@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]:
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec._nodes) == 6
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
Expand Down
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def wf1(a: int):
def wf2(a: typing.List[int]):
return map_task(wf1)(a=a)

wf2()

lp = LaunchPlan.create("test", wf1)

with pytest.raises(ValueError):
Expand All @@ -167,6 +169,8 @@ def wf2(a: typing.List[int]):
def wf3(a: typing.List[int]):
return map_task(lp)(a=a)

wf3()


def test_inputs_outputs_length():
@task
Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def empty_wf2():
def empty_wf2():
create_node(t2, "foo")

empty_wf2()


def test_more_normal_task():
nt = typing.NamedTuple("OneOutput", t1_str_output=str)
Expand Down Expand Up @@ -141,6 +143,8 @@ def my_wf(a: int) -> str:
t1_node = create_node(t1, a=a)
return t1_node.outputs

my_wf()


def test_runs_before():
@task
Expand Down Expand Up @@ -330,6 +334,8 @@ def t1(a: str) -> str:
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout="foo")

my_wf()


@pytest.mark.parametrize(
"retries,expected",
Expand Down Expand Up @@ -443,3 +449,5 @@ def my_wf(a: str) -> str:
@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)

my_wf()
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def my_wf() -> wf_outputs:
# Note only Namedtuple can be created like this
return wf_outputs(say_hello(), say_hello())

my_wf()


def test_serialized_docstrings():
@task
Expand Down
14 changes: 11 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def my_wf(a: int, b: str) -> (int, str):
d = t2(a=y, b=b)
return x, d

assert len(my_wf._nodes) == 2
assert len(my_wf.nodes) == 2
assert my_wf._nodes[0].id == "n0"
assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0]

assert len(my_wf._output_bindings) == 2
assert len(my_wf.output_bindings) == 2
assert my_wf._output_bindings[0].var == "o0"
assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output"

Expand Down Expand Up @@ -280,18 +280,24 @@ def test_wf_output_mismatch():
def my_wf(a: int, b: str) -> (int, str):
return a

my_wf()

with pytest.raises(AssertionError):

@workflow
def my_wf2(a: int, b: str) -> int:
return a, b # type: ignore

my_wf2()

with pytest.raises(AssertionError):

@workflow
def my_wf3(a: int, b: str) -> int:
return (a,) # type: ignore

my_wf3()

assert context_manager.FlyteContextManager.size() == 1


Expand Down Expand Up @@ -676,7 +682,7 @@ def lister() -> typing.List[str]:
return s

assert len(lister.interface.outputs) == 1
binding_data = lister._output_bindings[0].binding # the property should be named binding_data
binding_data = lister.output_bindings[0].binding # the property should be named binding_data
assert binding_data.collection is not None
assert len(binding_data.collection.bindings) == 10

Expand Down Expand Up @@ -800,6 +806,8 @@ def my_wf(a: int, b: str) -> (int, str):
conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah")
return x, d

my_wf()

assert context_manager.FlyteContextManager.size() == 1


Expand Down
20 changes: 19 additions & 1 deletion tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing_extensions import Annotated # type: ignore

import flytekit.configuration
from flytekit import StructuredDataset, kwtypes
from flytekit import FlyteContextManager, StructuredDataset, kwtypes
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.condition import conditional
from flytekit.core.task import task
from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow
Expand Down Expand Up @@ -156,6 +157,8 @@ def no_outputs_wf():
def one_output_wf() -> int: # type: ignore
t1(a=3)

one_output_wf()


def test_wf_no_output():
@task
Expand Down Expand Up @@ -320,3 +323,18 @@ def test_structured_dataset_wf():
assert_frame_equal(sd_to_schema_wf(), superset_df)
assert_frame_equal(schema_to_sd_wf()[0], subset_df)
assert_frame_equal(schema_to_sd_wf()[1], subset_df)


def test_compile_wf_at_compile_time():
ctx = FlyteContextManager.current_context()
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
):

@workflow
def wf():
t4()

assert ctx.compilation_state is None
2 changes: 2 additions & 0 deletions tests/flytekit/unit/remote/test_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def test_misnamed():
def wf(a: int) -> int:
return ft(b=a)

wf()


def test_calling_lp():
sub_wf_lp = LaunchPlan.get_or_create(sub_wf)
Expand Down

0 comments on commit 09fb418

Please sign in to comment.