diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 4e8f5a8f18..5e1757bb13 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -93,7 +93,7 @@ def pull( def _save_pulled_checksums(self, cache): for checksum in cache.scheme_keys("local"): cache_file = self.repo.cache.local.checksum_to_path_info(checksum) - if self.repo.cache.local.exists(cache_file): + if self.repo.cache.local.tree.exists(cache_file): # We can safely save here, as existing corrupted files will # be removed upon status, while files corrupted during # download will not be moved from tmp_file diff --git a/dvc/output/base.py b/dvc/output/base.py index 344774c2c1..c651d570a4 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -181,7 +181,7 @@ def is_dir_checksum(self): @property def exists(self): - return self.remote.exists(self.path_info) + return self.remote.tree.exists(self.path_info) def save_info(self): return self.remote.save_info(self.path_info) @@ -217,13 +217,13 @@ def changed(self): @property def is_empty(self): - return self.remote.is_empty(self.path_info) + return self.remote.tree.is_empty(self.path_info) def isdir(self): - return self.remote.isdir(self.path_info) + return self.remote.tree.isdir(self.path_info) def isfile(self): - return self.remote.isfile(self.path_info) + return self.remote.tree.isfile(self.path_info) def ignore(self): if not self.use_scm_ignore: @@ -326,7 +326,7 @@ def checkout( ) def remove(self, ignore_remove=False): - self.remote.remove(self.path_info) + self.remote.tree.remove(self.path_info) if self.scheme != "local": return @@ -337,7 +337,7 @@ def move(self, out): if self.scheme == "local" and self.use_scm_ignore: self.repo.scm.ignore_remove(self.fspath) - self.remote.move(self.path_info, out.path_info) + self.remote.tree.move(self.path_info, out.path_info) self.def_path = out.def_path self.path_info = out.path_info self.save() diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index e23d8ed2db..0b662b4a08 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -8,12 +8,45 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) +class AzureRemoteTree(BaseRemoteTree): + @property + def blob_service(self): + return self.remote.blob_service + + def _generate_download_url(self, path_info, expires=3600): + from azure.storage.blob import BlobPermissions + + expires_at = datetime.utcnow() + timedelta(seconds=expires) + + sas_token = self.blob_service.generate_blob_shared_access_signature( + path_info.bucket, + path_info.path, + permission=BlobPermissions.READ, + expiry=expires_at, + ) + download_url = self.blob_service.make_blob_url( + path_info.bucket, path_info.path, sas_token=sas_token + ) + return download_url + + def exists(self, path_info): + paths = self.remote.list_paths(path_info.bucket, path_info.path) + return any(path_info.path == path for path in paths) + + def remove(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + logger.debug(f"Removing {path_info}") + self.blob_service.delete_blob(path_info.bucket, path_info.path) + + class AzureRemote(BaseRemote): scheme = Schemes.AZURE path_cls = CloudURLInfo @@ -21,6 +54,7 @@ class AzureRemote(BaseRemote): PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 5000 + TREE_CLS = AzureRemoteTree def __init__(self, repo, config): super().__init__(repo, config) @@ -65,14 +99,7 @@ def get_etag(self, path_info): def get_file_checksum(self, path_info): return self.get_etag(path_info) - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - logger.debug(f"Removing {path_info}") - self.blob_service.delete_blob(path_info.bucket, path_info.path) - - def _list_paths(self, bucket, prefix, progress_callback=None): + def list_paths(self, bucket, prefix, progress_callback=None): blob_service = self.blob_service next_marker = None while True: @@ -97,7 +124,7 @@ def list_cache_paths(self, prefix=None, progress_callback=None): ) else: prefix = self.path_info.path - return self._list_paths( + return self.list_paths( self.path_info.bucket, prefix, progress_callback ) @@ -122,23 +149,3 @@ def _download( to_file, progress_callback=pbar.update_to, ) - - def exists(self, path_info): - paths = self._list_paths(path_info.bucket, path_info.path) - return any(path_info.path == path for path in paths) - - def _generate_download_url(self, path_info, expires=3600): - from azure.storage.blob import BlobPermissions - - expires_at = datetime.utcnow() + timedelta(seconds=expires) - - sas_token = self.blob_service.generate_blob_shared_access_signature( - path_info.bucket, - path_info.path, - permission=BlobPermissions.READ, - expiry=expires_at, - ) - download_url = self.blob_service.make_blob_url( - path_info.bucket, path_info.path, sas_token=sas_token - ) - return download_url diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 8c6d6ac1d7..60d502dbec 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -83,12 +83,99 @@ def wrapper(remote_obj, *args, **kwargs): return wrapper +class BaseRemoteTree: + SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + + def __init__(self, remote, config): + self.remote = remote + shared = config.get("shared") + self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared] + + @property + def file_mode(self): + return self._file_mode + + @property + def dir_mode(self): + return self._dir_mode + + @property + def scheme(self): + return self.remote.scheme + + @property + def path_cls(self): + return self.remote.path_cls + + def open(self, path_info, mode="r", encoding=None): + if hasattr(self, "_generate_download_url"): + get_url = partial(self._generate_download_url, path_info) + return open_url(get_url, mode=mode, encoding=encoding) + + raise RemoteActionNotImplemented("open", self.scheme) + + def exists(self, path_info): + raise NotImplementedError + + def isdir(self, path_info): + """Optional: Overwrite only if the remote has a way to distinguish + between a directory and a file. + """ + return False + + def isfile(self, path_info): + """Optional: Overwrite only if the remote has a way to distinguish + between a directory and a file. + """ + return True + + def iscopy(self, path_info): + """Check if this file is an independent copy.""" + return False # We can't be sure by default + + def walk_files(self, path_info): + """Return a generator with `PathInfo`s to all the files""" + raise NotImplementedError + + def is_empty(self, path_info): + return False + + def remove(self, path_info): + raise RemoteActionNotImplemented("remove", self.scheme) + + def makedirs(self, path_info): + """Optional: Implement only if the remote needs to create + directories before copying/linking/moving data + """ + + def move(self, from_info, to_info, mode=None): + assert mode is None + self.copy(from_info, to_info) + self.remove(from_info) + + def copy(self, from_info, to_info): + raise RemoteActionNotImplemented("copy", self.scheme) + + def copy_fobj(self, fobj, to_info): + raise RemoteActionNotImplemented("copy_fobj", self.scheme) + + def symlink(self, from_info, to_info): + raise RemoteActionNotImplemented("symlink", self.scheme) + + def hardlink(self, from_info, to_info): + raise RemoteActionNotImplemented("hardlink", self.scheme) + + def reflink(self, from_info, to_info): + raise RemoteActionNotImplemented("reflink", self.scheme) + + class BaseRemote: scheme = "base" path_cls = URLInfo REQUIRES = {} JOBS = 4 * cpu_count() INDEX_CLS = RemoteIndex + TREE_CLS = BaseRemoteTree PARAM_RELPATH = "relpath" CHECKSUM_DIR_SUFFIX = ".dir" @@ -102,7 +189,6 @@ class BaseRemote: CAN_TRAVERSE = True CACHE_MODE = None - SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} state = StateNoop() @@ -111,9 +197,6 @@ def __init__(self, repo, config): self._check_requires(config) - shared = config.get("shared") - self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared] - self.checksum_jobs = ( config.get("checksum_jobs") or (self.repo and self.repo.config["core"].get("checksum_jobs")) @@ -134,6 +217,8 @@ def __init__(self, repo, config): else: self.index = RemoteIndexNoop() + self.tree = self.TREE_CLS(self, config) + @classmethod def get_missing_deps(cls): import importlib @@ -218,7 +303,7 @@ def _collect_dir(self, path_info, tree=None, save_tree=False, **kwargs): if tree: walk_files = tree.walk_files else: - walk_files = self.walk_files + walk_files = self.tree.walk_files for fname in walk_files(path_info, **kwargs): if DvcIgnore.DVCIGNORE_FILE == fname.name: @@ -273,8 +358,8 @@ def _save_dir_info(self, dir_info, path_info=None): checksum, tmp_info = self._get_dir_info_checksum(dir_info) new_info = self.cache.checksum_to_path_info(checksum) if self.cache.changed_cache_file(checksum): - self.cache.makedirs(new_info.parent) - self.cache.move(tmp_info, new_info, mode=self.CACHE_MODE) + self.cache.tree.makedirs(new_info.parent) + self.cache.tree.move(tmp_info, new_info, mode=self.CACHE_MODE) if path_info: self.state.save(path_info, checksum) @@ -344,7 +429,7 @@ def is_dir_checksum(cls, checksum): def get_checksum(self, path_info): assert isinstance(path_info, str) or path_info.scheme == self.scheme - if not self.exists(path_info): + if not self.tree.exists(path_info): return None checksum = self.state.get(path_info) @@ -355,14 +440,16 @@ def get_checksum(self, path_info): if ( checksum and self.is_dir_checksum(checksum) - and not self.exists(self.cache.checksum_to_path_info(checksum)) + and not self.tree.exists( + self.cache.checksum_to_path_info(checksum) + ) ): checksum = None if checksum: return checksum - if self.isdir(path_info): + if self.tree.isdir(path_info): checksum = self.get_dir_checksum(path_info) else: checksum = self.get_file_checksum(path_info) @@ -396,7 +483,7 @@ def changed(self, path_info, checksum_info): "checking if '%s'('%s') has changed.", path_info, checksum_info ) - if not self.exists(path_info): + if not self.tree.exists(path_info): logger.debug("'%s' doesn't exist.", path_info) return True @@ -428,9 +515,9 @@ def link(self, from_info, to_info): self._link(from_info, to_info, self.cache_types) def _link(self, from_info, to_info, link_types): - assert self.isfile(from_info) + assert self.tree.isfile(from_info) - self.makedirs(to_info.parent) + self.tree.makedirs(to_info.parent) self._try_links(from_info, to_info, link_types) @@ -438,9 +525,9 @@ def _verify_link(self, path_info, link_type): if self.cache_type_confirmed: return - is_link = getattr(self, f"is_{link_type}", None) + is_link = getattr(self.tree, f"is_{link_type}", None) if is_link and not is_link(path_info): - self.remove(path_info) + self.tree.remove(path_info) raise DvcException(f"failed to verify {link_type}") self.cache_type_confirmed = True @@ -448,7 +535,7 @@ def _verify_link(self, path_info, link_type): @slow_link_guard def _try_links(self, from_info, to_info, link_types): while link_types: - link_method = getattr(self, link_types[0]) + link_method = getattr(self.tree, link_types[0]) try: self._do_link(from_info, to_info, link_method) self._verify_link(to_info, link_types[0]) @@ -463,7 +550,7 @@ def _try_links(self, from_info, to_info, link_types): raise DvcException("no possible cache types left to try out.") def _do_link(self, from_info, to_info, link_method): - if self.exists(to_info): + if self.tree.exists(to_info): raise DvcException(f"Link '{to_info}' already exists!") link_method(from_info, to_info) @@ -486,19 +573,21 @@ def _save_file( if not ( tree.isdvc(path_info, strict=False) and tree.fetch ): - self.copy_fobj(fobj, cache_info) + self.tree.copy_fobj(fobj, cache_info) callback = kwargs.get("download_callback") if callback: callback(1) else: if self.changed_cache(checksum): - self.move(path_info, cache_info, mode=self.CACHE_MODE) + self.tree.move(path_info, cache_info, mode=self.CACHE_MODE) self.link(cache_info, path_info) - elif self.iscopy(path_info) and self._cache_is_copy(path_info): + elif self.tree.iscopy(path_info) and self._cache_is_copy( + path_info + ): # Default relink procedure involves unneeded copy self.unprotect(path_info) else: - self.remove(path_info) + self.tree.remove(path_info) self.link(cache_info, path_info) if save_link: @@ -522,14 +611,14 @@ def _cache_is_copy(self, path_info): workspace_file = path_info.with_name("." + uuid()) test_cache_file = self.path_info / ".cache_type_test_file" - if not self.exists(test_cache_file): - with self.open(test_cache_file, "wb") as fobj: + if not self.tree.exists(test_cache_file): + with self.tree.open(test_cache_file, "wb") as fobj: fobj.write(bytes(1)) try: self.link(test_cache_file, workspace_file) finally: - self.remove(workspace_file) - self.remove(test_cache_file) + self.tree.remove(workspace_file) + self.tree.remove(test_cache_file) self.cache_type_confirmed = True return self.cache_types[0] == "copy" @@ -561,29 +650,6 @@ def _save_dir( self.state.save(path_info, checksum) return {self.PARAM_CHECKSUM: checksum} - def is_empty(self, path_info): - return False - - def isfile(self, path_info): - """Optional: Overwrite only if the remote has a way to distinguish - between a directory and a file. - """ - return True - - def isdir(self, path_info): - """Optional: Overwrite only if the remote has a way to distinguish - between a directory and a file. - """ - return False - - def iscopy(self, path_info): - """Check if this file is an independent copy.""" - return False # We can't be sure by default - - def walk_files(self, path_info): - """Return a generator with `PathInfo`s to all the files""" - raise NotImplementedError - @staticmethod def protect(path_info): pass @@ -617,7 +683,7 @@ def _save(self, path_info, checksum, save_link=True, tree=None, **kwargs): isdir = tree.isdir save_link = False else: - isdir = self.isdir + isdir = self.tree.isdir if isdir(path_info): return self._save_dir( @@ -625,6 +691,9 @@ def _save(self, path_info, checksum, save_link=True, tree=None, **kwargs): ) return self._save_file(path_info, checksum, save_link, tree, **kwargs) + def open(self, *args, **kwargs): + return self.tree.open(*args, **kwargs) + def _handle_transfer_exception( self, from_info, to_info, exception, operation ): @@ -680,13 +749,13 @@ def download( raise NotImplementedError if to_info.scheme == self.scheme != "local": - self.copy(from_info, to_info) + self.tree.copy(from_info, to_info) return 0 if to_info.scheme != "local": raise NotImplementedError - if self.isdir(from_info): + if self.tree.isdir(from_info): return self._download_dir( from_info, to_info, name, no_progress_bar, file_mode, dir_mode ) @@ -697,7 +766,7 @@ def download( def _download_dir( self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode ): - from_infos = list(self.walk_files(from_info)) + from_infos = list(self.tree.walk_files(from_info)) to_infos = ( to_info / info.relative_to(from_info) for info in from_infos ) @@ -744,39 +813,6 @@ def _download_file( return 0 - def open(self, path_info, mode="r", encoding=None): - if hasattr(self, "_generate_download_url"): - get_url = partial(self._generate_download_url, path_info) - return open_url(get_url, mode=mode, encoding=encoding) - - raise RemoteActionNotImplemented("open", self.scheme) - - def remove(self, path_info): - raise RemoteActionNotImplemented("remove", self.scheme) - - def move(self, from_info, to_info, mode=None): - assert mode is None - self.copy(from_info, to_info) - self.remove(from_info) - - def copy(self, from_info, to_info): - raise RemoteActionNotImplemented("copy", self.scheme) - - def copy_fobj(self, fobj, to_info): - raise RemoteActionNotImplemented("copy_fobj", self.scheme) - - def symlink(self, from_info, to_info): - raise RemoteActionNotImplemented("symlink", self.scheme) - - def hardlink(self, from_info, to_info): - raise RemoteActionNotImplemented("hardlink", self.scheme) - - def reflink(self, from_info, to_info): - raise RemoteActionNotImplemented("reflink", self.scheme) - - def exists(self, path_info): - raise NotImplementedError - def path_to_checksum(self, path): parts = self.path_cls(path).parts[-2:] @@ -849,7 +885,7 @@ def gc(self, named_cache, jobs=None): if self.is_dir_checksum(checksum): # backward compatibility self._remove_unpacked_dir(checksum) - self.remove(path_info) + self.tree.remove(path_info) removed = True if removed: self.index.clear() @@ -896,9 +932,9 @@ def changed_cache_file(self, checksum): self.protect(cache_info) return False - if self.exists(cache_info): + if self.tree.exists(cache_info): logger.warning("corrupted cache file '%s'.", cache_info) - self.remove(cache_info) + self.tree.remove(cache_info) return True @@ -1142,7 +1178,7 @@ def _cache_object_exists(self, checksums, jobs=None, name=None): ) as pbar: def exists_with_progress(path_info): - ret = self.exists(path_info) + ret = self.tree.exists(path_info) pbar.update_msg(str(path_info)) return ret @@ -1161,7 +1197,7 @@ def already_cached(self, path_info): return not self.changed_cache(current) def safe_remove(self, path_info, force=False): - if not self.exists(path_info): + if not self.tree.exists(path_info): return if not force and not self.already_cached(path_info): @@ -1173,7 +1209,7 @@ def safe_remove(self, path_info, force=False): if not prompt.confirm(msg): raise ConfirmRemoveError(str(path_info)) - self.remove(path_info) + self.tree.remove(path_info) def _checkout_file( self, path_info, checksum, force, progress_callback=None, relink=False @@ -1181,7 +1217,7 @@ def _checkout_file( """The file is changed we need to checkout a new copy""" added, modified = True, False cache_info = self.checksum_to_path_info(checksum) - if self.exists(path_info): + if self.tree.exists(path_info): logger.debug("data '%s' will be replaced.", path_info) self.safe_remove(path_info, force=force) added, modified = False, True @@ -1194,11 +1230,6 @@ def _checkout_file( return added, modified and not relink - def makedirs(self, path_info): - """Optional: Implement only if the remote needs to create - directories before copying/linking/moving data - """ - def _checkout_dir( self, path_info, @@ -1211,9 +1242,9 @@ def _checkout_dir( added, modified = False, False # Create dir separately so that dir is created # even if there are no files in it - if not self.exists(path_info): + if not self.tree.exists(path_info): added = True - self.makedirs(path_info) + self.tree.makedirs(path_info) dir_info = self.get_dir_cache(checksum) @@ -1249,7 +1280,7 @@ def _checkout_dir( return added, not added and modified and not relink def _remove_redundant_files(self, path_info, dir_info, force): - existing_files = set(self.walk_files(path_info)) + existing_files = set(self.tree.walk_files(path_info)) needed_files = { path_info / entry[self.PARAM_RELPATH] for entry in dir_info diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index dfc7b1d5a5..3cef0760ee 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -12,7 +12,7 @@ from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname @@ -87,6 +87,20 @@ def __init__(self, url): self._spath = re.sub("/{2,}", "/", self._spath.rstrip("/")) +class GDriveRemoteTree(BaseRemoteTree): + def exists(self, path_info): + try: + self.remote.get_item_id(path_info) + except GDrivePathNotFound: + return False + else: + return True + + def remove(self, path_info): + item_id = self.remote.get_item_id(path_info) + self.remote.gdrive_delete_file(item_id) + + class GDriveRemote(BaseRemote): scheme = Schemes.GDRIVE path_cls = GDriveURLInfo @@ -95,6 +109,7 @@ class GDriveRemote(BaseRemote): # Always prefer traverse for GDrive since API usage quotas are a concern. TRAVERSE_WEIGHT_MULTIPLIER = 1 TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = GDriveRemoteTree GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -281,7 +296,7 @@ def _ids_cache(self): cache = { "dirs": defaultdict(list), "ids": {}, - "root_id": self._get_item_id( + "root_id": self.get_item_id( self.path_info, use_cache=False, hint="Confirm the directory exists and you can access it.", @@ -394,7 +409,7 @@ def _gdrive_download_file( gdrive_file.GetContentFile(to_file, callback=pbar.update_to) @_gdrive_retry - def _gdrive_delete_file(self, item_id): + def gdrive_delete_file(self, item_id): from pydrive2.files import ApiRequestError param = {"id": item_id} @@ -498,7 +513,7 @@ def _path_to_item_ids(self, path, create, use_cache): [self._create_dir(min(parent_ids), title, path)] if create else [] ) - def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): + def get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert path_info.bucket == self._bucket item_ids = self._path_to_item_ids(path_info.path, create, use_cache) @@ -508,25 +523,17 @@ def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): assert not create raise GDrivePathNotFound(path_info, hint) - def exists(self, path_info): - try: - self._get_item_id(path_info) - except GDrivePathNotFound: - return False - else: - return True - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): dirname = to_info.parent assert dirname - parent_id = self._get_item_id(dirname, True) + parent_id = self.get_item_id(dirname, True) self._gdrive_upload_file( parent_id, to_info.name, no_progress_bar, from_file, name ) def _download(self, from_info, to_file, name=None, no_progress_bar=False): - item_id = self._get_item_id(from_info) + item_id = self.get_item_id(from_info) self._gdrive_download_file(item_id, to_file, name, no_progress_bar) def list_cache_paths(self, prefix=None, progress_callback=None): @@ -552,12 +559,5 @@ def list_cache_paths(self, prefix=None, progress_callback=None): self._ids_cache["ids"][parent_id], item["title"] ) - def remove(self, path_info): - item_id = self._get_item_id(path_info) - self._gdrive_delete_file(item_id) - def get_file_checksum(self, path_info): raise NotImplementedError - - def walk_files(self, path_info): - raise NotImplementedError diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index ffafee2403..5b42c8f24f 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -10,7 +10,7 @@ from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -65,11 +65,82 @@ def _upload_to_bucket( blob.upload_from_file(wrapped) +class GSRemoteTree(BaseRemoteTree): + @property + def gs(self): + return self.remote.gs + + def _generate_download_url(self, path_info, expires=3600): + expiration = timedelta(seconds=int(expires)) + + bucket = self.gs.bucket(path_info.bucket) + blob = bucket.get_blob(path_info.path) + if blob is None: + raise FileNotFoundError + return blob.generate_signed_url(expiration=expiration) + + def exists(self, path_info): + """Check if the blob exists. If it does not exist, + it could be a part of a directory path. + + eg: if `data/file.txt` exists, check for `data` should return True + """ + return self.isfile(path_info) or self.isdir(path_info) + + def isdir(self, path_info): + dir_path = path_info / "" + return bool(list(self.remote.list_paths(dir_path, max_items=1))) + + def isfile(self, path_info): + if path_info.path.endswith("/"): + return False + + blob = self.gs.bucket(path_info.bucket).blob(path_info.path) + return blob.exists() + + def walk_files(self, path_info): + for fname in self.remote.list_paths(path_info / ""): + # skip nested empty directories + if fname.endswith("/"): + continue + yield path_info.replace(fname) + + def remove(self, path_info): + if path_info.scheme != "gs": + raise NotImplementedError + + logger.debug(f"Removing gs://{path_info}") + blob = self.gs.bucket(path_info.bucket).get_blob(path_info.path) + if not blob: + return + + blob.delete() + + def makedirs(self, path_info): + if not path_info.path: + return + + self.gs.bucket(path_info.bucket).blob( + (path_info / "").path + ).upload_from_string("") + + def copy(self, from_info, to_info): + from_bucket = self.gs.bucket(from_info.bucket) + blob = from_bucket.get_blob(from_info.path) + if not blob: + msg = f"'{from_info.path}' doesn't exist in the cloud" + raise DvcException(msg) + + to_bucket = self.gs.bucket(to_info.bucket) + from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) + + class GSRemote(BaseRemote): scheme = Schemes.GS path_cls = CloudURLInfo REQUIRES = {"google-cloud-storage": "google.cloud.storage"} PARAM_CHECKSUM = "md5" + TREE_CLS = GSRemoteTree def __init__(self, repo, config): super().__init__(repo, config) @@ -105,28 +176,7 @@ def get_file_checksum(self, path_info): md5 = base64.b64decode(b64_md5) return codecs.getencoder("hex")(md5)[0].decode("utf-8") - def copy(self, from_info, to_info): - from_bucket = self.gs.bucket(from_info.bucket) - blob = from_bucket.get_blob(from_info.path) - if not blob: - msg = f"'{from_info.path}' doesn't exist in the cloud" - raise DvcException(msg) - - to_bucket = self.gs.bucket(to_info.bucket) - from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) - - def remove(self, path_info): - if path_info.scheme != "gs": - raise NotImplementedError - - logger.debug(f"Removing gs://{path_info}") - blob = self.gs.bucket(path_info.bucket).get_blob(path_info.path) - if not blob: - return - - blob.delete() - - def _list_paths( + def list_paths( self, path_info, max_items=None, prefix=None, progress_callback=None ): if prefix: @@ -141,44 +191,10 @@ def _list_paths( yield blob.name def list_cache_paths(self, prefix=None, progress_callback=None): - return self._list_paths( + return self.list_paths( self.path_info, prefix=prefix, progress_callback=progress_callback ) - def walk_files(self, path_info): - for fname in self._list_paths(path_info / ""): - # skip nested empty directories - if fname.endswith("/"): - continue - yield path_info.replace(fname) - - def makedirs(self, path_info): - if not path_info.path: - return - - self.gs.bucket(path_info.bucket).blob( - (path_info / "").path - ).upload_from_string("") - - def isdir(self, path_info): - dir_path = path_info / "" - return bool(list(self._list_paths(dir_path, max_items=1))) - - def isfile(self, path_info): - if path_info.path.endswith("/"): - return False - - blob = self.gs.bucket(path_info.bucket).blob(path_info.path) - return blob.exists() - - def exists(self, path_info): - """Check if the blob exists. If it does not exist, - it could be a part of a directory path. - - eg: if `data/file.txt` exists, check for `data` should return True - """ - return self.isfile(path_info) or self.isdir(path_info) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): bucket = self.gs.bucket(to_info.bucket) _upload_to_bucket( @@ -201,12 +217,3 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): disable=no_progress_bar, ) as wrapped: blob.download_to_file(wrapped) - - def _generate_download_url(self, path_info, expires=3600): - expiration = timedelta(seconds=int(expires)) - - bucket = self.gs.bucket(path_info.bucket) - blob = bucket.get_blob(path_info.path) - if blob is None: - raise FileNotFoundError - return blob.generate_signed_url(expiration=expiration) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 6d2e7cfb66..4e9430ab7c 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -11,18 +11,75 @@ from dvc.scheme import Schemes from dvc.utils import fix_env, tmp_fname -from .base import BaseRemote, RemoteCmdError +from .base import BaseRemote, BaseRemoteTree, RemoteCmdError from .pool import get_connection logger = logging.getLogger(__name__) +class HDFSRemoteTree(BaseRemoteTree): + @property + def hdfs(self): + return self.remote.hdfs + + @contextmanager + def open(self, path_info, mode="r", encoding=None): + assert mode in {"r", "rt", "rb"} + + try: + with self.hdfs(path_info) as hdfs, closing( + hdfs.open(path_info.path, mode="rb") + ) as fd: + if mode == "rb": + yield fd + else: + yield io.TextIOWrapper(fd, encoding=encoding) + except OSError as e: + # Empty .errno and not specific enough error class in pyarrow, + # see https://issues.apache.org/jira/browse/ARROW-6248 + if "file does not exist" in str(e): + raise FileNotFoundError(*e.args) + raise + + def exists(self, path_info): + assert not isinstance(path_info, list) + assert path_info.scheme == "hdfs" + with self.hdfs(path_info) as hdfs: + return hdfs.exists(path_info.path) + + def remove(self, path_info): + if path_info.scheme != "hdfs": + raise NotImplementedError + + if self.exists(path_info): + logger.debug(f"Removing {path_info.path}") + with self.hdfs(path_info) as hdfs: + hdfs.rm(path_info.path) + + def copy(self, from_info, to_info, **_kwargs): + dname = posixpath.dirname(to_info.path) + with self.hdfs(to_info) as hdfs: + hdfs.mkdir(dname) + # NOTE: this is how `hadoop fs -cp` works too: it copies through + # your local machine. + with hdfs.open(from_info.path, "rb") as from_fobj: + tmp_info = to_info.parent / tmp_fname(to_info.name) + try: + with hdfs.open(tmp_info.path, "wb") as tmp_fobj: + tmp_fobj.upload(from_fobj) + hdfs.rename(tmp_info.path, to_info.path) + except Exception: + self.remove(tmp_info) + raise + + class HDFSRemote(BaseRemote): scheme = Schemes.HDFS REGEX = r"^hdfs://((?P.*)@)?.*$" PARAM_CHECKSUM = "checksum" REQUIRES = {"pyarrow": "pyarrow"} TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = HDFSRemoteTree def __init__(self, repo, config): super().__init__(repo, config) @@ -91,37 +148,6 @@ def get_file_checksum(self, path_info): ) return self._group(regex, stdout, "checksum") - def copy(self, from_info, to_info, **_kwargs): - dname = posixpath.dirname(to_info.path) - with self.hdfs(to_info) as hdfs: - hdfs.mkdir(dname) - # NOTE: this is how `hadoop fs -cp` works too: it copies through - # your local machine. - with hdfs.open(from_info.path, "rb") as from_fobj: - tmp_info = to_info.parent / tmp_fname(to_info.name) - try: - with hdfs.open(tmp_info.path, "wb") as tmp_fobj: - tmp_fobj.upload(from_fobj) - hdfs.rename(tmp_info.path, to_info.path) - except Exception: - self.remove(tmp_info) - raise - - def remove(self, path_info): - if path_info.scheme != "hdfs": - raise NotImplementedError - - if self.exists(path_info): - logger.debug(f"Removing {path_info.path}") - with self.hdfs(path_info) as hdfs: - hdfs.rm(path_info.path) - - def exists(self, path_info): - assert not isinstance(path_info, list) - assert path_info.scheme == "hdfs" - with self.hdfs(path_info) as hdfs: - return hdfs.exists(path_info.path) - def _upload(self, from_file, to_info, **_kwargs): with self.hdfs(to_info) as hdfs: hdfs.mkdir(posixpath.dirname(to_info.path)) @@ -135,27 +161,8 @@ def _download(self, from_info, to_file, **_kwargs): with open(to_file, "wb+") as fobj: hdfs.download(from_info.path, fobj) - @contextmanager - def open(self, path_info, mode="r", encoding=None): - assert mode in {"r", "rt", "rb"} - - try: - with self.hdfs(path_info) as hdfs, closing( - hdfs.open(path_info.path, mode="rb") - ) as fd: - if mode == "rb": - yield fd - else: - yield io.TextIOWrapper(fd, encoding=encoding) - except OSError as e: - # Empty .errno and not specific enough error class in pyarrow, - # see https://issues.apache.org/jira/browse/ARROW-6248 - if "file does not exist" in str(e): - raise FileNotFoundError(*e.args) - raise - def list_cache_paths(self, prefix=None, progress_callback=None): - if not self.exists(self.path_info): + if not self.tree.exists(self.path_info): return if prefix: diff --git a/dvc/remote/http.py b/dvc/remote/http.py index a056c28439..62252ea720 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -8,7 +8,7 @@ from dvc.exceptions import DvcException, HTTPError from dvc.path_info import HTTPURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -23,6 +23,11 @@ def ask_password(host, user): ) +class HTTPRemoteTree(BaseRemoteTree): + def exists(self, path_info): + return bool(self.remote.request("HEAD", path_info.url)) + + class HTTPRemote(BaseRemote): scheme = Schemes.HTTP path_cls = HTTPURLInfo @@ -32,6 +37,7 @@ class HTTPRemote(BaseRemote): CHUNK_SIZE = 2 ** 16 PARAM_CHECKSUM = "etag" CAN_TRAVERSE = False + TREE_CLS = HTTPRemoteTree def __init__(self, repo, config): super().__init__(repo, config) @@ -52,7 +58,7 @@ def __init__(self, repo, config): self.headers = {} def _download(self, from_info, to_file, name=None, no_progress_bar=False): - response = self._request("GET", from_info.url, stream=True) + response = self.request("GET", from_info.url, stream=True) if response.status_code != 200: raise HTTPError(response.status_code, response.reason) with open(to_file, "wb") as fd: @@ -88,20 +94,17 @@ def chunks(): break yield chunk - response = self._request("POST", to_info.url, data=chunks()) + response = self.request("POST", to_info.url, data=chunks()) if response.status_code not in (200, 201): raise HTTPError(response.status_code, response.reason) - def exists(self, path_info): - return bool(self._request("HEAD", path_info.url)) - def _content_length(self, response): res = response.headers.get("Content-Length") return int(res) if res else None def get_file_checksum(self, path_info): url = path_info.url - headers = self._request("HEAD", url).headers + headers = self.request("HEAD", url).headers etag = headers.get("ETag") or headers.get("Content-MD5") if not etag: @@ -155,7 +158,7 @@ def _session(self): return session - def _request(self, method, url, **kwargs): + def request(self, method, url, **kwargs): import requests kwargs.setdefault("allow_redirects", True) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index d9a347f687..6ae7451326 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -17,6 +17,7 @@ STATUS_MISSING, STATUS_NEW, BaseRemote, + BaseRemoteTree, index_locked, ) from dvc.remote.index import RemoteIndexNoop @@ -36,37 +37,12 @@ logger = logging.getLogger(__name__) -class LocalRemote(BaseRemote): - scheme = Schemes.LOCAL - path_cls = PathInfo - PARAM_CHECKSUM = "md5" - PARAM_PATH = "path" - TRAVERSE_PREFIX_LEN = 2 - INDEX_CLS = RemoteIndexNoop - - UNPACKED_DIR_SUFFIX = ".unpacked" - - DEFAULT_CACHE_TYPES = ["reflink", "copy"] - - CACHE_MODE = 0o444 +class LocalRemoteTree(BaseRemoteTree): SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} - def __init__(self, repo, config): - super().__init__(repo, config) - self.cache_dir = config.get("url") - self._dir_info = {} - @property - def state(self): - return self.repo.state - - @property - def cache_dir(self): - return self.path_info.fspath if self.path_info else None - - @cache_dir.setter - def cache_dir(self, value): - self.path_info = PathInfo(value) if value else None + def repo(self): + return self.remote.repo @cached_property def _work_tree(self): @@ -83,42 +59,9 @@ def work_tree(self): return self._work_tree return None - @classmethod - def supported(cls, config): - return True - - @cached_property - def cache_path(self): - return os.path.abspath(self.cache_dir) - - def checksum_to_path(self, checksum): - # NOTE: `self.cache_path` is already normalized so we can simply use - # `os.sep` instead of `os.path.join`. This results in this helper - # being ~5.5 times faster. - return ( - f"{self.cache_path}{os.sep}{checksum[0:2]}{os.sep}{checksum[2:]}" - ) - - def list_cache_paths(self, prefix=None, progress_callback=None): - assert self.path_info is not None - if prefix: - path_info = self.path_info / prefix[:2] - if not self.exists(path_info): - return - else: - path_info = self.path_info - if progress_callback: - for path in walk_files(path_info): - progress_callback() - yield path - else: - yield from walk_files(path_info) - - def get(self, md5): - if not md5: - return None - - return self.checksum_to_path_info(md5).url + @staticmethod + def open(path_info, mode="r", encoding=None): + return open(path_info, mode=mode, encoding=encoding) def exists(self, path_info): assert isinstance(path_info, str) or path_info.scheme == "local" @@ -128,36 +71,6 @@ def exists(self, path_info): return True return self.repo.tree.exists(path_info) - def makedirs(self, path_info): - makedirs(path_info, exist_ok=True, mode=self._dir_mode) - - def already_cached(self, path_info): - assert path_info.scheme in ["", "local"] - - current_md5 = self.get_checksum(path_info) - - if not current_md5: - return False - - return not self.changed_cache(current_md5) - - def _verify_link(self, path_info, link_type): - if link_type == "hardlink" and self.getsize(path_info) == 0: - return - - super()._verify_link(path_info, link_type) - - def is_empty(self, path_info): - path = path_info.fspath - - if self.isfile(path_info) and os.path.getsize(path) == 0: - return True - - if self.isdir(path_info) and len(os.listdir(path)) == 0: - return True - - return False - def isfile(self, path_info): if not self.repo: return os.path.isfile(path_info) @@ -177,10 +90,6 @@ def iscopy(self, path_info): System.is_symlink(path_info) or System.is_hardlink(path_info) ) - @staticmethod - def getsize(path_info): - return os.path.getsize(path_info) - def walk_files(self, path_info): if self.work_tree: tree = self.work_tree @@ -189,8 +98,16 @@ def walk_files(self, path_info): for fname in tree.walk_files(path_info): yield PathInfo(fname) - def get_file_checksum(self, path_info): - return file_md5(path_info)[0] + def is_empty(self, path_info): + path = path_info.fspath + + if self.isfile(path_info) and os.path.getsize(path) == 0: + return True + + if self.isdir(path_info) and len(os.listdir(path)) == 0: + return True + + return False def remove(self, path_info): if isinstance(path_info, PathInfo): @@ -203,6 +120,9 @@ def remove(self, path_info): if self.exists(path): remove(path) + def makedirs(self, path_info): + makedirs(path_info, exist_ok=True, mode=self.dir_mode) + def move(self, from_info, to_info, mode=None): if from_info.scheme != "local" or to_info.scheme != "local": raise NotImplementedError @@ -211,9 +131,9 @@ def move(self, from_info, to_info, mode=None): if mode is None: if self.isfile(from_info): - mode = self._file_mode + mode = self.file_mode else: - mode = self._dir_mode + mode = self.dir_mode move(from_info, to_info, mode=mode) @@ -221,7 +141,7 @@ def copy(self, from_info, to_info): tmp_info = to_info.parent / tmp_fname(to_info.name) try: System.copy(from_info, tmp_info) - os.chmod(tmp_info, self._file_mode) + os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) except Exception: self.remove(tmp_info) @@ -232,7 +152,7 @@ def copy_fobj(self, fobj, to_info): tmp_info = to_info.parent / tmp_fname(to_info.name) try: copy_fobj_to_file(fobj, tmp_info) - os.chmod(tmp_info, self._file_mode) + os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) except Exception: self.remove(tmp_info) @@ -281,9 +201,102 @@ def reflink(self, from_info, to_info): System.reflink(from_info, tmp_info) # NOTE: reflink has its own separate inode, so you can set permissions # that are different from the source. - os.chmod(tmp_info, self._file_mode) + os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) + @staticmethod + def getsize(path_info): + return os.path.getsize(path_info) + + +class LocalRemote(BaseRemote): + scheme = Schemes.LOCAL + path_cls = PathInfo + PARAM_CHECKSUM = "md5" + PARAM_PATH = "path" + TRAVERSE_PREFIX_LEN = 2 + INDEX_CLS = RemoteIndexNoop + TREE_CLS = LocalRemoteTree + + UNPACKED_DIR_SUFFIX = ".unpacked" + + DEFAULT_CACHE_TYPES = ["reflink", "copy"] + + CACHE_MODE = 0o444 + + def __init__(self, repo, config): + super().__init__(repo, config) + self.cache_dir = config.get("url") + self._dir_info = {} + + @property + def state(self): + return self.repo.state + + @property + def cache_dir(self): + return self.path_info.fspath if self.path_info else None + + @cache_dir.setter + def cache_dir(self, value): + self.path_info = PathInfo(value) if value else None + + @classmethod + def supported(cls, config): + return True + + @cached_property + def cache_path(self): + return os.path.abspath(self.cache_dir) + + def checksum_to_path(self, checksum): + # NOTE: `self.cache_path` is already normalized so we can simply use + # `os.sep` instead of `os.path.join`. This results in this helper + # being ~5.5 times faster. + return ( + f"{self.cache_path}{os.sep}{checksum[0:2]}{os.sep}{checksum[2:]}" + ) + + def list_cache_paths(self, prefix=None, progress_callback=None): + assert self.path_info is not None + if prefix: + path_info = self.path_info / prefix[:2] + if not self.tree.exists(path_info): + return + else: + path_info = self.path_info + if progress_callback: + for path in walk_files(path_info): + progress_callback() + yield path + else: + yield from walk_files(path_info) + + def get(self, md5): + if not md5: + return None + + return self.checksum_to_path_info(md5).url + + def already_cached(self, path_info): + assert path_info.scheme in ["", "local"] + + current_md5 = self.get_checksum(path_info) + + if not current_md5: + return False + + return not self.changed_cache(current_md5) + + def _verify_link(self, path_info, link_type): + if link_type == "hardlink" and self.tree.getsize(path_info) == 0: + return + + super()._verify_link(path_info, link_type) + + def get_file_checksum(self, path_info): + return file_md5(path_info)[0] + def cache_exists(self, checksums, jobs=None, name=None): return [ checksum @@ -316,10 +329,6 @@ def _download( from_info, to_file, no_progress_bar=no_progress_bar, name=name ) - @staticmethod - def open(path_info, mode="r", encoding=None): - return open(path_info, mode=mode, encoding=encoding) - @index_locked def status( self, @@ -503,8 +512,8 @@ def _process( if download: func = partial( remote.download, - dir_mode=self._dir_mode, - file_mode=self._file_mode, + dir_mode=self.tree.dir_mode, + file_mode=self.tree.file_mode, ) status = STATUS_DELETED desc = "Downloading" @@ -662,7 +671,7 @@ def _unprotect_file(self, path): "a symlink or a hardlink.".format(path) ) - os.chmod(path, self._file_mode) + os.chmod(path, self.tree.file_mode) def _unprotect_dir(self, path): assert is_working_tree(self.repo.tree) @@ -703,7 +712,7 @@ def protect(self, path_info): def _remove_unpacked_dir(self, checksum): info = self.checksum_to_path_info(checksum) path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) - self.remove(path_info) + self.tree.remove(path_info) def is_protected(self, path_info): try: diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 7575a7e380..bf29d79064 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -7,12 +7,34 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) +class OSSRemoteTree(BaseRemoteTree): + @property + def oss_service(self): + return self.remote.oss_service + + def _generate_download_url(self, path_info, expires=3600): + assert path_info.bucket == self.remote.path_info.bucket + + return self.oss_service.sign_url("GET", path_info.path, expires) + + def exists(self, path_info): + paths = self.remote.list_paths(path_info.path) + return any(path_info.path == path for path in paths) + + def remove(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + logger.debug(f"Removing oss://{path_info}") + self.oss_service.delete_object(path_info.path) + + class OSSRemote(BaseRemote): """ oss2 document: @@ -38,6 +60,7 @@ class OSSRemote(BaseRemote): PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 LIST_OBJECT_PAGE_SIZE = 100 + TREE_CLS = OSSRemoteTree def __init__(self, repo, config): super().__init__(repo, config) @@ -83,14 +106,7 @@ def oss_service(self): ) return bucket - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - logger.debug(f"Removing oss://{path_info}") - self.oss_service.delete_object(path_info.path) - - def _list_paths(self, prefix, progress_callback=None): + def list_paths(self, prefix, progress_callback=None): import oss2 for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): @@ -105,7 +121,7 @@ def list_cache_paths(self, prefix=None, progress_callback=None): ) else: prefix = self.path_info.path - return self._list_paths(prefix, progress_callback) + return self.list_paths(prefix, progress_callback) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs @@ -122,12 +138,3 @@ def _download( self.oss_service.get_object_to_file( from_info.path, to_file, progress_callback=pbar.update_to ) - - def _generate_download_url(self, path_info, expires=3600): - assert path_info.bucket == self.path_info.bucket - - return self.oss_service.sign_url("GET", path_info.path, expires) - - def exists(self, path_info): - paths = self._list_paths(path_info.path) - return any(path_info.path == path for path in paths) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 62924e3e4c..e98b12fea0 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -9,85 +9,98 @@ from dvc.exceptions import DvcException, ETagMismatchError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) -class S3Remote(BaseRemote): - scheme = Schemes.S3 - path_cls = CloudURLInfo - REQUIRES = {"boto3": "boto3"} - PARAM_CHECKSUM = "etag" - - def __init__(self, repo, config): - super().__init__(repo, config) - - url = config.get("url", "s3://") - self.path_info = self.path_cls(url) - - self.region = config.get("region") - self.profile = config.get("profile") - self.endpoint_url = config.get("endpointurl") - - if config.get("listobjects"): - self.list_objects_api = "list_objects" - else: - self.list_objects_api = "list_objects_v2" +class S3RemoteTree(BaseRemoteTree): + @property + def s3(self): + return self.remote.s3 - self.use_ssl = config.get("use_ssl", True) + def _generate_download_url(self, path_info, expires=3600): + params = {"Bucket": path_info.bucket, "Key": path_info.path} + return self.s3.generate_presigned_url( + ClientMethod="get_object", Params=params, ExpiresIn=int(expires) + ) - self.extra_args = {} + def exists(self, path_info): + """Check if the blob exists. If it does not exist, + it could be a part of a directory path. - self.sse = config.get("sse") - if self.sse: - self.extra_args["ServerSideEncryption"] = self.sse + eg: if `data/file.txt` exists, check for `data` should return True + """ + return self.isfile(path_info) or self.isdir(path_info) - self.sse_kms_key_id = config.get("sse_kms_key_id") - if self.sse_kms_key_id: - self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id + def isdir(self, path_info): + # S3 doesn't have a concept for directories. + # + # Using `head_object` with a path pointing to a directory + # will throw a 404 error. + # + # A reliable way to know if a given path is a directory is by + # checking if there are more files sharing the same prefix + # with a `list_objects` call. + # + # We need to make sure that the path ends with a forward slash, + # since we can end with false-positives like the following example: + # + # bucket + # └── data + # ├── alice + # └── alpha + # + # Using `data/al` as prefix will return `[data/alice, data/alpha]`, + # While `data/al/` will return nothing. + # + dir_path = path_info / "" + return bool(list(self.remote.list_paths(dir_path, max_items=1))) - self.acl = config.get("acl") - if self.acl: - self.extra_args["ACL"] = self.acl + def isfile(self, path_info): + from botocore.exceptions import ClientError - self._append_aws_grants_to_extra_args(config) + if path_info.path.endswith("/"): + return False - shared_creds = config.get("credentialpath") - if shared_creds: - os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + try: + self.s3.head_object(Bucket=path_info.bucket, Key=path_info.path) + except ClientError as exc: + if exc.response["Error"]["Code"] != "404": + raise + return False - @wrap_prop(threading.Lock()) - @cached_property - def s3(self): - import boto3 + return True - session = boto3.session.Session( - profile_name=self.profile, region_name=self.region - ) + def walk_files(self, path_info, max_items=None): + for fname in self.remote.list_paths(path_info / "", max_items): + if fname.endswith("/"): + continue - return session.client( - "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl - ) + yield path_info.replace(path=fname) - @classmethod - def get_etag(cls, s3, bucket, path): - obj = cls.get_head_object(s3, bucket, path) + def remove(self, path_info): + if path_info.scheme != "s3": + raise NotImplementedError - return obj["ETag"].strip('"') + logger.debug(f"Removing {path_info}") + self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def get_file_checksum(self, path_info): - return self.get_etag(self.s3, path_info.bucket, path_info.path) + def makedirs(self, path_info): + # We need to support creating empty directories, which means + # creating an object with an empty body and a trailing slash `/`. + # + # We are not creating directory objects for every parent prefix, + # as it is not required. + if not path_info.path: + return - @staticmethod - def get_head_object(s3, bucket, path, *args, **kwargs): + dir_path = path_info / "" + self.s3.put_object(Bucket=path_info.bucket, Key=dir_path.path, Body="") - try: - obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) - except Exception as exc: - raise DvcException(f"s3://{bucket}/{path} does not exist") from exc - return obj + def copy(self, from_info, to_info): + self._copy(self.s3, from_info, to_info, self.remote.extra_args) @classmethod def _copy_multipart( @@ -101,7 +114,7 @@ def _copy_multipart( parts = [] byte_position = 0 for i in range(1, n_parts + 1): - obj = cls.get_head_object( + obj = S3Remote.get_head_object( s3, from_info.bucket, from_info.path, PartNumber=i ) part_size = obj["ContentLength"] @@ -156,7 +169,7 @@ def _copy(cls, s3, from_info, to_info, extra_args): # object is transfered in the same chunks as it was originally. from boto3.s3.transfer import TransferConfig - obj = cls.get_head_object(s3, from_info.bucket, from_info.path) + obj = S3Remote.get_head_object(s3, from_info.bucket, from_info.path) etag = obj["ETag"].strip('"') size = obj["ContentLength"] @@ -176,19 +189,85 @@ def _copy(cls, s3, from_info, to_info, extra_args): Config=TransferConfig(multipart_threshold=size + 1), ) - cached_etag = cls.get_etag(s3, to_info.bucket, to_info.path) + cached_etag = S3Remote.get_etag(s3, to_info.bucket, to_info.path) if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) - def copy(self, from_info, to_info): - self._copy(self.s3, from_info, to_info, self.extra_args) - def remove(self, path_info): - if path_info.scheme != "s3": - raise NotImplementedError +class S3Remote(BaseRemote): + scheme = Schemes.S3 + path_cls = CloudURLInfo + REQUIRES = {"boto3": "boto3"} + PARAM_CHECKSUM = "etag" + TREE_CLS = S3RemoteTree - logger.debug(f"Removing {path_info}") - self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) + def __init__(self, repo, config): + super().__init__(repo, config) + + url = config.get("url", "s3://") + self.path_info = self.path_cls(url) + + self.region = config.get("region") + self.profile = config.get("profile") + self.endpoint_url = config.get("endpointurl") + + if config.get("listobjects"): + self.list_objects_api = "list_objects" + else: + self.list_objects_api = "list_objects_v2" + + self.use_ssl = config.get("use_ssl", True) + + self.extra_args = {} + + self.sse = config.get("sse") + if self.sse: + self.extra_args["ServerSideEncryption"] = self.sse + + self.sse_kms_key_id = config.get("sse_kms_key_id") + if self.sse_kms_key_id: + self.extra_args["SSEKMSKeyId"] = self.sse_kms_key_id + + self.acl = config.get("acl") + if self.acl: + self.extra_args["ACL"] = self.acl + + self._append_aws_grants_to_extra_args(config) + + shared_creds = config.get("credentialpath") + if shared_creds: + os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + + @wrap_prop(threading.Lock()) + @cached_property + def s3(self): + import boto3 + + session = boto3.session.Session( + profile_name=self.profile, region_name=self.region + ) + + return session.client( + "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl + ) + + @classmethod + def get_etag(cls, s3, bucket, path): + obj = cls.get_head_object(s3, bucket, path) + + return obj["ETag"].strip('"') + + def get_file_checksum(self, path_info): + return self.get_etag(self.s3, path_info.bucket, path_info.path) + + @staticmethod + def get_head_object(s3, bucket, path, *args, **kwargs): + + try: + obj = s3.head_object(Bucket=bucket, Key=path, *args, **kwargs) + except Exception as exc: + raise DvcException(f"s3://{bucket}/{path} does not exist") from exc + return obj def _list_objects( self, path_info, max_items=None, prefix=None, progress_callback=None @@ -211,7 +290,7 @@ def _list_objects( else: yield from contents - def _list_paths( + def list_paths( self, path_info, max_items=None, prefix=None, progress_callback=None ): return ( @@ -222,69 +301,10 @@ def _list_paths( ) def list_cache_paths(self, prefix=None, progress_callback=None): - return self._list_paths( + return self.list_paths( self.path_info, prefix=prefix, progress_callback=progress_callback ) - def isfile(self, path_info): - from botocore.exceptions import ClientError - - if path_info.path.endswith("/"): - return False - - try: - self.s3.head_object(Bucket=path_info.bucket, Key=path_info.path) - except ClientError as exc: - if exc.response["Error"]["Code"] != "404": - raise - return False - - return True - - def exists(self, path_info): - """Check if the blob exists. If it does not exist, - it could be a part of a directory path. - - eg: if `data/file.txt` exists, check for `data` should return True - """ - return self.isfile(path_info) or self.isdir(path_info) - - def makedirs(self, path_info): - # We need to support creating empty directories, which means - # creating an object with an empty body and a trailing slash `/`. - # - # We are not creating directory objects for every parent prefix, - # as it is not required. - if not path_info.path: - return - - dir_path = path_info / "" - self.s3.put_object(Bucket=path_info.bucket, Key=dir_path.path, Body="") - - def isdir(self, path_info): - # S3 doesn't have a concept for directories. - # - # Using `head_object` with a path pointing to a directory - # will throw a 404 error. - # - # A reliable way to know if a given path is a directory is by - # checking if there are more files sharing the same prefix - # with a `list_objects` call. - # - # We need to make sure that the path ends with a forward slash, - # since we can end with false-positives like the following example: - # - # bucket - # └── data - # ├── alice - # └── alpha - # - # Using `data/al` as prefix will return `[data/alice, data/alpha]`, - # While `data/al/` will return nothing. - # - dir_path = path_info / "" - return bool(list(self._list_paths(dir_path, max_items=1))) - def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) with Tqdm( @@ -312,19 +332,6 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): from_info.bucket, from_info.path, to_file, Callback=pbar.update ) - def _generate_download_url(self, path_info, expires=3600): - params = {"Bucket": path_info.bucket, "Key": path_info.path} - return self.s3.generate_presigned_url( - ClientMethod="get_object", Params=params, ExpiresIn=int(expires) - ) - - def walk_files(self, path_info, max_items=None): - for fname in self._list_paths(path_info / "", max_items): - if fname.endswith("/"): - continue - - yield path_info.replace(path=fname) - def _append_aws_grants_to_extra_args(self, config): # Keys for extra_args can be one of the following list: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index eb0482ad6c..416b2d4eed 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -14,7 +14,7 @@ import dvc.prompt as prompt from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.remote.pool import get_connection from dvc.scheme import Schemes from dvc.utils import to_chunks @@ -33,6 +33,105 @@ def ask_password(host, user, port): ) +class SSHRemoteTree(BaseRemoteTree): + @property + def ssh(self): + return self.remote.ssh + + @contextmanager + def open(self, path_info, mode="r", encoding=None): + assert mode in {"r", "rt", "rb", "wb"} + + with self.ssh(path_info) as ssh, closing( + ssh.sftp.open(path_info.path, mode) + ) as fd: + if "b" in mode: + yield fd + else: + yield io.TextIOWrapper(fd, encoding=encoding) + + def exists(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.exists(path_info.path) + + def isdir(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.isdir(path_info.path) + + def isfile(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.isfile(path_info.path) + + def walk_files(self, path_info): + with self.ssh(path_info) as ssh: + for fname in ssh.walk_files(path_info.path): + yield path_info.replace(path=fname) + + def remove(self, path_info): + if path_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(path_info) as ssh: + ssh.remove(path_info.path) + + def makedirs(self, path_info): + with self.ssh(path_info) as ssh: + ssh.makedirs(path_info.path) + + def move(self, from_info, to_info, mode=None): + assert mode is None + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.move(from_info.path, to_info.path) + + def copy(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.atomic_copy(from_info.path, to_info.path) + + def symlink(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.symlink(from_info.path, to_info.path) + + def hardlink(self, from_info, to_info): + if not from_info.scheme == to_info.scheme == self.scheme: + raise NotImplementedError + + # See dvc/remote/local/__init__.py - hardlink() + if self.getsize(from_info) == 0: + + with self.ssh(to_info) as ssh: + ssh.sftp.open(to_info.path, "w").close() + + logger.debug( + "Created empty file: {src} -> {dest}".format( + src=str(from_info), dest=str(to_info) + ) + ) + return + + with self.ssh(from_info) as ssh: + ssh.hardlink(from_info.path, to_info.path) + + def reflink(self, from_info, to_info): + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: + raise NotImplementedError + + with self.ssh(from_info) as ssh: + ssh.reflink(from_info.path, to_info.path) + + def getsize(self, path_info): + with self.ssh(path_info) as ssh: + return ssh.getsize(path_info.path) + + class SSHRemote(BaseRemote): scheme = Schemes.SSH REQUIRES = {"paramiko": "paramiko"} @@ -46,6 +145,7 @@ class SSHRemote(BaseRemote): # We use conservative setting of 4 instead to not exhaust max sessions. CHECKSUM_JOBS = 4 TRAVERSE_PREFIX_LEN = 2 + TREE_CLS = SSHRemoteTree DEFAULT_CACHE_TYPES = ["copy"] @@ -148,10 +248,6 @@ def ssh(self, path_info): sock=self.sock, ) - def exists(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.exists(path_info.path) - def get_file_checksum(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError @@ -159,74 +255,6 @@ def get_file_checksum(self, path_info): with self.ssh(path_info) as ssh: return ssh.md5(path_info.path) - def isdir(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.isdir(path_info.path) - - def isfile(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.isfile(path_info.path) - - def getsize(self, path_info): - with self.ssh(path_info) as ssh: - return ssh.getsize(path_info.path) - - def copy(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.atomic_copy(from_info.path, to_info.path) - - def symlink(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.symlink(from_info.path, to_info.path) - - def hardlink(self, from_info, to_info): - if not from_info.scheme == to_info.scheme == self.scheme: - raise NotImplementedError - - # See dvc/remote/local/__init__.py - hardlink() - if self.getsize(from_info) == 0: - - with self.ssh(to_info) as ssh: - ssh.sftp.open(to_info.path, "w").close() - - logger.debug( - "Created empty file: {src} -> {dest}".format( - src=str(from_info), dest=str(to_info) - ) - ) - return - - with self.ssh(from_info) as ssh: - ssh.hardlink(from_info.path, to_info.path) - - def reflink(self, from_info, to_info): - if from_info.scheme != self.scheme or to_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.reflink(from_info.path, to_info.path) - - def remove(self, path_info): - if path_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(path_info) as ssh: - ssh.remove(path_info.path) - - def move(self, from_info, to_info, mode=None): - assert mode is None - if from_info.scheme != self.scheme or to_info.scheme != self.scheme: - raise NotImplementedError - - with self.ssh(from_info) as ssh: - ssh.move(from_info.path, to_info.path) - def _download(self, from_info, to_file, name=None, no_progress_bar=False): assert from_info.isin(self.path_info) with self.ssh(self.path_info) as ssh: @@ -247,18 +275,6 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): no_progress_bar=no_progress_bar, ) - @contextmanager - def open(self, path_info, mode="r", encoding=None): - assert mode in {"r", "rt", "rb", "wb"} - - with self.ssh(path_info) as ssh, closing( - ssh.sftp.open(path_info.path, mode) - ) as fd: - if "b" in mode: - yield fd - else: - yield io.TextIOWrapper(fd, encoding=encoding) - def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) @@ -275,15 +291,6 @@ def list_cache_paths(self, prefix=None, progress_callback=None): else: yield from ssh.walk_files(root) - def walk_files(self, path_info): - with self.ssh(path_info) as ssh: - for fname in ssh.walk_files(path_info.path): - yield path_info.replace(path=fname) - - def makedirs(self, path_info): - with self.ssh(path_info) as ssh: - ssh.makedirs(path_info.path) - def batch_exists(self, path_infos, callback): def _exists(chunk_and_channel): chunk, channel = chunk_and_channel diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 957f19ce2d..fc4a249c3c 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -21,7 +21,7 @@ ) from dvc.main import main from dvc.output.base import OutputAlreadyTrackedError, OutputIsStageFileError -from dvc.remote import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.system import System @@ -592,7 +592,7 @@ def test_should_not_checkout_when_adding_cached_copy(tmp_dir, dvc, mocker): shutil.copy("bar", "foo") - copy_spy = mocker.spy(dvc.cache.local, "copy") + copy_spy = mocker.spy(dvc.cache.local.tree, "copy") dvc.add("foo") @@ -618,7 +618,7 @@ def test_should_relink_on_repeated_add( tmp_dir.dvc_gen({"foo": "foo", "bar": "bar"}) os.remove("foo") - getattr(dvc.cache.local, link)(PathInfo("bar"), PathInfo("foo")) + getattr(dvc.cache.local.tree, link)(PathInfo("bar"), PathInfo("foo")) dvc.cache.local.cache_types = [new_link] @@ -698,7 +698,7 @@ def test_add_empty_files(tmp_dir, dvc, link): def test_add_optimization_for_hardlink_on_empty_files(tmp_dir, dvc, mocker): dvc.cache.local.cache_types = ["hardlink"] tmp_dir.gen({"foo": "", "bar": "", "lorem": "lorem", "ipsum": "ipsum"}) - m = mocker.spy(LocalRemote, "is_hardlink") + m = mocker.spy(LocalRemoteTree, "is_hardlink") stages = dvc.add(["foo", "bar", "lorem", "ipsum"]) assert m.call_count == 1 diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index 818c76a747..6a5085f72c 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -765,10 +765,10 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): dvc.add(str(remote.path_info / "foo")) - remote.remove(file_path) + remote.tree.remove(file_path) stats = dvc.checkout(force=True) assert stats == {**empty_checkout, "added": [str(file_path)]} - assert remote.exists(file_path) + assert remote.tree.exists(file_path) remote.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo\nfoo" diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index a12f4b6e85..5520893718 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -8,7 +8,7 @@ from dvc.exceptions import CollectCacheError from dvc.main import main -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.utils.fs import remove from tests.basic_env import TestDir, TestDvcGit @@ -321,7 +321,9 @@ def test_gc_cloud_remove_order(tmp_dir, scm, dvc, tmp_path_factory, mocker): dvc.remove(dir2.relpath) dvc.gc(workspace=True) - mocked_remove = mocker.patch.object(LocalRemote, "remove", autospec=True) + mocked_remove = mocker.patch.object( + LocalRemoteTree, "remove", autospec=True + ) dvc.gc(workspace=True, cloud=True) assert len(mocked_remove.mock_calls) == 8 # dir (and unpacked dir) should be first 4 checksums removed from diff --git a/tests/func/test_ignore.py b/tests/func/test_ignore.py index 03806cb1e0..a0c648932b 100644 --- a/tests/func/test_ignore.py +++ b/tests/func/test_ignore.py @@ -140,7 +140,7 @@ def test_match_nested(tmp_dir, dvc): ) remote = LocalRemote(dvc, {}) - result = {os.fspath(f) for f in remote.walk_files(".")} + result = {os.fspath(f) for f in remote.tree.walk_files(".")} assert result == {".dvcignore", "foo"} @@ -150,7 +150,7 @@ def test_ignore_external(tmp_dir, scm, dvc, tmp_path_factory): ext_dir.gen({"y.backup": "y", "tmp": "ext tmp"}) remote = LocalRemote(dvc, {}) - result = {relpath(f, ext_dir) for f in remote.walk_files(ext_dir)} + result = {relpath(f, ext_dir) for f in remote.tree.walk_files(ext_dir)} assert result == {"y.backup", "tmp"} diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 990d07d730..40d3b8156b 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -194,9 +194,13 @@ def unreliable_upload(self, from_file, to_info, name=None, **kwargs): assert upload_error_info.value.amount == 3 remote = dvc.cloud.get_remote("upstream") - assert not remote.exists(remote.checksum_to_path_info(foo.checksum)) - assert remote.exists(remote.checksum_to_path_info(bar.checksum)) - assert not remote.exists(remote.checksum_to_path_info(baz.checksum)) + assert not remote.tree.exists( + remote.checksum_to_path_info(foo.checksum) + ) + assert remote.tree.exists(remote.checksum_to_path_info(bar.checksum)) + assert not remote.tree.exists( + remote.checksum_to_path_info(baz.checksum) + ) # Push everything and delete local cache dvc.push() diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index b90e7fb93b..f2f4cef9d3 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -5,7 +5,7 @@ import pytest from moto import mock_s3 -from dvc.remote.s3 import S3Remote +from dvc.remote.s3 import S3Remote, S3RemoteTree from tests.remotes import S3 # from https://github.com/spulec/moto/blob/v1.3.5/tests/test_s3/test_s3.py#L40 @@ -42,7 +42,7 @@ def test_copy_singlepart_preserve_etag(): s3.create_bucket(Bucket=from_info.bucket) s3.put_object(Bucket=from_info.bucket, Key=from_info.path, Body="data") - S3Remote._copy(s3, from_info, to_info, {}) + S3RemoteTree._copy(s3, from_info, to_info, {}) @mock_s3 @@ -58,15 +58,15 @@ def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): ) remote.link(base_info / "from", base_info / "to") - assert remote.exists(base_info / "from") - assert remote.exists(base_info / "to") + assert remote.tree.exists(base_info / "from") + assert remote.tree.exists(base_info / "to") @mock_s3 def test_makedirs_doesnot_try_on_top_level_paths(tmp_dir, dvc, scm): base_info = S3Remote.path_cls("s3://bucket/") remote = S3Remote(dvc, {"url": str(base_info)}) - remote.makedirs(base_info) + remote.tree.makedirs(base_info) def _upload_multipart(s3, Bucket, Key): @@ -107,4 +107,4 @@ def test_copy_multipart_preserve_etag(): s3 = boto3.client("s3") s3.create_bucket(Bucket=from_info.bucket) _upload_multipart(s3, from_info.bucket, from_info.path) - S3Remote._copy(s3, from_info, to_info, {}) + S3RemoteTree._copy(s3, from_info, to_info, {}) diff --git a/tests/unit/dependency/test_local.py b/tests/unit/dependency/test_local.py index 0feeaff0bc..bff5590f1f 100644 --- a/tests/unit/dependency/test_local.py +++ b/tests/unit/dependency/test_local.py @@ -15,6 +15,6 @@ def _get_dependency(self): def test_save_missing(self): d = self._get_dependency() - with mock.patch.object(d.remote, "exists", return_value=False): + with mock.patch.object(d.remote.tree, "exists", return_value=False): with self.assertRaises(d.DoesNotExistError): d.save() diff --git a/tests/unit/output/test_local.py b/tests/unit/output/test_local.py index e56dfd55ae..00944524da 100644 --- a/tests/unit/output/test_local.py +++ b/tests/unit/output/test_local.py @@ -19,7 +19,7 @@ def _get_output(self): def test_save_missing(self): o = self._get_output() - with patch.object(o.remote, "exists", return_value=False): + with patch.object(o.remote.tree, "exists", return_value=False): with self.assertRaises(o.DoesNotExistError): o.save() diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index 037ebe1970..b0b22bec9a 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -207,5 +207,5 @@ def test_hardlink_optimization(dvc, tmp_dir, ssh_server): else: link_path = to_info.path - remote.hardlink(from_info, to_info) + remote.tree.hardlink(from_info, to_info) assert not System.is_hardlink(link_path) diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index a623379b7b..b5462193da 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -41,7 +41,7 @@ def test_get_file_checksum(tmp_dir): remote = AzureRemote(None, {}) to_info = remote.path_cls(Azure.get_url()) remote.upload(PathInfo("foo"), to_info) - assert remote.exists(to_info) + assert remote.tree.exists(to_info) checksum = remote.get_file_checksum(to_info) assert checksum assert isinstance(checksum, str) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index c89b0474f9..fa2325fb2b 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -33,12 +33,12 @@ def test_cmd_error(dvc): err = "sed: expression #1, char 2: extra characters after command" with mock.patch.object( - REMOTE_CLS, + REMOTE_CLS.TREE_CLS, "remove", side_effect=RemoteCmdError("base", cmd, ret, err), ): with pytest.raises(RemoteCmdError): - REMOTE_CLS(dvc, config).remove("file") + REMOTE_CLS(dvc, config).tree.remove("file") @mock.patch.object(BaseRemote, "_cache_checksums_traverse") diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index cc2ef08fa4..a42e3d50c0 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -34,7 +34,7 @@ def test_status_download_optimization(mocker, dvc): @pytest.mark.parametrize("link_name", ["hardlink", "symlink"]) def test_is_protected(tmp_dir, dvc, link_name): remote = LocalRemote(dvc, {}) - link_method = getattr(remote, link_name) + link_method = getattr(remote.tree, link_name) (tmp_dir / "foo").write_text("foo") diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py index 58dd29636f..8e0f82a5d8 100644 --- a/tests/unit/remote/test_remote.py +++ b/tests/unit/remote/test_remote.py @@ -37,5 +37,5 @@ def test_makedirs_not_create_for_top_level_path(remote_cls, dvc, mocker): # we use remote clients with same name as scheme to interact with remote mocker.patch.object(remote_cls, remote.scheme, mocked_client) - remote.makedirs(remote.path_info) + remote.tree.makedirs(remote.path_info) assert not mocked_client.called diff --git a/tests/unit/remote/test_remote_dir.py b/tests/unit/remote/test_remote_tree.py similarity index 89% rename from tests/unit/remote/test_remote_dir.py rename to tests/unit/remote/test_remote_tree.py index 9a0f60d4da..fc9f557df2 100644 --- a/tests/unit/remote/test_remote_dir.py +++ b/tests/unit/remote/test_remote_tree.py @@ -48,7 +48,7 @@ def test_isdir(remote): ] for expected, path in test_cases: - assert remote.isdir(remote.path_info / path) == expected + assert remote.tree.isdir(remote.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True) @@ -68,7 +68,7 @@ def test_exists(remote): ] for expected, path in test_cases: - assert remote.exists(remote.path_info / path) == expected + assert remote.tree.exists(remote.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True) @@ -83,7 +83,7 @@ def test_walk_files(remote): remote.path_info / "data/subdir/empty_file", ] - assert list(remote.walk_files(remote.path_info / "data")) == files + assert list(remote.tree.walk_files(remote.path_info / "data")) == files @pytest.mark.parametrize("remote", [S3Mocked], indirect=True) @@ -96,7 +96,7 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): from_info = remote.path_info / "foo" to_info = another.path_info / "foo" - remote.copy(from_info, to_info) + remote.tree.copy(from_info, to_info) from_etag = S3Remote.get_etag(s3, from_info.bucket, from_info.path) to_etag = S3Remote.get_etag(s3, "another", "foo") @@ -106,12 +106,13 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): @pytest.mark.parametrize("remote", remotes, indirect=True) def test_makedirs(remote): + tree = remote.tree empty_dir = remote.path_info / "empty_dir" / "" - remote.remove(empty_dir) - assert not remote.exists(empty_dir) - remote.makedirs(empty_dir) - assert remote.exists(empty_dir) - assert remote.isdir(empty_dir) + tree.remove(empty_dir) + assert not tree.exists(empty_dir) + tree.makedirs(empty_dir) + assert tree.exists(empty_dir) + assert tree.isdir(empty_dir) @pytest.mark.parametrize("remote", [GCP, S3Mocked], indirect=True) @@ -133,7 +134,7 @@ def test_isfile(remote): ] for expected, path in test_cases: - assert remote.isfile(remote.path_info / path) == expected + assert remote.tree.isfile(remote.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True)