Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get/import: retrieve files inside directory outs #3309

Merged
merged 2 commits into from Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions dvc/external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

from funcy import retry, suppress, wrap_with, cached_property

from dvc.path_info import PathInfo
from dvc.compat import fspath
from dvc.repo import Repo
from dvc.config import NoRemoteError, NotDvcRepoError
from dvc.exceptions import NoRemoteInExternalRepoError
from dvc.exceptions import OutputNotFoundError, NoOutputInExternalRepoError
from dvc.exceptions import FileMissingError, PathMissingError
from dvc.utils.fs import remove, fs_copy
from dvc.utils.fs import remove, fs_copy, move
from dvc.utils import tmp_fname
from dvc.scm.git import Git


Expand Down Expand Up @@ -67,32 +69,49 @@ def __init__(self, root_dir, url):
self._set_upstream()

def pull_to(self, path, to_info):
try:
out = None
with suppress(OutputNotFoundError):
out = self.find_out_by_relpath(path)
"""
Pull the corresponding file or directory specified by `path` and
checkout it into `to_info`.

It works with files tracked by Git and DVC, and also local files
outside the repository.
"""
out = None
path_info = PathInfo(self.root_dir) / path

with suppress(OutputNotFoundError):
(out,) = self.find_outs_by_path(fspath(path_info), strict=False)

try:
if out and out.use_cache:
self._pull_cached(out, to_info)
self._pull_cached(out, path, to_info)
return

# Git handled files can't have absolute path
# Check if it is handled by Git (it can't have an absolute path)
if os.path.isabs(path):
raise FileNotFoundError

fs_copy(os.path.join(self.root_dir, path), fspath(to_info))
fs_copy(fspath(path_info), fspath(to_info))
except FileNotFoundError:
raise PathMissingError(path, self.url)

def _pull_cached(self, out, to_info):
def _pull_cached(self, out, src, dest):
with self.state:
tmp = PathInfo(tmp_fname(dest))
target = (out.path_info.parent / src).relative_to(out.path_info)
src = tmp / target

out.path_info = tmp

# Only pull unless all needed cache is present
if out.changed_cache():
self.cloud.pull(out.get_used_cache())
if out.changed_cache(filter_info=src):
self.cloud.pull(out.get_used_cache(filter_info=src))

failed = out.checkout(filter_info=src)

move(src, dest)
remove(tmp)

out.path_info = to_info
failed = out.checkout()
# This might happen when pull haven't really pulled all the files
if failed:
raise FileNotFoundError

Expand Down
4 changes: 2 additions & 2 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def exists(self):
def changed_checksum(self):
return self.checksum != self.remote.get_checksum(self.path_info)

def changed_cache(self):
def changed_cache(self, filter_info=None):
if not self.use_cache or not self.checksum:
return True

return self.cache.changed_cache(self.checksum)
return self.cache.changed_cache(self.checksum, filter_info=filter_info)

def status(self):
if self.checksum and self.use_cache and self.changed_cache():
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def move(src, dst, mode=None):
of data is happening.
"""

src = fspath_py35(src)
dst = fspath_py35(dst)
src = fspath(src)
dst = fspath(dst)

dst = os.path.abspath(dst)
tmp = "{}.{}".format(dst, uuid())
Expand Down
27 changes: 27 additions & 0 deletions tests/func/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,33 @@ def test_get_from_non_dvc_master(tmp_dir, git_dir, caplog):
assert (tmp_dir / "some_dst").read_text() == "some text"


def test_get_file_from_dir(tmp_dir, erepo_dir):
with erepo_dir.chdir():
erepo_dir.dvc_gen(
{
"dir": {
"1": "1",
"2": "2",
"subdir": {"foo": "foo", "bar": "bar"},
}
},
commit="create dir",
)

Repo.get(fspath(erepo_dir), os.path.join("dir", "1"))
assert (tmp_dir / "1").read_text() == "1"

Repo.get(fspath(erepo_dir), os.path.join("dir", "2"), out="file")
assert (tmp_dir / "file").read_text() == "2"

Repo.get(fspath(erepo_dir), os.path.join("dir", "subdir"))
assert (tmp_dir / "subdir" / "foo").read_text() == "foo"
assert (tmp_dir / "subdir" / "bar").read_text() == "bar"

Repo.get(fspath(erepo_dir), os.path.join("dir", "subdir", "foo"), out="X")
assert (tmp_dir / "X").read_text() == "foo"


def test_get_url_positive(tmp_dir, erepo_dir, caplog):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo")
Expand Down
36 changes: 36 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ def test_import_dir(tmp_dir, scm, dvc, erepo_dir):
}


def test_import_file_from_dir(tmp_dir, scm, dvc, erepo_dir):
with erepo_dir.chdir():
erepo_dir.dvc_gen(
{
"dir": {
"1": "1",
"2": "2",
"subdir": {"foo": "foo", "bar": "bar"},
}
},
commit="create dir",
)

stage = dvc.imp(fspath(erepo_dir), os.path.join("dir", "1"))

assert (tmp_dir / "1").read_text() == "1"
assert scm.repo.git.check_ignore("1")
assert stage.deps[0].def_repo == {
"url": fspath(erepo_dir),
"rev_lock": erepo_dir.scm.get_rev(),
}

dvc.imp(fspath(erepo_dir), os.path.join("dir", "2"), out="file")
assert (tmp_dir / "file").read_text() == "2"
assert (tmp_dir / "file.dvc").exists()

dvc.imp(fspath(erepo_dir), os.path.join("dir", "subdir"))
assert (tmp_dir / "subdir" / "foo").read_text() == "foo"
assert (tmp_dir / "subdir" / "bar").read_text() == "bar"
assert (tmp_dir / "subdir.dvc").exists()

dvc.imp(fspath(erepo_dir), os.path.join("dir", "subdir", "foo"), out="X")
assert (tmp_dir / "X").read_text() == "foo"
assert (tmp_dir / "X.dvc").exists()


def test_import_non_cached(erepo_dir, tmp_dir, dvc, scm):
src = "non_cached_output"
dst = src + "_imported"
Expand Down