From ad1a055d2b59b081bcaab01a975298b8e2372327 Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:55:16 -0800 Subject: [PATCH] [Core feature] Flytekit should support `unsafe` mode for types (#2419) * Enable unsafe type task & workflow Signed-off-by: Mecoli1219 Signed-off-by: Kevin Su Co-authored-by: Kevin Su Signed-off-by: Katrina Rogan --- flytekit/core/interface.py | 9 +- flytekit/core/python_function_task.py | 8 +- flytekit/core/task.py | 5 + flytekit/core/type_engine.py | 7 +- flytekit/core/workflow.py | 11 +- flytekit/types/pickle/pickle.py | 6 +- tests/flytekit/unit/core/test_type_hints.py | 120 +++++++++++++++++++- 7 files changed, 155 insertions(+), 11 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index cbfd08ae2f..5c2ab53f92 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -384,6 +384,7 @@ def transform_function_to_interface( fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False, + pickle_untyped: bool = False, ) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use @@ -395,6 +396,9 @@ def transform_function_to_interface( type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) + # If the return annotation is None and the pickle_untyped is True, we will use it as Any + if return_annotation is None and pickle_untyped: + return_annotation = Any ctx = FlyteContextManager.current_context() @@ -420,7 +424,10 @@ def transform_function_to_interface( for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) if annotation is None: - raise FlyteMissingTypeException(fn=fn, param_name=k) + if not pickle_untyped: + raise FlyteMissingTypeException(fn=fn, param_name=k) + # If the pickle_untyped is True, we will use it as Any + annotation = Any default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future inputs[k] = (annotation, default) # type: ignore diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 7319975aa9..c63199ad04 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -112,6 +112,7 @@ def __init__( node_dependency_hints: Optional[ Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]] ] = None, + pickle_untyped: bool = False, **kwargs, ): """ @@ -125,10 +126,15 @@ def __init__( :param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints: A list of tasks, launchplans, or workflows that this task depends on. This is only for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime. + :param bool pickle_untyped: If set to True, the task will pickle untyped outputs. This is just a convenience + flag to avoid having to specify the output types in the interface. This is not recommended for production + use. """ if task_function is None: raise ValueError("TaskFunction is a required parameter for PythonFunctionTask") - self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function)) + self._native_interface = transform_function_to_interface( + task_function, Docstring(callable_=task_function), pickle_untyped=pickle_untyped + ) mutated_interface = self._native_interface.remove_inputs(ignore_input_vars) name, _, _, _ = extract_task_module(task_function) super().__init__( diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 745f452a83..9709adda08 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -130,6 +130,7 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., + pickle_untyped: bool = ..., ) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -168,6 +169,7 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., + pickle_untyped: bool = ..., ) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... @@ -211,6 +213,7 @@ def task( pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, + pickle_untyped: bool = False, ) -> Union[ Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], @@ -340,6 +343,7 @@ def launch_dynamically(): :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param accelerator: The accelerator to use for this task. + :param pickle_untyped: Boolean that indicates if the task allows unspecified data types. """ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: @@ -375,6 +379,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: pod_template=pod_template, pod_template_name=pod_template_name, accelerator=accelerator, + pickle_untyped=pickle_untyped, ) update_wrapper(task_instance, decorated_fn) return task_instance diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 5102a7df74..160c0595a4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1306,7 +1306,12 @@ def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expecte "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if (python_val is None and python_type != type(None)) and expected and expected.union_type is None: + if ( + (python_val is None and python_type != type(None)) + and expected + and expected.union_type is None + and python_type is not Any + ): raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") @classmethod diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index de0f620e96..9cccf19e58 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -668,11 +668,14 @@ def __init__( docstring: Optional[Docstring] = None, on_failure: Optional[Union[WorkflowBase, Task]] = None, docs: Optional[Documentation] = None, + pickle_untyped: bool = False, default_options: Optional[Options] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function - native_interface = transform_function_to_interface(workflow_function, docstring=docstring) + native_interface = transform_function_to_interface( + workflow_function, docstring=docstring, pickle_untyped=pickle_untyped + ) # TODO do we need this - can this not be in launchplan only? # This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or @@ -837,6 +840,7 @@ def workflow( interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., + pickle_untyped: bool = ..., default_options: Optional[Options] = ..., ) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ... @@ -848,6 +852,7 @@ def workflow( interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., + pickle_untyped: bool = ..., default_options: Optional[Options] = ..., ) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ... @@ -858,6 +863,7 @@ def workflow( interruptible: bool = False, on_failure: Optional[Union[WorkflowBase, Task]] = None, docs: Optional[Documentation] = None, + pickle_untyped: bool = False, default_options: Optional[Options] = None, ) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ @@ -890,6 +896,8 @@ def workflow( :param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of the current workflow, with an additional parameter called `error` Error :param docs: Description entity for the workflow + :param pickle_untyped: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing + the workflow. This is not recommended for general use. :param default_options: Default options for the workflow when creating a default launch plan. Currently only the labels and annotations are allowed to be set as defaults. """ @@ -906,6 +914,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow: docstring=Docstring(callable_=fn), on_failure=on_failure, docs=docs, + pickle_untyped=pickle_untyped, default_options=default_options, ) update_wrapper(workflow_instance, fn) diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index b49b03205a..bbe5b71eca 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -78,7 +78,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p return await FlytePickle.from_pickle(uri) async def async_to_literal( - self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, ) -> Literal: if python_val is None: raise AssertionError("Cannot pickle None Value.") diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 6664deaaeb..1074423baf 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -20,8 +20,7 @@ import flytekit import flytekit.configuration from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task -from flytekit.configuration import (FastSerializationSettings, Image, - ImageConfig) +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional from flytekit.core.constants import MESSAGEPACK @@ -33,11 +32,18 @@ from flytekit.core.resources import Resources from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import patch, task_mock -from flytekit.core.type_engine import (RestrictedTypeError, SimpleTransformer, - TypeEngine, TypeTransformerFailedError) +from flytekit.core.type_engine import ( + RestrictedTypeError, + SimpleTransformer, + TypeEngine, + TypeTransformerFailedError, +) from flytekit.core.workflow import workflow -from flytekit.exceptions.user import (FlyteFailureNodeInputMismatchException, - FlyteValidationException) +from flytekit.exceptions.user import ( + FlyteFailureNodeInputMismatchException, + FlyteValidationException, + FlyteMissingTypeException, +) from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter @@ -2207,3 +2213,105 @@ def my_wf(a: int, retries: int) -> int: with pytest.raises(AssertionError): my_wf(a=1, retries=1) + + +def test_pickle_untyped_input_wf_and_task(): + @task(pickle_untyped=True) + def t1(a) -> int: + if type(a) == int: + return a + 1 + return 0 + + with pytest.raises(FlyteMissingTypeException): + + @task + def t2_wo_pickle_untyped(a) -> int: + return a + 1 + + @workflow(pickle_untyped=True) + def wf1_with_pickle_untyped(a) -> int: + return t1(a=a) + + assert wf1_with_pickle_untyped(a=1) == 2 + assert wf1_with_pickle_untyped(a="1") == 0 + + with pytest.raises(FlyteMissingTypeException): + + @workflow + def wf1_wo_pickle_untyped(a) -> int: + return t1(a=a) + + +def test_pickle_untyped_wf_and_task(): + @task(pickle_untyped=True) + def t1(a): + if type(a) != int: + return "t1" + return a + 1 + + @task(pickle_untyped=True) + def t2(a): + if type(a) != int: + return "t2" + return a + 2 + + @workflow(pickle_untyped=True) + def wf1_with_pickle_untyped(a): + a1 = t1(a=a) + return t2(a=a1) + + assert wf1_with_pickle_untyped(a=1) == 4 + assert wf1_with_pickle_untyped(a="1") == "t2" + + +def test_wf_with_pickle_untyped_and_regular_tasks(): + @task(pickle_untyped=True) + def t1(a): + if type(a) != int: + return "t1" + return a + 1 + + @task + def t2(a: typing.Any) -> typing.Any: + if type(a) != int: + return "t2" + return a + 2 + + @workflow(pickle_untyped=True) + def wf1_with_pickle_untyped(a): + a1 = t1(a=a) + return t2(a=a1) + + assert wf1_with_pickle_untyped(a=1) == 4 + assert wf1_with_pickle_untyped(a="1") == "t2" + + @workflow(pickle_untyped=True) + def wf2_with_pickle_untyped(a): + a1 = t2(a=a) + return t1(a=a1) + + assert wf2_with_pickle_untyped(a=1) == 4 + assert wf2_with_pickle_untyped(a="1") == "t1" + + +def test_pickle_untyped_task_with_specified_input(): + @task(pickle_untyped=True) + def t1(a, b: typing.Any): + if type(a) != int: + if type(b) != int: + return "t1" + else: + return b + elif type(b) != int: + return a + return a + b + + @workflow(pickle_untyped=True) + def wf1_with_pickle_untyped(a: typing.Any, b): + r = t1(a=a, b=b) + return r + + assert wf1_with_pickle_untyped(a=1, b=2) == 3 + assert wf1_with_pickle_untyped(a="1", b=2) == 2 + assert wf1_with_pickle_untyped(a=1, b="2") == 1 + assert wf1_with_pickle_untyped(a="1", b="2") == "t1"