From 9f3b30dfe45f3c2716d52138d0dbb6e3a0a8817e Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 22 Jul 2021 15:37:58 -0700 Subject: [PATCH] updated data persistence api Signed-off-by: Ketan Umare --- flytekit/__init__.py | 4 + flytekit/bin/entrypoint.py | 2 +- flytekit/core/context_manager.py | 11 +-- flytekit/core/data_persistence.py | 94 ++++++++----------- flytekit/extras/persistence/gcs_gsutil.py | 77 ++++----------- flytekit/extras/persistence/http.py | 19 ++-- flytekit/extras/persistence/s3_awscli.py | 58 +++--------- flytekit/extras/sqlite3/task.py | 2 +- flytekit/interfaces/data/common.py | 3 - flytekit/interfaces/data/s3/s3proxy.py | 2 +- flytekit/types/schema/types.py | 2 +- flytekit/types/schema/types_pandas.py | 2 +- .../pandera/flytekitplugins/pandera/schema.py | 2 +- 13 files changed, 96 insertions(+), 182 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 2bfdbae885..8b8a5cf6ac 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -141,6 +141,7 @@ from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.launch_plan import LaunchPlan from flytekit.core.map_task import map_task @@ -153,6 +154,9 @@ from flytekit.core.task import Secret, reference_task, task from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow +from flytekit.extras.persistence import gcs_gsutil as _gcs +from flytekit.extras.persistence import http as _http +from flytekit.extras.persistence import s3_awscli as _s3 from flytekit.loggers import logger from flytekit.types import schema diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 650328c6e8..7b70309435 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -176,7 +176,7 @@ def _dispatch_execute( for k, v in output_file_dict.items(): _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k)) - ctx.file_access.upload_directory(ctx.execution_state.engine_dir, output_prefix) + ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True) _logging.info(f"Engine folder written successfully to the output prefix {output_prefix}") diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 9aea017334..361d88eb06 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -33,9 +33,8 @@ from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.configuration import images, internal from flytekit.configuration import sdk as _sdk_config -from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider from flytekit.engines.unit import mock_stats as _mock_stats -from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models.core import identifier as _identifier # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -463,7 +462,7 @@ def with_compilation_state(self, c: CompilationState) -> Builder: def with_new_compilation_state(self) -> Builder: return self.with_compilation_state(self.new_compilation_state()) - def with_file_access(self, fa: _data_proxy.FileAccessProvider) -> Builder: + def with_file_access(self, fa: FileAccessProvider) -> Builder: return self.new_builder().with_file_access(fa) def with_serialization_settings(self, ss: SerializationSettings) -> Builder: @@ -498,7 +497,7 @@ def current_context() -> FlyteContext: @dataclass class Builder(object): - file_access: _data_proxy.FileAccessProvider + file_access: FileAccessProvider level: int = 0 compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None @@ -551,7 +550,7 @@ def with_compilation_state(self, c: CompilationState) -> "Builder": def with_new_compilation_state(self) -> "Builder": return self.with_compilation_state(self.new_compilation_state()) - def with_file_access(self, fa: _data_proxy.FileAccessProvider) -> "Builder": + def with_file_access(self, fa: FileAccessProvider) -> "Builder": self.file_access = fa return self @@ -680,7 +679,7 @@ def initialize(): logging=_logging, tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"), ) - default_context = FlyteContext(file_access=_data_proxy.default_local_file_access_provider) + default_context = FlyteContext(file_access=default_local_file_access_provider) default_context = default_context.with_execution_state( default_context.new_execution_state().with_params(user_space_params=default_user_space_params) ).build() diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 62d6cac0c5..4cc6a1d28f 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -6,13 +6,13 @@ from abc import abstractmethod from distutils import dir_util as _dir_util from shutil import copyfile as _copyfile -from typing import Union, Dict +from typing import Dict, Union from uuid import UUID -from flytekit.loggers import logger from flytekit.common.exceptions.user import FlyteAssertion from flytekit.common.utils import PerformanceTimer from flytekit.interfaces.random import random +from flytekit.loggers import logger class UnsupportedPersistenceOp(Exception): @@ -50,30 +50,16 @@ def exists(self, path: str) -> bool: pass @abstractmethod - def download_directory(self, remote_path: str, local_path: str): - """ - downloads a directory from path to path recursively - """ - pass - - @abstractmethod - def download(self, remote_path: str, local_path: str): - """ - downloads a file from path to path - """ - pass - - @abstractmethod - def upload(self, file_path: str, to_path: str): + def get(self, from_path: str, to_path: str, recursive: bool = False): """ - uploads the given file to path + Retrieves a data from from_path and writes to the given to_path (to_path is locally accessible) """ pass @abstractmethod - def upload_directory(self, local_path: str, remote_path: str): + def put(self, from_path: str, to_path: str, recursive: bool = False): """ - uploads a directory from path to path recursively + Stores data from from_path and writes to the given to_path (from_path is locally accessible) """ pass @@ -98,6 +84,7 @@ class DataPersistencePlugins(object): These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins. """ + _PLUGINS: Dict[str, DataPersistence] = {} @classmethod @@ -115,7 +102,8 @@ def register_plugin(cls, protocol: str, plugin: DataPersistence, force: bool = F if not force: raise TypeError( f"Cannot register plugin {plugin.name} for protocol {protocol} as plugin {p.name} is already" - f" registered for the same protocol. You can force register the new plugin by passing force=True") + f" registered for the same protocol. You can force register the new plugin by passing force=True" + ) cls._PLUGINS[protocol] = plugin @@ -146,6 +134,12 @@ def is_supported_protocol(cls, protocol: str) -> bool: class DiskPersistence(DataPersistence): + """ + The simplest form of persistence that is available with default flytekit - Disk based persistence. + This will store all data locally and retreive the data from local. This is helpful for local execution and simulating + runs. + """ + PROTOCOL = "file://" def __init__(self, *args, **kwargs): @@ -187,21 +181,22 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N def exists(self, path: str): return _os.path.exists(self.strip_file_header(path)) - def download_directory(self, from_path: str, to_path: str): + def get(self, from_path: str, to_path: str, recursive: bool = False): if from_path != to_path: - _dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def download(self, from_path: str, to_path: str): - _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def upload(self, from_path: str, to_path: str): - # Emulate s3's flat storage by automatically creating directory path - self._make_local_path(_os.path.dirname(self.strip_file_header(to_path))) - # Write the object to a local file in the sandbox - _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) + if recursive: + _dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + else: + _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - def upload_directory(self, from_path, to_path): - self.download_directory(from_path, to_path) + def put(self, from_path: str, to_path: str, recursive: bool = False): + if from_path != to_path: + if recursive: + _dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + else: + # Emulate s3's flat storage by automatically creating directory path + self._make_local_path(_os.path.dirname(self.strip_file_header(to_path))) + # Write the object to a local file in the sandbox + _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) def construct_path(self, add_protocol: bool, *args) -> str: if add_protocol: @@ -239,12 +234,13 @@ def local_sandbox_dir(self) -> os.PathLike: def local_access(self) -> DiskPersistence: return self._local - def construct_random_path(self, persist: DataPersistence, - file_path_or_file_name: typing.Optional[str] = None) -> str: + def construct_random_path( + self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None + ) -> str: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - key = UUID(int=random.random.getrandbits(128)).hex + key = UUID(int=random.getrandbits(128)).hex if file_path_or_file_name: _, tail = os.path.split(file_path_or_file_name) if tail: @@ -286,28 +282,27 @@ def download_directory(self, remote_path: str, local_path: str): """ Downloads directory from given remote to local path """ - return DataPersistencePlugins.find_plugin(remote_path).download_directory(remote_path, local_path) + return self.get_data(remote_path, local_path, is_multipart=True) def download(self, remote_path: str, local_path: str): """ Downloads from remote to local """ - return DataPersistencePlugins.find_plugin(remote_path).download(remote_path, local_path) + return self.get_data(remote_path, local_path) def upload(self, file_path: str, to_path: str): """ :param Text file_path: :param Text to_path: """ - return DataPersistencePlugins.find_plugin(to_path).upload(file_path, to_path) + return self.put_data(file_path, to_path) def upload_directory(self, local_path: str, remote_path: str): """ :param Text local_path: :param Text remote_path: """ - # TODO: https://github.com/flyteorg/flyte/issues/762 - test if this works! - return DataPersistencePlugins.find_plugin(remote_path).upload_directory(local_path, remote_path) + return self.put_data(local_path, remote_path, is_multipart=True) def get_data(self, remote_path: str, local_path: str, is_multipart=False): """ @@ -317,10 +312,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): """ try: with PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)): - if is_multipart: - self.download_directory(remote_path, local_path) - else: - self.download(remote_path, local_path) + DataPersistencePlugins.find_plugin(remote_path).get(remote_path, local_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" @@ -343,10 +335,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul """ try: with PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)): - if is_multipart: - self._default_remote.upload_directory(local_path, remote_path) - else: - self._default_remote.upload(local_path, remote_path) + DataPersistencePlugins.find_plugin(remote_path).put(local_path, remote_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" @@ -360,6 +349,5 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul # TODO make this use tmpdir tmp_dir = os.path.join("/tmp/flyte", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) default_local_file_access_provider = FileAccessProvider( - local_sandbox_dir=os.path.join(tmp_dir, "sandbox"), - raw_output_prefix=os.path.join(tmp_dir, "raw") -) \ No newline at end of file + local_sandbox_dir=os.path.join(tmp_dir, "sandbox"), raw_output_prefix=os.path.join(tmp_dir, "raw") +) diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py index dfd5d3a4a8..31fb68d026 100644 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ b/flytekit/extras/persistence/gcs_gsutil.py @@ -1,11 +1,9 @@ import os as _os import sys as _sys -import uuid as _uuid from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException from flytekit.configuration import gcp as _gcp_config from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.interfaces import random as _flyte_random from flytekit.tools import subprocess as _subprocess if _sys.version_info >= (3,): @@ -27,20 +25,9 @@ class GCSPersistence(DataPersistence): _GS_UTIL_CLI = "gsutil" PROTOCOL = "gs://" - def __init__(self, raw_output_data_prefix_override: str = None): - """ - :param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see - S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard - path (_get_shard_path), use this prefix instead as a base. This code assumes that the - path passed in is correct. That is, an S3 path won't be passed in when running on GCP. - """ - self._raw_output_data_prefix_override = raw_output_data_prefix_override + def __init__(self): super(GCSPersistence, self).__init__(name="gcs-gsutil") - @property - def raw_output_data_prefix_override(self) -> str: - return self._raw_output_data_prefix_override - @staticmethod def _check_binary(): """ @@ -81,58 +68,30 @@ def exists(self, remote_path): except Exception: return False - def download_directory(self, remote_path, local_path): - """ - :param Text remote_path: remote gs:// path - :param Text local_path: directory to copy to - """ - GCSPersistence._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(remote_path), local_path) - return _update_cmd_config_and_execute(cmd) - - def download(self, remote_path, local_path): - """ - :param Text remote_path: remote gs:// path - :param Text local_path: directory to copy to - """ - if not remote_path.startswith("gs://"): + def get(self, from_path: str, to_path: str, recursive: bool = False): + if not from_path.startswith("gs://"): raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") GCSPersistence._check_binary() + if recursive: + cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(from_path), to_path) + else: + cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - cmd = self._maybe_with_gsutil_parallelism("cp", remote_path, local_path) - return _update_cmd_config_and_execute(cmd) - - def upload(self, file_path, to_path): - """ - :param Text file_path: - :param Text to_path: - """ - GCSPersistence._check_binary() - - cmd = self._maybe_with_gsutil_parallelism("cp", file_path, to_path) return _update_cmd_config_and_execute(cmd) - def upload_directory(self, local_path, remote_path): - """ - :param Text local_path: - :param Text remote_path: - """ - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - + def put(self, from_path: str, to_path: str, recursive: bool = False): GCSPersistence._check_binary() - cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(local_path), - remote_path if remote_path.endswith("/") else remote_path + "/", - ) + if recursive: + cmd = self._maybe_with_gsutil_parallelism( + "cp", + "-r", + _amend_path(from_path), + to_path if to_path.endswith("/") else to_path + "/", + ) + else: + cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) return _update_cmd_config_and_execute(cmd) def construct_path(self, add_protocol: bool, *paths) -> str: @@ -142,4 +101,4 @@ def construct_path(self, add_protocol: bool, *paths) -> str: return path -DataPersistencePlugins.register_plugin("gcs://", GCSPersistence()) \ No newline at end of file +DataPersistencePlugins.register_plugin("gcs://", GCSPersistence()) diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py index d2582e8b20..8ff68a797b 100644 --- a/flytekit/extras/persistence/http.py +++ b/flytekit/extras/persistence/http.py @@ -28,10 +28,11 @@ def exists(self, path: str): ) return rsp.status_code == type(self)._HTTP_OK - def download_directory(self, from_path: str, to_path: str): - raise _user_exceptions.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.") - - def download(self, from_path: str, to_path: str): + def get(self, from_path: str, to_path: str, recursive: bool = False): + if recursive: + raise _user_exceptions.FlyteAssertion( + "Reading data recursively from HTTP endpoint is not currently supported." + ) rsp = _requests.get(from_path) if rsp.status_code != type(self)._HTTP_OK: raise _user_exceptions.FlyteValueException( @@ -41,16 +42,14 @@ def download(self, from_path: str, to_path: str): with open(to_path, "wb") as writer: writer.write(rsp.content) - def upload(self, from_path: str, to_path: str): - raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") - - def upload_directory(self, from_path: str, to_path: str): + def put(self, from_path: str, to_path: str, recursive: bool = False): raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") def construct_path(self, add_protocol: bool, *paths) -> str: raise _user_exceptions.FlyteAssertion( - "There are multiple ways of creating http links / paths, this is not supported by the persistence layer") + "There are multiple ways of creating http links / paths, this is not supported by the persistence layer" + ) DataPersistencePlugins.register_plugin("http://", HttpPersistence()) -DataPersistencePlugins.register_plugin("https://", HttpPersistence()) \ No newline at end of file +DataPersistencePlugins.register_plugin("https://", HttpPersistence()) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py index f7fb39d323..3d60562731 100644 --- a/flytekit/extras/persistence/s3_awscli.py +++ b/flytekit/extras/persistence/s3_awscli.py @@ -121,64 +121,32 @@ def exists(self, remote_path): else: raise ex - def download_directory(self, remote_path, local_path): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ + def get(self, from_path: str, to_path: str, recursive: bool = False): S3Persistence._check_binary() - if not remote_path.startswith("s3://"): + if not from_path.startswith("s3://"): raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", remote_path, local_path] - return _update_cmd_config_and_execute(cmd) - - def download(self, remote_path, local_path): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - S3Persistence._check_binary() - cmd = [S3Persistence._AWS_CLI, "s3", "cp", remote_path, local_path] - return _update_cmd_config_and_execute(cmd) - - def upload(self, file_path, to_path): - """ - :param Text file_path: - :param Text to_path: - """ - S3Persistence._check_binary() - - extra_args = { - "ACL": "bucket-owner-full-control", - } - - cmd = [S3Persistence._AWS_CLI, "s3", "cp"] - cmd.extend(_extra_args(extra_args)) - cmd += [file_path, to_path] - + if recursive: + cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] + else: + cmd = [S3Persistence._AWS_CLI, "s3", "cp", remote_path, local_path] return _update_cmd_config_and_execute(cmd) - def upload_directory(self, local_path, remote_path): - """ - :param Text local_path: - :param Text remote_path: - """ + def put(self, from_path: str, to_path: str, recursive: bool = False): extra_args = { "ACL": "bucket-owner-full-control", } - if not remote_path.startswith("s3://"): + if not to_path.startswith("s3://"): raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") S3Persistence._check_binary() - cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive"] + cmd = [S3Persistence._AWS_CLI, "s3", "cp"] + if recursive: + cmd += ["--recursive"] cmd.extend(_extra_args(extra_args)) - cmd += [local_path, remote_path] + cmd += [from_path, to_path] return _update_cmd_config_and_execute(cmd) def construct_path(self, add_protocol: bool, *paths) -> str: @@ -188,4 +156,4 @@ def construct_path(self, add_protocol: bool, *paths) -> str: return path -DataPersistencePlugins.register_plugin("s3://", S3Persistence()) \ No newline at end of file +DataPersistencePlugins.register_plugin("s3://", S3Persistence()) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 391b88035b..e4a803f50a 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -116,7 +116,7 @@ def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.A ctx = FlyteContext.current_context() file_ext = os.path.basename(tt.custom["uri"]) local_path = os.path.join(temp_dir, file_ext) - ctx.file_access.download(tt.custom["uri"], local_path) + ctx.file_access.get_data(tt.custom["uri"], local_path) if tt.custom["compressed"]: local_path = unarchive_file(local_path, temp_dir) diff --git a/flytekit/interfaces/data/common.py b/flytekit/interfaces/data/common.py index bd69844d20..1544b608b6 100644 --- a/flytekit/interfaces/data/common.py +++ b/flytekit/interfaces/data/common.py @@ -1,6 +1,3 @@ -import abc as _abc - - class DataProxy(object): def __init__(self, name: str): self._name = name diff --git a/flytekit/interfaces/data/s3/s3proxy.py b/flytekit/interfaces/data/s3/s3proxy.py index 6c856cef8f..ab11e7b738 100644 --- a/flytekit/interfaces/data/s3/s3proxy.py +++ b/flytekit/interfaces/data/s3/s3proxy.py @@ -221,4 +221,4 @@ def _get_shard_path(self) -> str: shard = "" for _ in _six_moves.range(_aws_config.S3_SHARD_STRING_LENGTH.get()): shard += _flyte_random.random.choice(self._SHARD_CHARACTERS) - return _aws_config.S3_SHARD_FORMATTER.get().format(shard) \ No newline at end of file + return _aws_config.S3_SHARD_FORMATTER.get().format(shard) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 5814910452..76c413a74e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -365,7 +365,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise AssertionError("Can only covert a literal schema to a FlyteSchema") def downloader(x, y): - ctx.file_access.download_directory(x, y) + ctx.file_access.get_data(x, y, is_multipart=True) return expected_python_type( local_path=ctx.file_access.get_random_local_directory(), diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index f0a347282a..89f90738c5 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -139,7 +139,7 @@ def to_python_value( if not (lv and lv.scalar and lv.scalar.schema): return pandas.DataFrame() local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.download_directory(lv.scalar.schema.uri, local_dir) + ctx.file_access.get_data(lv.scalar.schema.uri, local_dir, is_multipart=True) r = PandasSchemaReader(local_dir=local_dir, cols=None, fmt=SchemaFormat.PARQUET) return r.all() diff --git a/plugins/pandera/flytekitplugins/pandera/schema.py b/plugins/pandera/flytekitplugins/pandera/schema.py index f283df6b87..2daaaf5fd7 100644 --- a/plugins/pandera/flytekitplugins/pandera/schema.py +++ b/plugins/pandera/flytekitplugins/pandera/schema.py @@ -76,7 +76,7 @@ def to_python_value( raise AssertionError("Can only covert a literal schema to a pandera schema") def downloader(x, y): - ctx.file_access.download_directory(x, y) + ctx.file_access.get_data(x, y, is_multipart=True) df = FlyteSchema( local_path=ctx.file_access.get_random_local_directory(),