diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 38e7f34132..2a045194e0 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -10,6 +10,7 @@ jobs: build: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: [3.7, 3.8, 3.9] spark-version-suffix: ["", "-spark2"] diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index bcab26ea03..2cbbdc8555 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -123,7 +123,6 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return task_models.TaskSpec(template=tt) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index e56f7f091f..1f8c19e2b6 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -3,6 +3,7 @@ import collections import copy import inspect +import logging as _logging import typing from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union @@ -13,6 +14,9 @@ from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger from flytekit.models.core import interface as _interface_models +from flytekit.types.pickle import FlytePickle + +T = typing.TypeVar("T") class Interface(object): @@ -244,20 +248,41 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) +def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]: + try: + if hasattr(t, "__origin__") and hasattr(t, "__args__"): + if t.__origin__ == list: + return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] + elif t.__origin__ == dict and t.__args__[0] == str: + return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] + else: + TypeEngine.get_transformer(t) + except ValueError: + _logging.warning( + f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. " + f"Pickle can only be used to send objects between the exact same version of Python, " + f"and we strongly recommend to use python type that flyte support." + ) + return FlytePickle[t] + return t + + def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use for each output parameter, construct the TypedInterface object For now the fancy object, maybe in the future a dumb object. - """ outputs = extract_return_annotation(signature.return_annotation) - + for k, v in outputs.items(): + outputs[k] = _change_unrecognized_type_to_pickle(v) inputs = OrderedDict() for k, v in signature.parameters.items(): + annotation = v.annotation + default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None) + inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -273,7 +298,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op def transform_variable_map( - variable_map: Dict[str, type], descriptions: Dict[str, str] = {} + variable_map: Dict[str, type], + descriptions: Dict[str, str] = {}, ) -> Dict[str, _interface_models.Variable]: """ Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a @@ -283,6 +309,14 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) + sub_type: Type[T] = v + if hasattr(v, "__origin__") and hasattr(v, "__args__"): + if v.__origin__ is list: + sub_type = v.__args__[0] + elif v.__origin__ is dict: + sub_type = v.__args__[1] + if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: + res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} return res diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 475576afe4..02195deb10 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -65,6 +65,7 @@ def my_wf(in1: int, in2: int) -> int: def extract_value( ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: flytekit.models.core.types.LiteralType ) -> _literal_models.Literal: + if isinstance(input_val, list): if flyte_literal_type.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") diff --git a/flytekit/extras/cloud_pickle_resolver.py b/flytekit/extras/cloud_pickle_resolver.py index 3ea5fb0e5c..99ca5438c4 100644 --- a/flytekit/extras/cloud_pickle_resolver.py +++ b/flytekit/extras/cloud_pickle_resolver.py @@ -1,7 +1,7 @@ from base64 import b64decode, b64encode from typing import List -import cloudpickle # intentionally not yet part of setup.py +import cloudpickle from flytekit.core.base_task import TaskResolverMixin from flytekit.core.context_manager import SerializationSettings diff --git a/flytekit/models/core/types.py b/flytekit/models/core/types.py index 879293f1d1..90c0cc98c9 100644 --- a/flytekit/models/core/types.py +++ b/flytekit/models/core/types.py @@ -231,6 +231,10 @@ def metadata(self): """ return self._metadata + @metadata.setter + def metadata(self, value): + self._metadata = value + def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.LiteralType diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py new file mode 100644 index 0000000000..65604e67bb --- /dev/null +++ b/flytekit/types/pickle/__init__.py @@ -0,0 +1,12 @@ +""" +Flytekit Pickle Type +========================================================== +.. currentmodule:: flytekit.types.pickle + +.. autosummary:: + :toctree: generated/ + + FlytePickle +""" + +from .pickle import FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py new file mode 100644 index 0000000000..87b3a150b0 --- /dev/null +++ b/flytekit/types/pickle/pickle.py @@ -0,0 +1,89 @@ +import os +import typing +from typing import Type + +import cloudpickle + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.models.core import types as _core_types +from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.core.types import LiteralType + +T = typing.TypeVar("T") + + +class FlytePickle(typing.Generic[T]): + """ + This type is only used by flytekit internally. User should not use this type. + Any type that flyte can't recognize will become FlytePickle + """ + + @classmethod + def python_type(cls) -> None: + return None + + def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]: + if python_type is None: + return cls + + class _SpecificFormatClass(FlytePickle): + # Get the type engine to see this as kind of a generic + __origin__ = FlytePickle + + @classmethod + def python_type(cls) -> typing.Type: + return python_type + + return _SpecificFormatClass + + +class FlytePickleTransformer(TypeTransformer[FlytePickle]): + PYTHON_PICKLE_FORMAT = "PythonPickle" + + def __init__(self): + super().__init__(name="FlytePickle", t=FlytePickle) + + def assert_type(self, t: Type[T], v: T): + # Every type can serialize to pickle, so we don't need to check the type here. + ... + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + uri = lv.scalar.blob.uri + # Deserialize the pickle, and return data in the pickle, + # and download pickle file to local first if file is not in the local file systems. + if ctx.file_access.is_remote(uri): + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, False) + uri = local_path + with open(uri, "rb") as infile: + data = cloudpickle.load(infile) + return data + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + # Dump the task output into pickle + local_dir = ctx.file_access.get_random_local_directory() + os.makedirs(local_dir, exist_ok=True) + local_path = ctx.file_access.get_random_local_path() + uri = os.path.join(local_dir, local_path) + with open(uri, "w+b") as outfile: + cloudpickle.dump(python_val, outfile) + + remote_path = ctx.file_access.get_random_remote_path(uri) + ctx.file_access.put_data(uri, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + return _core_types.LiteralType( + blob=_core_types.BlobType( + format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + +TypeEngine.register(FlytePickleTransformer()) diff --git a/setup.py b/setup.py index 6d33d34b13..6cb2c8f262 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ "docstring-parser>=0.9.0", "diskcache>=5.2.1", "checksumdir>=1.2.0", + "cloudpickle>=2.0.0", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/core/functools/test_decorators.py b/tests/flytekit/unit/core/functools/test_decorators.py index d634fa01a9..f87e714dc9 100644 --- a/tests/flytekit/unit/core/functools/test_decorators.py +++ b/tests/flytekit/unit/core/functools/test_decorators.py @@ -70,9 +70,9 @@ def test_unwrapped_task(): ) error = completed_process.stderr error_str = error.strip().split("\n")[-1] - assert error_str == ( - "ValueError: Type not supported currently in Flytekit. " - "Please register a new transformer" + assert ( + "TaskFunction cannot be a nested/inner or local function." + " It should be accessible at a module level for Flyte to execute it." in error_str ) diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py new file mode 100644 index 0000000000..cb1929f992 --- /dev/null +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -0,0 +1,80 @@ +from collections import OrderedDict +from typing import Dict, List + +from flytekit.common.translator import get_serializable +from flytekit.core import context_manager +from flytekit.core.context_manager import Image, ImageConfig +from flytekit.core.task import task +from flytekit.models.core.literals import BlobMetadata +from flytekit.models.core.types import BlobType, LiteralType +from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def test_to_python_value_and_literal(): + ctx = context_manager.FlyteContext.current_context() + tf = FlytePickleTransformer() + python_val = "fake_output" + lt = tf.get_literal_type(FlytePickle) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, str) + assert output == python_val + + +def test_get_literal_type(): + tf = FlytePickleTransformer() + lt = tf.get_literal_type(FlytePickle) + assert lt == LiteralType( + blob=BlobType( + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE + ) + ) + + +def test_nested(): + class Foo(object): + def __init__(self, number: int): + self.number = number + + @task + def t1(a: int) -> List[List[Foo]]: + return [[Foo(number=a)]] + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert ( + task_spec.template.interface.outputs["o0"].type.collection_type.collection_type.blob.format + is FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ) + + +def test_nested2(): + class Foo(object): + def __init__(self, number: int): + self.number = number + + @task + def t1(a: int) -> List[Dict[str, Foo]]: + return [{"a": Foo(number=a)}] + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert ( + task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format + is FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ) diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index f10bbbe300..81a01517de 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -14,6 +14,7 @@ ) from flytekit.models.core import types as _core_types from flytekit.types.file import FlyteFile +from flytekit.types.pickle import FlytePickle def test_extract_only(): @@ -269,3 +270,21 @@ def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int): assert typed_interface.inputs.get("b").description == "bar" assert typed_interface.outputs.get("x_str").description == "description for x_str" assert typed_interface.outputs.get("y_int").description == "description for y_int" + + +def test_parameter_change_to_pickle_type(): + ctx = context_manager.FlyteContext.current_context() + + class Foo: + def __init__(self, name): + self.name = name + + def z(a: Foo) -> Foo: + ... + + our_interface = transform_signature_to_interface(inspect.signature(z)) + params = transform_inputs_to_parameters(ctx, our_interface) + assert params.parameters["a"].required + assert params.parameters["a"].default is None + assert our_interface.outputs["o0"].__origin__ == FlytePickle + assert our_interface.inputs["a"].__origin__ == FlytePickle diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3912cae4b3..15c1c4112d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -29,6 +29,8 @@ from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer +from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import FlytePickleTransformer def test_type_engine(): @@ -60,6 +62,7 @@ def test_type_resolution(): assert type(TypeEngine.get_transformer(int)) == SimpleTransformer assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer + assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer with pytest.raises(ValueError): TypeEngine.get_transformer(typing.Any) diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index dacfa392bd..c079d2818e 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1016,13 +1016,27 @@ def my_wf(a: int, b: str) -> (MyCustomType, int): def test_arbit_class(): class Foo(object): - pass + def __init__(self, number: int): + self.number = number - with pytest.raises(ValueError): + @task + def t1(a: int) -> Foo: + return Foo(number=a) - @task - def t1(a: int) -> Foo: - return Foo() + @task + def t2(a: Foo) -> typing.List[Foo]: + return [a, a] + + @task + def t3(a: typing.List[Foo]) -> typing.Dict[str, Foo]: + return {"hello": a[0]} + + def wf(a: int) -> typing.Dict[str, Foo]: + o1 = t1(a=a) + o2 = t2(a=o1) + return t3(a=o2) + + assert wf(1)["hello"].number == 1 def test_dataclass_more():