diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 7236286f29..8be9d8ccae 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1061,7 +1061,6 @@ def flyte_entity_call_handler( ) ctx = FlyteContextManager.current_context() - if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return create_and_link_node(ctx, entity=entity, **kwargs) elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index f8ba257d7e..b716eb7114 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -230,10 +230,12 @@ def interface(self) -> _interface_models.TypedInterface: @property def output_bindings(self) -> List[_literal_models.Binding]: + self.compile() return self._output_bindings @property def nodes(self) -> List[Node]: + self.compile() return self._nodes def __repr__(self): @@ -257,11 +259,15 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis # Get default arguments and override with kwargs passed in input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) + self.compile() return flyte_entity_call_handler(self, *args, **input_kwargs) def execute(self, **kwargs): raise Exception("Should not be called") + def compile(self, **kwargs): + pass + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. @@ -272,6 +278,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. + self.compile() function_outputs = self.execute(**kwargs) # First handle the empty return case. @@ -612,6 +619,7 @@ def __init__( python_interface=native_interface, docs=docs, ) + self.compiled = False @property def function(self): @@ -625,6 +633,9 @@ def compile(self, **kwargs): Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ + if self.compiled: + return + self.compiled = True ctx = FlyteContextManager.current_context() self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface) all_nodes = [] @@ -759,7 +770,6 @@ def wrapper(fn): docstring=Docstring(callable_=fn), docs=docs, ) - workflow_instance.compile() update_wrapper(workflow_instance, fn) return workflow_instance diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index a16d80d781..cc9b26c4fa 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]: def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: return transform2(df=transform1(df=invalid_df)) + invalid_wf() + # raise error when executing workflow with invalid input @workflow def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 8eb105777e..6fe2b01e61 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -196,3 +196,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]: @workflow def wf(): return t3() + + wf() diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 3ab0026cf3..7b0b292baa 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -167,12 +167,16 @@ def decompose_unary() -> int: result = return_true() return conditional("test").if_(result).then(success()).else_().then(failed()) + decompose_unary() + with pytest.raises(AssertionError): @workflow def decompose_none() -> int: return conditional("test").if_(None).then(success()).else_().then(failed()) + decompose_none() + with pytest.raises(AssertionError): @workflow @@ -180,6 +184,8 @@ def decompose_is() -> int: result = return_true() return conditional("test").if_(result is True).then(success()).else_().then(failed()) + decompose_is() + @workflow def decompose() -> int: result = return_true() diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index cccf406c71..b9b0ebd3fa 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -34,11 +34,16 @@ def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) + @workflow + def subwf(a: int): + t1(a=a) + @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) + subwf(a=a) return s @workflow @@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]: ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map) - assert len(dynamic_job_spec._nodes) == 5 + assert len(dynamic_job_spec._nodes) == 6 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 14c9620ae6..95927873d0 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -159,6 +159,8 @@ def wf1(a: int): def wf2(a: typing.List[int]): return map_task(wf1)(a=a) + wf2() + lp = LaunchPlan.create("test", wf1) with pytest.raises(ValueError): @@ -167,6 +169,8 @@ def wf2(a: typing.List[int]): def wf3(a: typing.List[int]): return map_task(lp)(a=a) + wf3() + def test_inputs_outputs_length(): @task diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 0858a08007..b8d32e4c8d 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -96,6 +96,8 @@ def empty_wf2(): def empty_wf2(): create_node(t2, "foo") + empty_wf2() + def test_more_normal_task(): nt = typing.NamedTuple("OneOutput", t1_str_output=str) @@ -141,6 +143,8 @@ def my_wf(a: int) -> str: t1_node = create_node(t1, a=a) return t1_node.outputs + my_wf() + def test_runs_before(): @task @@ -330,6 +334,8 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(timeout="foo") + my_wf() + @pytest.mark.parametrize( "retries,expected", @@ -443,3 +449,5 @@ def my_wf(a: str) -> str: @workflow def my_wf(a: str) -> str: return t1(a=a).with_overrides(task_config=None) + + my_wf() diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 8deb406fb2..d47d57969c 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -440,6 +440,8 @@ def my_wf() -> wf_outputs: # Note only Namedtuple can be created like this return wf_outputs(say_hello(), say_hello()) + my_wf() + def test_serialized_docstrings(): @task diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index f41a05ea32..373a536769 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -176,11 +176,11 @@ def my_wf(a: int, b: str) -> (int, str): d = t2(a=y, b=b) return x, d - assert len(my_wf._nodes) == 2 + assert len(my_wf.nodes) == 2 assert my_wf._nodes[0].id == "n0" assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0] - assert len(my_wf._output_bindings) == 2 + assert len(my_wf.output_bindings) == 2 assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" @@ -280,18 +280,24 @@ def test_wf_output_mismatch(): def my_wf(a: int, b: str) -> (int, str): return a + my_wf() + with pytest.raises(AssertionError): @workflow def my_wf2(a: int, b: str) -> int: return a, b # type: ignore + my_wf2() + with pytest.raises(AssertionError): @workflow def my_wf3(a: int, b: str) -> int: return (a,) # type: ignore + my_wf3() + assert context_manager.FlyteContextManager.size() == 1 @@ -676,7 +682,7 @@ def lister() -> typing.List[str]: return s assert len(lister.interface.outputs) == 1 - binding_data = lister._output_bindings[0].binding # the property should be named binding_data + binding_data = lister.output_bindings[0].binding # the property should be named binding_data assert binding_data.collection is not None assert len(binding_data.collection.bindings) == 10 @@ -800,6 +806,8 @@ def my_wf(a: int, b: str) -> (int, str): conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah") return x, d + my_wf() + assert context_manager.FlyteContextManager.size() == 1 diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 23b9d0631e..90a8c712e6 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -7,8 +7,9 @@ from typing_extensions import Annotated # type: ignore import flytekit.configuration -from flytekit import StructuredDataset, kwtypes +from flytekit import FlyteContextManager, StructuredDataset, kwtypes from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow @@ -156,6 +157,8 @@ def no_outputs_wf(): def one_output_wf() -> int: # type: ignore t1(a=3) + one_output_wf() + def test_wf_no_output(): @task @@ -320,3 +323,18 @@ def test_structured_dataset_wf(): assert_frame_equal(sd_to_schema_wf(), superset_df) assert_frame_equal(schema_to_sd_wf()[0], subset_df) assert_frame_equal(schema_to_sd_wf()[1], subset_df) + + +def test_compile_wf_at_compile_time(): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ): + + @workflow + def wf(): + t4() + + assert ctx.compilation_state is None diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 34e4f8e8b8..289fba37d7 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -75,6 +75,8 @@ def test_misnamed(): def wf(a: int) -> int: return ft(b=a) + wf() + def test_calling_lp(): sub_wf_lp = LaunchPlan.get_or_create(sub_wf)