Skip to content

Commit

Permalink
Remote client failed to fetch FlytePickle object (#764)
Browse files Browse the repository at this point in the history
* Fetch pickle value from flytekit remote

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Remove default value

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 7, 2021
1 parent b257569 commit 87c665b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
12 changes: 10 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,17 @@ def __init__(
# Not exposing this as a property for now.
self._entrypoint_settings = entrypoint_settings

raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join(
sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw"
)
self._file_access = file_access or FileAccessProvider(
local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"),
raw_output_prefix=raw_output_data_prefix,
)
# Save the file access object locally, but also make it available for use from the context.
FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build())
self._file_access = file_access
FlyteContextManager.with_context(
FlyteContextManager.current_context().with_file_access(self._file_access).build()
)

# TODO: Reconsider whether we want this. Probably best to not cache.
self._serialized_entity_cache = OrderedDict()
Expand Down
7 changes: 6 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickleTransformer


def noop():
Expand Down Expand Up @@ -348,7 +349,11 @@ def _downloader():
return ff

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT
):
return FlyteFile.__class_getitem__(literal_type.blob.format)

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
Expand Down
10 changes: 10 additions & 0 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
):
return FlytePickle

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
Expand Down
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ def test_guessing_basic():
pt = TypeEngine.guess_python_type(lt)
assert pt is None

lt = model_types.LiteralType(
blob=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
)
)
pt = TypeEngine.guess_python_type(lt)
assert pt is FlytePickle


def test_guessing_containers():
b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
Expand Down Expand Up @@ -552,6 +560,25 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_pickle_type():
class Foo(object):
def __init__(self, number: int):
self.number = number

lt = TypeEngine.to_literal_type(FlytePickle)
assert lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
assert lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE

ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, Foo(1), FlytePickle, lt)
assert "/tmp/flyte/" in lv.scalar.blob.uri

transformer = FlytePickleTransformer()
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert Foo(1).number == pv.number


def test_enum_in_dataclass():
@dataclass_json
@dataclass
Expand Down

0 comments on commit 87c665b

Please sign in to comment.