From 1b6d605d027470a9b879de52878021548cf7ae9b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 1 Aug 2024 17:42:05 -0700 Subject: [PATCH] reference_task should inherit from PythonTask Signed-off-by: Kevin Su --- flytekit/core/interface.py | 11 ++++++++--- flytekit/core/task.py | 6 +++--- .../flytekit/integration/remote/test_remote.py | 2 +- tests/flytekit/unit/core/test_imperative.py | 4 ++-- tests/flytekit/unit/core/test_references.py | 8 ++++---- tests/flytekit/unit/remote/test_remote.py | 18 +++++++++++++++++- 6 files changed, 35 insertions(+), 14 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index e671347cee..d9cefb3849 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -369,7 +369,9 @@ def transform_interface_to_list_interface( return Interface(inputs=map_inputs, outputs=map_outputs) -def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface: +def transform_function_to_interface( + fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False +) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use for each output parameter, construct the TypedInterface object @@ -382,9 +384,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc return_annotation = type_hints.get("return", None) ctx = FlyteContextManager.current_context() + + # Check if the function has a return statement at compile time locally. + # Skip it if the function is a reference task/workflow since it doesn't have a body. if ( - ctx.execution_state - # Only check if the task/workflow has a return statement at compile time locally. + not is_reference_entity + and ctx.execution_state and ctx.execution_state.mode is None # inspect module does not work correctly with Python <3.10.10. https://github.com/flyteorg/flyte/issues/5608 and sys.version_info >= (3, 10, 10) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index e02034a32e..402862be74 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -11,7 +11,7 @@ from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow -from flytekit.core.base_task import TaskMetadata, TaskResolverMixin +from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask @@ -371,7 +371,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: return wrapper -class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore +class ReferenceTask(ReferenceEntity, PythonTask): # type: ignore """ This is a reference task, the body of the function passed in through the constructor will never be used, only the signature of the function will be. The signature should also match the signature of the task you're referencing, @@ -412,7 +412,7 @@ def reference_task( """ def wrapper(fn) -> ReferenceTask: - interface = transform_function_to_interface(fn) + interface = transform_function_to_interface(fn, is_reference_entity=True) return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs) return wrapper diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index fc57cb7573..7fbc8b90a6 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -399,7 +399,7 @@ def test_execute_reference_task(register): version=VERSION, ) def t1(a: int) -> nt: - return nt(t1_int_output=a + 2, c="world") + ... remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) execution = remote.execute( diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index f361f748b1..aee88e19d1 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -327,7 +327,7 @@ def ref_t1( dataframe: pd.DataFrame, imputation_method: str = "median", ) -> pd.DataFrame: - return dataframe + ... @reference_task( project="flytesnacks", @@ -340,7 +340,7 @@ def ref_t2( split_mask: int, num_features: int, ) -> pd.DataFrame: - return dataframe + ... wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf") wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")]) diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index 732b6951d9..b945027570 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -81,7 +81,7 @@ def test_ref_task_more(): version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: - return "hello" + ... @workflow def wf1(in1: typing.List[str]) -> str: @@ -106,7 +106,7 @@ def test_ref_task_more_2(): version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: - return "hello" + ... @reference_task( project="flytesnacks", @@ -115,7 +115,7 @@ def ref_t1(a: typing.List[str]) -> str: version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t2(a: typing.List[str]) -> str: - return "hello" + ... @workflow def wf1(in1: typing.List[str]) -> str: @@ -435,7 +435,7 @@ def test_ref_dynamic_task(): version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: int) -> str: - return "hello" + ... @task def t2(a: str, b: str) -> str: diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index f4be6c33c1..3852da9a31 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -15,7 +15,7 @@ from mock import ANY, MagicMock, patch import flytekit.configuration -from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow +from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager @@ -527,6 +527,22 @@ def wf(name: str = "union"): version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, mock.ANY, image_spec.image_name()) register_workflow_mock.assert_called_once() + @reference_task( + project="flytesnacks", + domain="development", + name="flytesnacks.examples.basics.basics.workflow.slope", + version="v1", + ) + def ref_basic(x: typing.List[int], y: typing.List[int]) -> float: + ... + + @workflow + def wf1(name: str = "union") -> float: + return ref_basic(x=[1, 2, 3], y=[4, 5, 6]) + + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote.register_script(wf1) + @mock.patch("flytekit.remote.remote.FlyteRemote.client") def test_local_server(mock_client):