Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restyle get: handle non-DVC repositories #3103

Closed
wants to merge 13 commits into from
12 changes: 8 additions & 4 deletions dvc/command/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def run(self):


def add_parser(subparsers, parent_parser):
GET_HELP = "Download/copy files or directories from DVC repository."
GET_HELP = (
"Download a file or directory from any DVC project or Git repository"
)
get_parser = subparsers.add_parser(
"get",
parents=[parent_parser],
Expand All @@ -40,10 +42,12 @@ def add_parser(subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
get_parser.add_argument(
"url", help="URL of Git repository with DVC project to download from."
"url",
help="Location of DVC project or Git repository to download from",
)
get_parser.add_argument(
"path", help="Path to a file or directory within a DVC repository."
"path",
help="Path to a file or directory within the project or repository",
)
get_parser.add_argument(
"-o",
Expand All @@ -52,6 +56,6 @@ def add_parser(subparsers, parent_parser):
help="Destination path to copy/download files to.",
)
get_parser.add_argument(
"--rev", nargs="?", help="DVC repository git revision."
"--rev", nargs="?", help="Git revision (e.g. branch, tag, SHA)"
)
get_parser.set_defaults(func=CmdGet)
5 changes: 0 additions & 5 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,6 @@ def __init__(self, ignore_dirname):
)


class UrlNotDvcRepoError(DvcException):
def __init__(self, url):
super().__init__("URL '{}' is not a dvc repository.".format(url))


class GitHookAlreadyExistsError(DvcException):
def __init__(self, hook_name):
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions dvc/external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def external_repo(url=None, rev=None, rev_lock=None, cache_dir=None):
repo.close()


def cached_clone(url, rev=None, **_ignored_kwargs):
def cached_clone(url, rev=None, clone_path=None, **_ignored_kwargs):
"""Clone an external git repo to a temporary directory.

Returns the path to a local temporary directory with the specified
Expand All @@ -44,7 +44,7 @@ def cached_clone(url, rev=None, **_ignored_kwargs):

"""

new_path = tempfile.mkdtemp("dvc-erepo")
new_path = clone_path or tempfile.mkdtemp("dvc-erepo")

# Copy and adjust existing clean clone
if (url, None, None) in REPO_CACHE:
Expand Down
35 changes: 19 additions & 16 deletions dvc/repo/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
DvcException,
NotDvcRepoError,
OutputNotFoundError,
UrlNotDvcRepoError,
PathMissingError,
)
from dvc.external_repo import external_repo
from dvc.external_repo import cached_clone
from dvc.path_info import PathInfo
from dvc.stage import Stage
from dvc.utils import resolve_output
Expand All @@ -30,6 +29,8 @@ def __init__(self):

@staticmethod
def get(url, path, out=None, rev=None):
from dvc.repo import Repo

out = resolve_output(path, out)

if Stage.is_valid_filename(out):
Expand All @@ -43,7 +44,10 @@ def get(url, path, out=None, rev=None):
dpath = os.path.dirname(os.path.abspath(out))
tmp_dir = os.path.join(dpath, "." + str(shortuuid.uuid()))
try:
with external_repo(cache_dir=tmp_dir, url=url, rev=rev) as repo:
cached_clone(url, rev=rev, clone_path=tmp_dir)
try:
repo = Repo(tmp_dir)

# Try any links possible to avoid data duplication.
#
# Not using symlink, because we need to remove cache after we are
Expand All @@ -55,24 +59,23 @@ def get(url, path, out=None, rev=None):
# the same cache file might be used a few times in a directory.
repo.cache.local.cache_types = ["reflink", "hardlink", "copy"]

try:
output = repo.find_out_by_relpath(path)
except OutputNotFoundError:
output = None

if output and output.use_cache:
output = repo.find_out_by_relpath(path)
if output.use_cache:
_get_cached(repo, output, out)
else:
# Either an uncached out with absolute path or a user error
if os.path.isabs(path):
raise FileNotFoundError
return

except (NotDvcRepoError, OutputNotFoundError):
pass

# It's an uncached out with absolute path, a non-DVC repo, or a
# user error
if os.path.isabs(path):
raise FileNotFoundError

fs_copy(os.path.join(repo.root_dir, path), out)
fs_copy(os.path.join(tmp_dir, path), out)

except (OutputNotFoundError, FileNotFoundError):
raise PathMissingError(path, url)
except NotDvcRepoError:
raise UrlNotDvcRepoError(url)
finally:
remove(tmp_dir)

Expand Down
10 changes: 3 additions & 7 deletions tests/func/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from dvc.cache import Cache
from dvc.config import Config
from dvc.exceptions import UrlNotDvcRepoError
from dvc.repo.get import GetDVCFileError, PathMissingError
from dvc.repo import Repo
from dvc.system import System
Expand Down Expand Up @@ -87,9 +86,10 @@ def test_get_repo_rev(tmp_dir, erepo_dir):
def test_get_from_non_dvc_repo(tmp_dir, erepo_dir):
erepo_dir.scm.repo.index.remove([erepo_dir.dvc.dvc_dir], r=True)
erepo_dir.scm.commit("remove dvc")
erepo_dir.scm_gen({"some_file": "contents"}, commit="create file")

with pytest.raises(UrlNotDvcRepoError):
Repo.get(fspath(erepo_dir), "some_file.zip")
Repo.get(fspath(erepo_dir), "some_file", "file_imported")
assert (tmp_dir / "file_imported").read_text() == "contents"


def test_get_a_dvc_file(tmp_dir, erepo_dir):
Expand Down Expand Up @@ -164,10 +164,6 @@ def test_get_from_non_dvc_master(tmp_dir, erepo_dir, caplog):
erepo_dir.dvc.scm.repo.index.remove([".dvc"], r=True)
erepo_dir.dvc.scm.commit("remove .dvc")

# sanity check
with pytest.raises(UrlNotDvcRepoError):
Repo.get(fspath(erepo_dir), "some_file")

caplog.clear()
dst = "file_imported"
with caplog.at_level(logging.INFO, logger="dvc"):
Expand Down