Skip to content

Commit

Permalink
updated data persistence api
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Jul 22, 2021
1 parent ea87889 commit 9f3b30d
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 182 deletions.
4 changes: 4 additions & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
from flytekit.core.condition import conditional
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
Expand All @@ -153,6 +154,9 @@
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.extras.persistence import gcs_gsutil as _gcs
from flytekit.extras.persistence import http as _http
from flytekit.extras.persistence import s3_awscli as _s3
from flytekit.loggers import logger
from flytekit.types import schema

Expand Down
2 changes: 1 addition & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _dispatch_execute(
for k, v in output_file_dict.items():
_common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k))

ctx.file_access.upload_directory(ctx.execution_state.engine_dir, output_prefix)
ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True)
_logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")


Expand Down
11 changes: 5 additions & 6 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.configuration import images, internal
from flytekit.configuration import sdk as _sdk_config
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider
from flytekit.engines.unit import mock_stats as _mock_stats
from flytekit.interfaces.data import data_proxy as _data_proxy
from flytekit.models.core import identifier as _identifier

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin
Expand Down Expand Up @@ -463,7 +462,7 @@ def with_compilation_state(self, c: CompilationState) -> Builder:
def with_new_compilation_state(self) -> Builder:
return self.with_compilation_state(self.new_compilation_state())

def with_file_access(self, fa: _data_proxy.FileAccessProvider) -> Builder:
def with_file_access(self, fa: FileAccessProvider) -> Builder:
return self.new_builder().with_file_access(fa)

def with_serialization_settings(self, ss: SerializationSettings) -> Builder:
Expand Down Expand Up @@ -498,7 +497,7 @@ def current_context() -> FlyteContext:

@dataclass
class Builder(object):
file_access: _data_proxy.FileAccessProvider
file_access: FileAccessProvider
level: int = 0
compilation_state: Optional[CompilationState] = None
execution_state: Optional[ExecutionState] = None
Expand Down Expand Up @@ -551,7 +550,7 @@ def with_compilation_state(self, c: CompilationState) -> "Builder":
def with_new_compilation_state(self) -> "Builder":
return self.with_compilation_state(self.new_compilation_state())

def with_file_access(self, fa: _data_proxy.FileAccessProvider) -> "Builder":
def with_file_access(self, fa: FileAccessProvider) -> "Builder":
self.file_access = fa
return self

Expand Down Expand Up @@ -680,7 +679,7 @@ def initialize():
logging=_logging,
tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"),
)
default_context = FlyteContext(file_access=_data_proxy.default_local_file_access_provider)
default_context = FlyteContext(file_access=default_local_file_access_provider)
default_context = default_context.with_execution_state(
default_context.new_execution_state().with_params(user_space_params=default_user_space_params)
).build()
Expand Down
94 changes: 41 additions & 53 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from abc import abstractmethod
from distutils import dir_util as _dir_util
from shutil import copyfile as _copyfile
from typing import Union, Dict
from typing import Dict, Union
from uuid import UUID

from flytekit.loggers import logger
from flytekit.common.exceptions.user import FlyteAssertion
from flytekit.common.utils import PerformanceTimer
from flytekit.interfaces.random import random
from flytekit.loggers import logger


class UnsupportedPersistenceOp(Exception):
Expand Down Expand Up @@ -50,30 +50,16 @@ def exists(self, path: str) -> bool:
pass

@abstractmethod
def download_directory(self, remote_path: str, local_path: str):
"""
downloads a directory from path to path recursively
"""
pass

@abstractmethod
def download(self, remote_path: str, local_path: str):
"""
downloads a file from path to path
"""
pass

@abstractmethod
def upload(self, file_path: str, to_path: str):
def get(self, from_path: str, to_path: str, recursive: bool = False):
"""
uploads the given file to path
Retrieves a data from from_path and writes to the given to_path (to_path is locally accessible)
"""
pass

@abstractmethod
def upload_directory(self, local_path: str, remote_path: str):
def put(self, from_path: str, to_path: str, recursive: bool = False):
"""
uploads a directory from path to path recursively
Stores data from from_path and writes to the given to_path (from_path is locally accessible)
"""
pass

Expand All @@ -98,6 +84,7 @@ class DataPersistencePlugins(object):
These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins.
"""

_PLUGINS: Dict[str, DataPersistence] = {}

@classmethod
Expand All @@ -115,7 +102,8 @@ def register_plugin(cls, protocol: str, plugin: DataPersistence, force: bool = F
if not force:
raise TypeError(
f"Cannot register plugin {plugin.name} for protocol {protocol} as plugin {p.name} is already"
f" registered for the same protocol. You can force register the new plugin by passing force=True")
f" registered for the same protocol. You can force register the new plugin by passing force=True"
)

cls._PLUGINS[protocol] = plugin

Expand Down Expand Up @@ -146,6 +134,12 @@ def is_supported_protocol(cls, protocol: str) -> bool:


class DiskPersistence(DataPersistence):
"""
The simplest form of persistence that is available with default flytekit - Disk based persistence.
This will store all data locally and retreive the data from local. This is helpful for local execution and simulating
runs.
"""

PROTOCOL = "file://"

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -187,21 +181,22 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N
def exists(self, path: str):
return _os.path.exists(self.strip_file_header(path))

def download_directory(self, from_path: str, to_path: str):
def get(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
_dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))

def download(self, from_path: str, to_path: str):
_copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))

def upload(self, from_path: str, to_path: str):
# Emulate s3's flat storage by automatically creating directory path
self._make_local_path(_os.path.dirname(self.strip_file_header(to_path)))
# Write the object to a local file in the sandbox
_copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))
if recursive:
_dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
else:
_copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))

def upload_directory(self, from_path, to_path):
self.download_directory(from_path, to_path)
def put(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
if recursive:
_dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
else:
# Emulate s3's flat storage by automatically creating directory path
self._make_local_path(_os.path.dirname(self.strip_file_header(to_path)))
# Write the object to a local file in the sandbox
_copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))

def construct_path(self, add_protocol: bool, *args) -> str:
if add_protocol:
Expand Down Expand Up @@ -239,12 +234,13 @@ def local_sandbox_dir(self) -> os.PathLike:
def local_access(self) -> DiskPersistence:
return self._local

def construct_random_path(self, persist: DataPersistence,
file_path_or_file_name: typing.Optional[str] = None) -> str:
def construct_random_path(
self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None
) -> str:
"""
Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name
"""
key = UUID(int=random.random.getrandbits(128)).hex
key = UUID(int=random.getrandbits(128)).hex
if file_path_or_file_name:
_, tail = os.path.split(file_path_or_file_name)
if tail:
Expand Down Expand Up @@ -286,28 +282,27 @@ def download_directory(self, remote_path: str, local_path: str):
"""
Downloads directory from given remote to local path
"""
return DataPersistencePlugins.find_plugin(remote_path).download_directory(remote_path, local_path)
return self.get_data(remote_path, local_path, is_multipart=True)

def download(self, remote_path: str, local_path: str):
"""
Downloads from remote to local
"""
return DataPersistencePlugins.find_plugin(remote_path).download(remote_path, local_path)
return self.get_data(remote_path, local_path)

def upload(self, file_path: str, to_path: str):
"""
:param Text file_path:
:param Text to_path:
"""
return DataPersistencePlugins.find_plugin(to_path).upload(file_path, to_path)
return self.put_data(file_path, to_path)

def upload_directory(self, local_path: str, remote_path: str):
"""
:param Text local_path:
:param Text remote_path:
"""
# TODO: https://github.com/flyteorg/flyte/issues/762 - test if this works!
return DataPersistencePlugins.find_plugin(remote_path).upload_directory(local_path, remote_path)
return self.put_data(local_path, remote_path, is_multipart=True)

def get_data(self, remote_path: str, local_path: str, is_multipart=False):
"""
Expand All @@ -317,10 +312,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False):
"""
try:
with PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)):
if is_multipart:
self.download_directory(remote_path, local_path)
else:
self.download(remote_path, local_path)
DataPersistencePlugins.find_plugin(remote_path).get(remote_path, local_path, recursive=is_multipart)
except Exception as ex:
raise FlyteAssertion(
"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n"
Expand All @@ -343,10 +335,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul
"""
try:
with PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)):
if is_multipart:
self._default_remote.upload_directory(local_path, remote_path)
else:
self._default_remote.upload(local_path, remote_path)
DataPersistencePlugins.find_plugin(remote_path).put(local_path, remote_path, recursive=is_multipart)
except Exception as ex:
raise FlyteAssertion(
f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n"
Expand All @@ -360,6 +349,5 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul
# TODO make this use tmpdir
tmp_dir = os.path.join("/tmp/flyte", datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
default_local_file_access_provider = FileAccessProvider(
local_sandbox_dir=os.path.join(tmp_dir, "sandbox"),
raw_output_prefix=os.path.join(tmp_dir, "raw")
)
local_sandbox_dir=os.path.join(tmp_dir, "sandbox"), raw_output_prefix=os.path.join(tmp_dir, "raw")
)
77 changes: 18 additions & 59 deletions flytekit/extras/persistence/gcs_gsutil.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os as _os
import sys as _sys
import uuid as _uuid

from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException
from flytekit.configuration import gcp as _gcp_config
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.interfaces import random as _flyte_random
from flytekit.tools import subprocess as _subprocess

if _sys.version_info >= (3,):
Expand All @@ -27,20 +25,9 @@ class GCSPersistence(DataPersistence):
_GS_UTIL_CLI = "gsutil"
PROTOCOL = "gs://"

def __init__(self, raw_output_data_prefix_override: str = None):
"""
:param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see
S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard
path (_get_shard_path), use this prefix instead as a base. This code assumes that the
path passed in is correct. That is, an S3 path won't be passed in when running on GCP.
"""
self._raw_output_data_prefix_override = raw_output_data_prefix_override
def __init__(self):
super(GCSPersistence, self).__init__(name="gcs-gsutil")

@property
def raw_output_data_prefix_override(self) -> str:
return self._raw_output_data_prefix_override

@staticmethod
def _check_binary():
"""
Expand Down Expand Up @@ -81,58 +68,30 @@ def exists(self, remote_path):
except Exception:
return False

def download_directory(self, remote_path, local_path):
"""
:param Text remote_path: remote gs:// path
:param Text local_path: directory to copy to
"""
GCSPersistence._check_binary()

if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(remote_path), local_path)
return _update_cmd_config_and_execute(cmd)

def download(self, remote_path, local_path):
"""
:param Text remote_path: remote gs:// path
:param Text local_path: directory to copy to
"""
if not remote_path.startswith("gs://"):
def get(self, from_path: str, to_path: str, recursive: bool = False):
if not from_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

GCSPersistence._check_binary()
if recursive:
cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(from_path), to_path)
else:
cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path)

cmd = self._maybe_with_gsutil_parallelism("cp", remote_path, local_path)
return _update_cmd_config_and_execute(cmd)

def upload(self, file_path, to_path):
"""
:param Text file_path:
:param Text to_path:
"""
GCSPersistence._check_binary()

cmd = self._maybe_with_gsutil_parallelism("cp", file_path, to_path)
return _update_cmd_config_and_execute(cmd)

def upload_directory(self, local_path, remote_path):
"""
:param Text local_path:
:param Text remote_path:
"""
if not remote_path.startswith("gs://"):
raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...")

def put(self, from_path: str, to_path: str, recursive: bool = False):
GCSPersistence._check_binary()

cmd = self._maybe_with_gsutil_parallelism(
"cp",
"-r",
_amend_path(local_path),
remote_path if remote_path.endswith("/") else remote_path + "/",
)
if recursive:
cmd = self._maybe_with_gsutil_parallelism(
"cp",
"-r",
_amend_path(from_path),
to_path if to_path.endswith("/") else to_path + "/",
)
else:
cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path)
return _update_cmd_config_and_execute(cmd)

def construct_path(self, add_protocol: bool, *paths) -> str:
Expand All @@ -142,4 +101,4 @@ def construct_path(self, add_protocol: bool, *paths) -> str:
return path


DataPersistencePlugins.register_plugin("gcs://", GCSPersistence())
DataPersistencePlugins.register_plugin("gcs://", GCSPersistence())
Loading

0 comments on commit 9f3b30d

Please sign in to comment.