Skip to content

Commit

Permalink
reference_task should inherit from PythonTask
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Aug 2, 2024
1 parent 1b67f16 commit 1b6d605
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
11 changes: 8 additions & 3 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def transform_interface_to_list_interface(
return Interface(inputs=map_inputs, outputs=map_outputs)


def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
def transform_function_to_interface(
fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
Expand All @@ -382,9 +384,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
return_annotation = type_hints.get("return", None)

ctx = FlyteContextManager.current_context()

# Check if the function has a return statement at compile time locally.
# Skip it if the function is a reference task/workflow since it doesn't have a body.
if (
ctx.execution_state
# Only check if the task/workflow has a return statement at compile time locally.
not is_reference_entity
and ctx.execution_state
and ctx.execution_state.mode is None
# inspect module does not work correctly with Python <3.10.10. https://github.com/flyteorg/flyte/issues/5608
and sys.version_info >= (3, 10, 10)
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
Expand Down Expand Up @@ -371,7 +371,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
return wrapper


class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore
class ReferenceTask(ReferenceEntity, PythonTask): # type: ignore
"""
This is a reference task, the body of the function passed in through the constructor will never be used, only the
signature of the function will be. The signature should also match the signature of the task you're referencing,
Expand Down Expand Up @@ -412,7 +412,7 @@ def reference_task(
"""

def wrapper(fn) -> ReferenceTask:
interface = transform_function_to_interface(fn)
interface = transform_function_to_interface(fn, is_reference_entity=True)
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper
2 changes: 1 addition & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_execute_reference_task(register):
version=VERSION,
)
def t1(a: int) -> nt:
return nt(t1_int_output=a + 2, c="world")
...

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def ref_t1(
dataframe: pd.DataFrame,
imputation_method: str = "median",
) -> pd.DataFrame:
return dataframe
...

@reference_task(
project="flytesnacks",
Expand All @@ -340,7 +340,7 @@ def ref_t2(
split_mask: int,
num_features: int,
) -> pd.DataFrame:
return dataframe
...

wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf")
wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")])
Expand Down
8 changes: 4 additions & 4 deletions tests/flytekit/unit/core/test_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_ref_task_more():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: typing.List[str]) -> str:
return "hello"
...

@workflow
def wf1(in1: typing.List[str]) -> str:
Expand All @@ -106,7 +106,7 @@ def test_ref_task_more_2():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: typing.List[str]) -> str:
return "hello"
...

@reference_task(
project="flytesnacks",
Expand All @@ -115,7 +115,7 @@ def ref_t1(a: typing.List[str]) -> str:
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t2(a: typing.List[str]) -> str:
return "hello"
...

@workflow
def wf1(in1: typing.List[str]) -> str:
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_ref_dynamic_task():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: int) -> str:
return "hello"
...

@task
def t2(a: str, b: str) -> str:
Expand Down
18 changes: 17 additions & 1 deletion tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
Expand Down Expand Up @@ -527,6 +527,22 @@ def wf(name: str = "union"):
version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, mock.ANY, image_spec.image_name())
register_workflow_mock.assert_called_once()

@reference_task(
project="flytesnacks",
domain="development",
name="flytesnacks.examples.basics.basics.workflow.slope",
version="v1",
)
def ref_basic(x: typing.List[int], y: typing.List[int]) -> float:
...

@workflow
def wf1(name: str = "union") -> float:
return ref_basic(x=[1, 2, 3], y=[4, 5, 6])

flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote.register_script(wf1)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
Expand Down

0 comments on commit 1b6d605

Please sign in to comment.