Skip to content

Commit

Permalink
[Core feature] Flytekit should support unsafe mode for types (#2419)
Browse files Browse the repository at this point in the history
* Enable unsafe type task & workflow
Signed-off-by: Mecoli1219 <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
Mecoli1219 authored Nov 6, 2024
1 parent b1cb262 commit 55e9f8d
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 11 deletions.
9 changes: 8 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def transform_function_to_interface(
fn: typing.Callable,
docstring: Optional[Docstring] = None,
is_reference_entity: bool = False,
pickle_untyped: bool = False,
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
Expand All @@ -395,6 +396,9 @@ def transform_function_to_interface(
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
# If the return annotation is None and the pickle_untyped is True, we will use it as Any
if return_annotation is None and pickle_untyped:
return_annotation = Any

ctx = FlyteContextManager.current_context()

Expand All @@ -420,7 +424,10 @@ def transform_function_to_interface(
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
raise FlyteMissingTypeException(fn=fn, param_name=k)
if not pickle_untyped:
raise FlyteMissingTypeException(fn=fn, param_name=k)
# If the pickle_untyped is True, we will use it as Any
annotation = Any
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
pickle_untyped: bool = False,
**kwargs,
):
"""
Expand All @@ -125,10 +126,15 @@ def __init__(
:param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints:
A list of tasks, launchplans, or workflows that this task depends on. This is only
for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime.
:param bool pickle_untyped: If set to True, the task will pickle untyped outputs. This is just a convenience
flag to avoid having to specify the output types in the interface. This is not recommended for production
use.
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function))
self._native_interface = transform_function_to_interface(
task_function, Docstring(callable_=task_function), pickle_untyped=pickle_untyped
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
name, _, _, _ = extract_task_module(task_function)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -168,6 +169,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +213,7 @@ def task(
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -340,6 +343,7 @@ def launch_dynamically():
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -375,6 +379,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
pickle_untyped=pickle_untyped,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,12 @@ def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expecte
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if (python_val is None and python_type != type(None)) and expected and expected.union_type is None:
if (
(python_val is None and python_type != type(None))
and expected
and expected.union_type is None
and python_type is not Any
):
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,14 @@ def __init__(
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
native_interface = transform_function_to_interface(workflow_function, docstring=docstring)
native_interface = transform_function_to_interface(
workflow_function, docstring=docstring, pickle_untyped=pickle_untyped
)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -837,6 +840,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...

Expand All @@ -848,6 +852,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...

Expand All @@ -858,6 +863,7 @@ def workflow(
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
Expand Down Expand Up @@ -890,6 +896,8 @@ def workflow(
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
:param pickle_untyped: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing
the workflow. This is not recommended for general use.
:param default_options: Default options for the workflow when creating a default launch plan. Currently only
the labels and annotations are allowed to be set as defaults.
"""
Expand All @@ -906,6 +914,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
pickle_untyped=pickle_untyped,
default_options=default_options,
)
update_wrapper(workflow_instance, fn)
Expand Down
6 changes: 5 additions & 1 deletion flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
return await FlytePickle.from_pickle(uri)

async def async_to_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
self,
ctx: FlyteContext,
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
Expand Down
120 changes: 114 additions & 6 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import flytekit
import flytekit.configuration
from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task
from flytekit.configuration import (FastSerializationSettings, Image,
ImageConfig)
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core import context_manager, launch_plan, promise
from flytekit.core.condition import conditional
from flytekit.core.constants import MESSAGEPACK
Expand All @@ -33,11 +32,18 @@
from flytekit.core.resources import Resources
from flytekit.core.task import TaskMetadata, task
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import (RestrictedTypeError, SimpleTransformer,
TypeEngine, TypeTransformerFailedError)
from flytekit.core.type_engine import (
RestrictedTypeError,
SimpleTransformer,
TypeEngine,
TypeTransformerFailedError,
)
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import (FlyteFailureNodeInputMismatchException,
FlyteValidationException)
from flytekit.exceptions.user import (
FlyteFailureNodeInputMismatchException,
FlyteValidationException,
FlyteMissingTypeException,
)
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
Expand Down Expand Up @@ -2207,3 +2213,105 @@ def my_wf(a: int, retries: int) -> int:

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)


def test_pickle_untyped_input_wf_and_task():
@task(pickle_untyped=True)
def t1(a) -> int:
if type(a) == int:
return a + 1
return 0

with pytest.raises(FlyteMissingTypeException):

@task
def t2_wo_pickle_untyped(a) -> int:
return a + 1

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a) -> int:
return t1(a=a)

assert wf1_with_pickle_untyped(a=1) == 2
assert wf1_with_pickle_untyped(a="1") == 0

with pytest.raises(FlyteMissingTypeException):

@workflow
def wf1_wo_pickle_untyped(a) -> int:
return t1(a=a)


def test_pickle_untyped_wf_and_task():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return "t1"
return a + 1

@task(pickle_untyped=True)
def t2(a):
if type(a) != int:
return "t2"
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") == "t2"


def test_wf_with_pickle_untyped_and_regular_tasks():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return "t1"
return a + 1

@task
def t2(a: typing.Any) -> typing.Any:
if type(a) != int:
return "t2"
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") == "t2"

@workflow(pickle_untyped=True)
def wf2_with_pickle_untyped(a):
a1 = t2(a=a)
return t1(a=a1)

assert wf2_with_pickle_untyped(a=1) == 4
assert wf2_with_pickle_untyped(a="1") == "t1"


def test_pickle_untyped_task_with_specified_input():
@task(pickle_untyped=True)
def t1(a, b: typing.Any):
if type(a) != int:
if type(b) != int:
return "t1"
else:
return b
elif type(b) != int:
return a
return a + b

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a: typing.Any, b):
r = t1(a=a, b=b)
return r

assert wf1_with_pickle_untyped(a=1, b=2) == 3
assert wf1_with_pickle_untyped(a="1", b=2) == 2
assert wf1_with_pickle_untyped(a=1, b="2") == 1
assert wf1_with_pickle_untyped(a="1", b="2") == "t1"

0 comments on commit 55e9f8d

Please sign in to comment.