Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Flytekit should support unsafe mode for types #2419

Merged
merged 19 commits into from
Nov 6, 2024
Merged
10 changes: 9 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ def transform_interface_to_list_interface(
return Interface(inputs=map_inputs, outputs=map_outputs)


def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
def transform_function_to_interface(
fn: typing.Callable, docstring: Optional[Docstring] = None, unsafe: bool = False
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
Expand All @@ -371,13 +373,19 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
# If the return annotation is None and the unsafe is True, we will use it as Any
if return_annotation is None and unsafe:
return_annotation = Any

outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
# If the annotation is None and the unsafe is True, we will use it as Any
if annotation is None and unsafe:
annotation = Any
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
unsafe: bool = False,
**kwargs,
):
"""
Expand All @@ -129,7 +130,9 @@ def __init__(
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function))
self._native_interface = transform_function_to_interface(
task_function, Docstring(callable_=task_function), unsafe
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
name, _, _, _ = extract_task_module(task_function)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
unsafe: bool = ...,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to have a more explicit name. Something like:

Suggested change
unsafe: bool = ...,
pickle_untyped: bool = ...,

) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -154,6 +155,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
unsafe: bool = ...,
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ...


Expand Down Expand Up @@ -190,6 +192,7 @@ def task(
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
unsafe: bool = False,
) -> Union[
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
Expand Down Expand Up @@ -311,6 +314,7 @@ def launch_dynamically():
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param unsafe: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -343,6 +347,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
unsafe=unsafe,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected and expected.union_type is None:
if python_val is None and expected and expected.union_type is None and python_type is not Any:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,10 +650,11 @@ def __init__(
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
unsafe: bool = False,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
native_interface = transform_function_to_interface(workflow_function, docstring=docstring)
native_interface = transform_function_to_interface(workflow_function, docstring=docstring, unsafe=unsafe)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -804,6 +805,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
unsafe: bool = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...


Expand All @@ -823,6 +825,7 @@ def workflow(
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
unsafe: bool = False,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
Expand Down Expand Up @@ -854,6 +857,8 @@ def workflow(
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
:param unsafe: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing
the workflow. This is not recommended for general use.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
Expand All @@ -868,6 +873,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
unsafe=unsafe,
)
update_wrapper(workflow_instance, fn)
return workflow_instance
Expand Down
11 changes: 7 additions & 4 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import typing
from typing import Type
from typing import Optional, Type

import cloudpickle

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Void
from flytekit.models.types import LiteralType

T = typing.TypeVar("T")
Expand Down Expand Up @@ -86,13 +86,16 @@ def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
if lv.scalar.blob is None:
return None
uri = lv.scalar.blob.uri
return FlytePickle.from_pickle(uri)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
# raise AssertionError("Cannot pickle None Value.")
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return Literal(scalar=Scalar(none_type=Void()))
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,9 +2032,9 @@ def __init__(self, number: int):
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert Foo(1).number == pv.number

with pytest.raises(AssertionError, match="Cannot pickle None Value"):
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
TypeEngine.to_literal(ctx, None, FlytePickle, lt)
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
lv = TypeEngine.to_literal(ctx, None, FlytePickle, lt)
assert lv.scalar.none_type == Void()

with pytest.raises(
AssertionError,
Expand Down
122 changes: 119 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flytekit.core.resources import Resources
from flytekit.core.task import TaskMetadata, task
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine
from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteValidationException
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -80,7 +80,9 @@ def test_forwardref_namedtuple_output():
# This test case tests typing.NamedTuple outputs for cases where eg.
# from __future__ import annotations is enabled, such that all type hints become ForwardRef
@task
def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
def my_task(
a: int,
) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
ctx = flytekit.current_context()
assert str(ctx.execution_id) == "ex:local:local:local"
return a + 2, "hello world"
Expand Down Expand Up @@ -1915,7 +1917,12 @@ def wf() -> pd.DataFrame:

df = wf()

expected_df = pd.DataFrame(data={"col1": [1 + 10 + 100, 2 + 20 + 200], "col2": [3 + 30 + 300, 4 + 40 + 400]})
expected_df = pd.DataFrame(
data={
"col1": [1 + 10 + 100, 2 + 20 + 200],
"col2": [3 + 30 + 300, 4 + 40 + 400],
}
)
assert expected_df.equals(df)


Expand Down Expand Up @@ -2000,3 +2007,112 @@ def my_wf(a: int, retries: int) -> int:

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)


def test_unsafe_input_wf_and_task():
@task(unsafe=True)
def t1(a) -> int:
if type(a) == int:
return a + 1
return 0

@task
def t2_wo_unsafe(a) -> int:
return a + 1

@workflow(unsafe=True)
def wf1_with_unsafe(a) -> int:
return t1(a=a)

assert wf1_with_unsafe(a=1) == 2
assert wf1_with_unsafe(a="1") == 0
assert wf1_with_unsafe(a=None) == 0

@workflow
def wf1_wo_unsafe(a) -> int:
return t1(a=a)

@workflow
def wf1_wo_unsafe2(a: int) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test this workflow in the sandbox cluster?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I only test it on my local environment. Can I ask how to try it?

Copy link
Member

@Future-Outlier Future-Outlier May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR can help you, you can either use imageSpec or Dockerfile to create your own image.
#1870

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from flytekit import task, Secret, workflow, ImageSpec

flytekit_dev_version = "https://github.com/Mecoli1219/flytekit.git@c7c9a2ad2bdf6799ec324955a023281cca030038"
image = ImageSpec(
    registry="localhost:30000",
    platform="linux/amd64",
    apt_packages=["git"],
    packages=[
        f"git+{flytekit_dev_version}",
    ],
)


@task(unsafe=True, container_image=image)
def t1(a) -> int:
    if type(a) == int:
        return a + 1
    return 0


@workflow(unsafe=True)
def wf1_with_unsafe(a) -> int:
    return t1(a=a)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above code will fail. The reason is detailed in flyteorg/flyte#5261.

Copy link
Contributor Author

@Mecoli1219 Mecoli1219 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from flytekit import task, Secret, workflow, ImageSpec
from typing import Tuple

flytekit_dev_version = "https://github.com/Mecoli1219/flytekit.git@9885189b2af95c0c0dfa94ff92b71e865d95cf80"
image = ImageSpec(
    registry="localhost:30000",
    platform="linux/amd64",
    builder="fast-builder",
    apt_packages=["git"],
    packages=[
        f"git+{flytekit_dev_version}",
    ],
)


@task(unsafe=True, container_image=image)
def t1(a) -> int:
    if type(a) == int:
        return a + 1
    return 0


@workflow
def wf1_with_unsafe() -> Tuple[int, int, int]:
    a1 = t1(a=1)
    a2 = t1(a="1")
    a3 = t1(a=None)    #! This will fail
    return a1, a2, a3

Currently, the input with a value None will fail, but the issue is solved with this PR (flyteorg/flyte#5408).

return t2_wo_unsafe(a=a)

with pytest.raises(TypeError):
wf1_wo_unsafe(a=1)

with pytest.raises(TypeError):
wf1_wo_unsafe2(a=1)


def test_unsafe_wf_and_task():
@task(unsafe=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task(unsafe=True)
def t2(a):
if type(a) != int:
return None
return a + 2

@workflow(unsafe=True)
def wf1_with_unsafe(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_unsafe(a=1) == 4
assert wf1_with_unsafe(a="1") is None


def test_wf_with_unsafe_and_safe_tasks():
@task(unsafe=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task
def t2(a: typing.Any) -> typing.Any:
if type(a) != int:
return None
return a + 2

@workflow(unsafe=True)
def wf1_with_unsafe(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_unsafe(a=1) == 4
assert wf1_with_unsafe(a="1") is None

@workflow(unsafe=True)
def wf2_with_unsafe(a):
a1 = t2(a=a)
return t1(a=a1)

assert wf2_with_unsafe(a=1) == 4
assert wf2_with_unsafe(a="1") is None


def test_unsafe_task_with_specified_input():
@task(unsafe=True)
def t1(a, b: typing.Any):
if type(a) != int:
if type(b) != int:
return None
else:
return b
elif type(b) != int:
return a
return a + b

@workflow(unsafe=True)
def wf1_with_unsafe(a: typing.Any, b):
r = t1(a=a, b=b)
return r

assert wf1_with_unsafe(a=1, b=2) == 3
assert wf1_with_unsafe(a="1", b=2) == 2
assert wf1_with_unsafe(a=1, b="2") == 1
assert wf1_with_unsafe(a="1", b="2") is None
Loading