Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reference_task should inherit from PythonTask #2643

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def transform_interface_to_list_interface(
return Interface(inputs=map_inputs, outputs=map_outputs)


def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
def transform_function_to_interface(
fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
Expand All @@ -382,9 +384,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
return_annotation = type_hints.get("return", None)

ctx = FlyteContextManager.current_context()

# Check if the function has a return statement at compile time locally.
# Skip it if the function is a reference task/workflow since it doesn't have a body.
if (
ctx.execution_state
# Only check if the task/workflow has a return statement at compile time locally.
not is_reference_entity
and ctx.execution_state
and ctx.execution_state.mode is None
# inspect module does not work correctly with Python <3.10.10. https://github.com/flyteorg/flyte/issues/5608
and sys.version_info >= (3, 10, 10)
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
Expand Down Expand Up @@ -371,7 +371,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
return wrapper


class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore
class ReferenceTask(ReferenceEntity, PythonTask): # type: ignore
"""
This is a reference task, the body of the function passed in through the constructor will never be used, only the
signature of the function will be. The signature should also match the signature of the task you're referencing,
Expand Down Expand Up @@ -412,7 +412,7 @@ def reference_task(
"""

def wrapper(fn) -> ReferenceTask:
interface = transform_function_to_interface(fn)
interface = transform_function_to_interface(fn, is_reference_entity=True)
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper
2 changes: 1 addition & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_execute_reference_task(register):
version=VERSION,
)
def t1(a: int) -> nt:
return nt(t1_int_output=a + 2, c="world")
...

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def ref_t1(
dataframe: pd.DataFrame,
imputation_method: str = "median",
) -> pd.DataFrame:
return dataframe
...

@reference_task(
project="flytesnacks",
Expand All @@ -340,7 +340,7 @@ def ref_t2(
split_mask: int,
num_features: int,
) -> pd.DataFrame:
return dataframe
...

wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf")
wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")])
Expand Down
8 changes: 4 additions & 4 deletions tests/flytekit/unit/core/test_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_ref_task_more():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: typing.List[str]) -> str:
return "hello"
...

@workflow
def wf1(in1: typing.List[str]) -> str:
Expand All @@ -106,7 +106,7 @@ def test_ref_task_more_2():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: typing.List[str]) -> str:
return "hello"
...

@reference_task(
project="flytesnacks",
Expand All @@ -115,7 +115,7 @@ def ref_t1(a: typing.List[str]) -> str:
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t2(a: typing.List[str]) -> str:
return "hello"
...

@workflow
def wf1(in1: typing.List[str]) -> str:
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_ref_dynamic_task():
version="553018f39e519bdb2597b652639c30ce16b99c79",
)
def ref_t1(a: int) -> str:
return "hello"
...

@task
def t2(a: str, b: str) -> str:
Expand Down
18 changes: 17 additions & 1 deletion tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
Expand Down Expand Up @@ -527,6 +527,22 @@ def wf(name: str = "union"):
version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, mock.ANY, image_spec.image_name())
register_workflow_mock.assert_called_once()

@reference_task(
project="flytesnacks",
domain="development",
name="flytesnacks.examples.basics.basics.workflow.slope",
version="v1",
)
def ref_basic(x: typing.List[int], y: typing.List[int]) -> float:
...

@workflow
def wf1(name: str = "union") -> float:
return ref_basic(x=[1, 2, 3], y=[4, 5, 6])

flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote.register_script(wf1)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
Expand Down
Loading