Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Flytekit should support unsafe mode for types #2419

Merged
merged 19 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def transform_function_to_interface(
fn: typing.Callable,
docstring: Optional[Docstring] = None,
is_reference_entity: bool = False,
pickle_untyped: bool = False,
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
Expand All @@ -395,6 +396,9 @@ def transform_function_to_interface(
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
# If the return annotation is None and the pickle_untyped is True, we will use it as Any
if return_annotation is None and pickle_untyped:
return_annotation = Any

ctx = FlyteContextManager.current_context()

Expand All @@ -420,7 +424,10 @@ def transform_function_to_interface(
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
raise FlyteMissingTypeException(fn=fn, param_name=k)
if not pickle_untyped:
raise FlyteMissingTypeException(fn=fn, param_name=k)
# If the pickle_untyped is True, we will use it as Any
annotation = Any
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
pickle_untyped: bool = False,
Copy link
Member

@thomasjpfan thomasjpfan Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 Since you proposed the safe parameter name, are you okay with using pickle_untyped?

My concerned is that safe=False could mean so many unsafe behavior. pickle_untyped is explicit about the behavior.

**kwargs,
):
"""
Expand All @@ -125,10 +126,15 @@ def __init__(
:param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints:
A list of tasks, launchplans, or workflows that this task depends on. This is only
for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime.
:param bool pickle_untyped: If set to True, the task will pickle untyped outputs. This is just a convenience
flag to avoid having to specify the output types in the interface. This is not recommended for production
use.
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function))
self._native_interface = transform_function_to_interface(
task_function, Docstring(callable_=task_function), pickle_untyped=pickle_untyped
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
name, _, _, _ = extract_task_module(task_function)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -168,6 +169,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +213,7 @@ def task(
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -340,6 +343,7 @@ def launch_dynamically():
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -375,6 +379,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
pickle_untyped=pickle_untyped,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,12 @@ def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expecte
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if (python_val is None and python_type != type(None)) and expected and expected.union_type is None:
if (
(python_val is None and python_type != type(None))
and expected
and expected.union_type is None
and python_type is not Any
):
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,14 @@ def __init__(
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
native_interface = transform_function_to_interface(workflow_function, docstring=docstring)
native_interface = transform_function_to_interface(
workflow_function, docstring=docstring, pickle_untyped=pickle_untyped
)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -837,6 +840,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...

Expand All @@ -848,6 +852,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...

Expand All @@ -858,6 +863,7 @@ def workflow(
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
Expand Down Expand Up @@ -890,6 +896,8 @@ def workflow(
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
:param pickle_untyped: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing
the workflow. This is not recommended for general use.
:param default_options: Default options for the workflow when creating a default launch plan. Currently only
the labels and annotations are allowed to be set as defaults.
"""
Expand All @@ -906,6 +914,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
pickle_untyped=pickle_untyped,
default_options=default_options,
)
update_wrapper(workflow_instance, fn)
Expand Down
6 changes: 5 additions & 1 deletion flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
return await FlytePickle.from_pickle(uri)

async def async_to_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
self,
ctx: FlyteContext,
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
Expand Down
120 changes: 114 additions & 6 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import flytekit
import flytekit.configuration
from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task
from flytekit.configuration import (FastSerializationSettings, Image,
ImageConfig)
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core import context_manager, launch_plan, promise
from flytekit.core.condition import conditional
from flytekit.core.constants import MESSAGEPACK
Expand All @@ -33,11 +32,18 @@
from flytekit.core.resources import Resources
from flytekit.core.task import TaskMetadata, task
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import (RestrictedTypeError, SimpleTransformer,
TypeEngine, TypeTransformerFailedError)
from flytekit.core.type_engine import (
RestrictedTypeError,
SimpleTransformer,
TypeEngine,
TypeTransformerFailedError,
)
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import (FlyteFailureNodeInputMismatchException,
FlyteValidationException)
from flytekit.exceptions.user import (
FlyteFailureNodeInputMismatchException,
FlyteValidationException,
FlyteMissingTypeException,
)
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
Expand Down Expand Up @@ -2207,3 +2213,105 @@ def my_wf(a: int, retries: int) -> int:

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)


def test_pickle_untyped_input_wf_and_task():
@task(pickle_untyped=True)
def t1(a) -> int:
if type(a) == int:
return a + 1
return 0

with pytest.raises(FlyteMissingTypeException):

@task
def t2_wo_pickle_untyped(a) -> int:
return a + 1

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a) -> int:
return t1(a=a)

assert wf1_with_pickle_untyped(a=1) == 2
assert wf1_with_pickle_untyped(a="1") == 0

with pytest.raises(FlyteMissingTypeException):

@workflow
def wf1_wo_pickle_untyped(a) -> int:
return t1(a=a)


def test_pickle_untyped_wf_and_task():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return "t1"
return a + 1

@task(pickle_untyped=True)
def t2(a):
if type(a) != int:
return "t2"
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") == "t2"


def test_wf_with_pickle_untyped_and_regular_tasks():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return "t1"
return a + 1

@task
def t2(a: typing.Any) -> typing.Any:
if type(a) != int:
return "t2"
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") == "t2"

@workflow(pickle_untyped=True)
def wf2_with_pickle_untyped(a):
a1 = t2(a=a)
return t1(a=a1)

assert wf2_with_pickle_untyped(a=1) == 4
assert wf2_with_pickle_untyped(a="1") == "t1"


def test_pickle_untyped_task_with_specified_input():
@task(pickle_untyped=True)
def t1(a, b: typing.Any):
if type(a) != int:
if type(b) != int:
return "t1"
else:
return b
elif type(b) != int:
return a
return a + b

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a: typing.Any, b):
r = t1(a=a, b=b)
return r

assert wf1_with_pickle_untyped(a=1, b=2) == 3
assert wf1_with_pickle_untyped(a="1", b=2) == 2
assert wf1_with_pickle_untyped(a=1, b="2") == 1
assert wf1_with_pickle_untyped(a="1", b="2") == "t1"
Loading