diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 117abf27e3..60ae7bbf3e 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -235,9 +235,17 @@ def __init__( # Not exposing this as a property for now. self._entrypoint_settings = entrypoint_settings + raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( + sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw" + ) + self._file_access = file_access or FileAccessProvider( + local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), + raw_output_prefix=raw_output_data_prefix, + ) # Save the file access object locally, but also make it available for use from the context. - FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build()) - self._file_access = file_access + FlyteContextManager.with_context( + FlyteContextManager.current_context().with_file_access(self._file_access).build() + ) # TODO: Reconsider whether we want this. Probably best to not cache. self._serialized_entity_cache = OrderedDict() diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index eb64e5943d..9bff3a5dec 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -10,6 +10,7 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType +from flytekit.types.pickle.pickle import FlytePickleTransformer def noop(): @@ -348,7 +349,11 @@ def _downloader(): return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: - if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ): return FlyteFile.__class_getitem__(literal_type.blob.format) raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 9219d3a8b4..3472dec7e6 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -78,6 +78,16 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ctx.file_access.put_data(uri, remote_path, is_multipart=False) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ): + return FlytePickle + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 50a77c43d9..9f88a80fb2 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -356,6 +356,14 @@ def test_guessing_basic(): pt = TypeEngine.guess_python_type(lt) assert pt is None + lt = model_types.LiteralType( + blob=BlobType( + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE + ) + ) + pt = TypeEngine.guess_python_type(lt) + assert pt is FlytePickle + def test_guessing_containers(): b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) @@ -552,6 +560,25 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +def test_pickle_type(): + class Foo(object): + def __init__(self, number: int): + self.number = number + + lt = TypeEngine.to_literal_type(FlytePickle) + assert lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT + assert lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, Foo(1), FlytePickle, lt) + assert "/tmp/flyte/" in lv.scalar.blob.uri + + transformer = FlytePickleTransformer() + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert Foo(1).number == pv.number + + def test_enum_in_dataclass(): @dataclass_json @dataclass