Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Sep 29, 2021
1 parent 301ead5 commit e4c1bf6
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 9 deletions.
12 changes: 7 additions & 5 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op
except ValueError:
_logging.warning(
f"We change the output type to the PythonPickle "
f"since we can't find a transformer for the type {v}. "
f"We strongly recommend to use python type that flyte support."
f"since we can't find a transformer for the type {v}.\n"
f"Pickle can only be used to send objects between the exact same version of Python, "
f"and we strongly recommend to use python type that flyte support."
)
outputs[k] = FlytePickle(python_type=v) # type: ignore

Expand All @@ -273,9 +274,10 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op
TypeEngine.get_transformer(annotation)
except ValueError:
_logging.warning(
f"We change the input type to the PythonPickle "
f"if we can't find a transformer for the original type {v}. "
f"We strongly recommend to use python type that flyte support."
f"We change the output type to the PythonPickle "
f"since we can't find a transformer for the type {v}.\n"
f"Pickle can only be used to send objects between the exact same version of Python, "
f"and we strongly recommend to use python type that flyte support."
)
annotation = FlytePickle(python_type=v.annotation)
default = v.default if v.default is not inspect.Parameter.empty else None
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extras/cloud_pickle_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from base64 import b64decode, b64encode
from typing import List

import cloudpickle # intentionally not yet part of setup.py
import cloudpickle

from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
Expand Down
6 changes: 3 additions & 3 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import pickle
import cloudpickle
import typing
from typing import Type

Expand Down Expand Up @@ -38,7 +38,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
ctx.file_access.get_data(uri, self.PICKLE, False)
uri = self.PICKLE
infile = open(uri, "rb")
data = pickle.load(infile)
data = cloudpickle.load(infile)
infile.close()
return data

Expand All @@ -53,7 +53,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
os.makedirs(local_dir, exist_ok=True)
uri = os.path.join(local_dir, self.PICKLE)
outfile = open(uri, "w+b")
pickle.dump(python_val, outfile)
cloudpickle.dump(python_val, outfile)
outfile.close()

remote_path = ctx.file_access.get_random_remote_path(uri)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"singledispatchmethod; python_version < '3.8.0'",
"docstring-parser>=0.9.0",
"diskcache>=5.2.1",
"cloudpickle>=2.0.0"
],
extras_require=extras_require,
scripts=[
Expand Down
34 changes: 34 additions & 0 deletions tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from flytekit.core import context_manager
from flytekit.models.core.types import BlobType
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer


def test_to_python_value_and_literal():
ctx = context_manager.FlyteContext.current_context()
tf = FlytePickleTransformer()
python_val = "fake_output"
lt = tf.get_literal_type(FlytePickle)

lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None

output = tf.to_python_value(ctx, lv, str)
assert output == python_val


def test_get_literal_type():
tf = FlytePickleTransformer()
lt = tf.get_literal_type(FlytePickle)
assert lt == LiteralType(
blob=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
)
)
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from flytekit.models.core import types as _core_types
from flytekit.types.file import FlyteFile
from flytekit.types.pickle import FlytePickle


def test_extract_only():
Expand Down Expand Up @@ -269,3 +270,21 @@ def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int):
assert typed_interface.inputs.get("b").description == "bar"
assert typed_interface.outputs.get("x_str").description == "description for x_str"
assert typed_interface.outputs.get("y_int").description == "description for y_int"


def test_parameter_change_to_pickle_type():
ctx = context_manager.FlyteContext.current_context()

class Foo:
def __init__(self, name):
self.name = name

def z(a: Foo) -> Foo:
...

our_interface = transform_signature_to_interface(inspect.signature(z))
params = transform_inputs_to_parameters(ctx, our_interface)
assert params.parameters["a"].required
assert params.parameters["a"].default is None
assert isinstance(our_interface.outputs["o0"], FlytePickle)
assert isinstance(our_interface.inputs["a"], FlytePickle)
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from flytekit.models.types import LiteralType, SimpleType
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer


def test_type_engine():
Expand Down Expand Up @@ -59,6 +61,7 @@ def test_type_resolution():
assert type(TypeEngine.get_transformer(int)) == SimpleTransformer

assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer
assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer

with pytest.raises(ValueError):
TypeEngine.get_transformer(typing.Any)
Expand Down

0 comments on commit e4c1bf6

Please sign in to comment.