diff --git a/dvc/cli.py b/dvc/cli.py index 0dc80c5918..7def718bcf 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -9,7 +9,8 @@ from dvc.command.base import fix_subparsers import dvc.command.init as init -import dvc.command.pkg as pkg +import dvc.command.get as get +import dvc.command.get_url as get_url import dvc.command.destroy as destroy import dvc.command.remove as remove import dvc.command.move as move @@ -20,6 +21,7 @@ import dvc.command.gc as gc import dvc.command.add as add import dvc.command.imp as imp +import dvc.command.imp_url as imp_url import dvc.command.config as config import dvc.command.checkout as checkout import dvc.command.remote as remote @@ -41,7 +43,8 @@ COMMANDS = [ init, - pkg, + get, + get_url, destroy, add, remove, @@ -52,6 +55,7 @@ data_sync, gc, imp, + imp_url, config, checkout, remote, diff --git a/dvc/command/get.py b/dvc/command/get.py new file mode 100644 index 0000000000..f933a7e3c6 --- /dev/null +++ b/dvc/command/get.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals + +import argparse +import logging + +from dvc.repo import Repo +from dvc.exceptions import DvcException +from .base import CmdBaseNoRepo, append_doc_link + + +logger = logging.getLogger(__name__) + + +class CmdGet(CmdBaseNoRepo): + def run(self): + try: + Repo.get( + self.args.url, + path=self.args.path, + out=self.args.out, + rev=self.args.rev, + ) + return 0 + except DvcException: + logger.exception( + "failed to get '{}' from '{}'".format( + self.args.path, self.args.url + ) + ) + return 1 + + +def add_parser(subparsers, parent_parser): + GET_HELP = "Download or copy files from URL." + get_parser = subparsers.add_parser( + "get", + parents=[parent_parser], + description=append_doc_link(GET_HELP, "get"), + help=GET_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + get_parser.add_argument( + "url", help="DVC repository URL to download data from." + ) + get_parser.add_argument("path", help="Path to data within DVC repository.") + get_parser.add_argument( + "-o", "--out", nargs="?", help="Destination path to put data to." + ) + get_parser.add_argument( + "--rev", nargs="?", help="DVC repository git revision." + ) + get_parser.set_defaults(func=CmdGet) diff --git a/dvc/command/get_url.py b/dvc/command/get_url.py new file mode 100644 index 0000000000..27f4e935ea --- /dev/null +++ b/dvc/command/get_url.py @@ -0,0 +1,39 @@ +from __future__ import unicode_literals + +import argparse +import logging + +from dvc.repo import Repo +from dvc.exceptions import DvcException +from .base import CmdBaseNoRepo, append_doc_link + + +logger = logging.getLogger(__name__) + + +class CmdGetUrl(CmdBaseNoRepo): + def run(self): + try: + Repo.get_url(self.args.url, out=self.args.out) + return 0 + except DvcException: + logger.exception("failed to get '{}'".format(self.args.url)) + return 1 + + +def add_parser(subparsers, parent_parser): + GET_HELP = "Download or copy files from URL." + get_parser = subparsers.add_parser( + "get-url", + parents=[parent_parser], + description=append_doc_link(GET_HELP, "get-url"), + help=GET_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + get_parser.add_argument( + "url", help="See `dvc import-url -h` for full list of supported URLs." + ) + get_parser.add_argument( + "out", nargs="?", help="Destination path to put data to." + ) + get_parser.set_defaults(func=CmdGetUrl) diff --git a/dvc/command/imp.py b/dvc/command/imp.py index b284ebb283..c42558e530 100644 --- a/dvc/command/imp.py +++ b/dvc/command/imp.py @@ -1,10 +1,8 @@ from __future__ import unicode_literals import argparse -import os import logging -from dvc.utils.compat import urlparse from dvc.exceptions import DvcException from dvc.command.base import CmdBase, append_doc_link @@ -15,18 +13,16 @@ class CmdImport(CmdBase): def run(self): try: - default_out = os.path.basename(urlparse(self.args.url).path) - - out = self.args.out or default_out - self.repo.imp( - self.args.url, out, self.args.resume, fname=self.args.file + self.args.url, + path=self.args.path, + out=self.args.out, + rev=self.args.rev, ) except DvcException: logger.exception( - "failed to import {}. You could also try downloading " - "it manually and adding it with `dvc add` command.".format( - self.args.url + "failed to import '{}' from '{}'.".format( + self.args.path, self.args.url ) ) return 1 @@ -34,7 +30,9 @@ def run(self): def add_parser(subparsers, parent_parser): - IMPORT_HELP = "Download or copy files from URL and take under DVC control." + IMPORT_HELP = ( + "Download data from DVC repository and take it under DVC control." + ) import_parser = subparsers.add_parser( "import", @@ -43,28 +41,14 @@ def add_parser(subparsers, parent_parser): help=IMPORT_HELP, formatter_class=argparse.RawTextHelpFormatter, ) + import_parser.add_argument("url", help="DVC repository URL.") import_parser.add_argument( - "url", - help="Supported urls:\n" - "/path/to/file\n" - "C:\\\\path\\to\\file\n" - "https://example.com/path/to/file\n" - "s3://bucket/path/to/file\n" - "gs://bucket/path/to/file\n" - "hdfs://example.com/path/to/file\n" - "ssh://example.com:/path/to/file\n" - "remote://myremote/path/to/file (see `dvc remote`)", - ) - import_parser.add_argument( - "--resume", - action="store_true", - default=False, - help="Resume previously started download.", + "path", nargs="?", help="Path to data within DVC repository." ) import_parser.add_argument( - "out", nargs="?", help="Destination path to put files to." + "-o", "--out", nargs="?", help="Destination path to put data to." ) import_parser.add_argument( - "-f", "--file", help="Specify name of the DVC-file it generates." + "--rev", nargs="?", help="DVC repository git revision." ) import_parser.set_defaults(func=CmdImport) diff --git a/dvc/command/imp_url.py b/dvc/command/imp_url.py new file mode 100644 index 0000000000..89153d59eb --- /dev/null +++ b/dvc/command/imp_url.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals + +import argparse +import logging + +from dvc.exceptions import DvcException +from dvc.command.base import CmdBase, append_doc_link + + +logger = logging.getLogger(__name__) + + +class CmdImportUrl(CmdBase): + def run(self): + try: + self.repo.imp_url( + self.args.url, + out=self.args.out, + resume=self.args.resume, + fname=self.args.file, + ) + except DvcException: + logger.exception( + "failed to import {}. You could also try downloading " + "it manually and adding it with `dvc add` command.".format( + self.args.url + ) + ) + return 1 + return 0 + + +def add_parser(subparsers, parent_parser): + IMPORT_HELP = "Download or copy files from URL and take under DVC control." + + import_parser = subparsers.add_parser( + "import-url", + parents=[parent_parser], + description=append_doc_link(IMPORT_HELP, "import-url"), + help=IMPORT_HELP, + formatter_class=argparse.RawTextHelpFormatter, + ) + import_parser.add_argument( + "url", + help="Supported urls:\n" + "/path/to/file\n" + "C:\\\\path\\to\\file\n" + "https://example.com/path/to/file\n" + "s3://bucket/path/to/file\n" + "gs://bucket/path/to/file\n" + "hdfs://example.com/path/to/file\n" + "ssh://example.com:/path/to/file\n" + "remote://myremote/path/to/file (see `dvc remote`)", + ) + import_parser.add_argument( + "--resume", + action="store_true", + default=False, + help="Resume previously started download.", + ) + import_parser.add_argument( + "out", nargs="?", help="Destination path to put files to." + ) + import_parser.add_argument( + "-f", "--file", help="Specify name of the DVC-file it generates." + ) + import_parser.set_defaults(func=CmdImportUrl) diff --git a/dvc/command/pkg.py b/dvc/command/pkg.py deleted file mode 100644 index 8ddaa8c498..0000000000 --- a/dvc/command/pkg.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import unicode_literals - -import argparse -import logging - -from dvc.pkg import PkgManager -from dvc.exceptions import DvcException -from .base import CmdBase, CmdBaseNoRepo, fix_subparsers, append_doc_link - - -logger = logging.getLogger(__name__) - - -class CmdPkgInstall(CmdBase): - def run(self): - try: - self.repo.pkg.install( - self.args.url, - version=self.args.version, - name=self.args.name, - force=self.args.force, - ) - return 0 - except DvcException: - logger.exception( - "failed to install package '{}'".format(self.args.url) - ) - return 1 - - -class CmdPkgUninstall(CmdBase): - def run(self): - ret = 0 - for target in self.args.targets: - try: - self.repo.pkg.uninstall(target) - except DvcException: - logger.exception( - "failed to uninstall package '{}'".format(target) - ) - ret = 1 - return ret - - -class CmdPkgImport(CmdBase): - def run(self): - try: - self.repo.pkg.imp( - self.args.name, - self.args.src, - out=self.args.out, - version=self.args.version, - ) - return 0 - except DvcException: - logger.exception( - "failed to import '{}' from package '{}'".format( - self.args.src, self.args.name - ) - ) - return 1 - - -class CmdPkgGet(CmdBaseNoRepo): - def run(self): - try: - PkgManager.get( - self.args.url, - self.args.src, - out=self.args.out, - version=self.args.version, - ) - return 0 - except DvcException: - logger.exception( - "failed to get '{}' from package '{}'".format( - self.args.src, self.args.name - ) - ) - return 1 - - -def add_parser(subparsers, parent_parser): - PKG_HELP = "Manage DVC packages." - pkg_parser = subparsers.add_parser( - "pkg", - parents=[parent_parser], - description=append_doc_link(PKG_HELP, "pkg"), - formatter_class=argparse.RawDescriptionHelpFormatter, - add_help=False, - ) - - pkg_subparsers = pkg_parser.add_subparsers( - dest="cmd", help="Use dvc pkg CMD --help for command-specific help." - ) - - fix_subparsers(pkg_subparsers) - - PKG_INSTALL_HELP = "Install package." - pkg_install_parser = pkg_subparsers.add_parser( - "install", - parents=[parent_parser], - description=append_doc_link(PKG_INSTALL_HELP, "pkg-install"), - help=PKG_INSTALL_HELP, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - pkg_install_parser.add_argument("url", help="Package URL.") - pkg_install_parser.add_argument( - "--version", nargs="?", help="Package version." - ) - pkg_install_parser.add_argument( - "--name", - nargs="?", - help=( - "Package alias. If not specified, the name will be determined " - "from URL." - ), - ) - pkg_install_parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - help="Reinstall package if it is already installed.", - ) - pkg_install_parser.set_defaults(func=CmdPkgInstall) - - PKG_UNINSTALL_HELP = "Uninstall package(s)." - pkg_uninstall_parser = pkg_subparsers.add_parser( - "uninstall", - parents=[parent_parser], - description=append_doc_link(PKG_UNINSTALL_HELP, "pkg-uninstall"), - help=PKG_UNINSTALL_HELP, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - pkg_uninstall_parser.add_argument( - "targets", nargs="*", default=[None], help="Package name." - ) - pkg_uninstall_parser.set_defaults(func=CmdPkgUninstall) - - PKG_IMPORT_HELP = "Import data from package." - pkg_import_parser = pkg_subparsers.add_parser( - "import", - parents=[parent_parser], - description=append_doc_link(PKG_IMPORT_HELP, "pkg-import"), - help=PKG_IMPORT_HELP, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - pkg_import_parser.add_argument("name", help="Package name or url.") - pkg_import_parser.add_argument("src", help="Path to data in the package.") - pkg_import_parser.add_argument( - "-o", "--out", nargs="?", help="Destination path to put data to." - ) - pkg_import_parser.add_argument( - "--version", nargs="?", help="Package version." - ) - pkg_import_parser.set_defaults(func=CmdPkgImport) - - PKG_GET_HELP = "Download data from the package." - pkg_get_parser = pkg_subparsers.add_parser( - "get", - parents=[parent_parser], - description=append_doc_link(PKG_GET_HELP, "pkg-get"), - help=PKG_GET_HELP, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - pkg_get_parser.add_argument("url", help="Package url.") - pkg_get_parser.add_argument("src", help="Path to data in the package.") - pkg_get_parser.add_argument( - "-o", "--out", nargs="?", help="Destination path to put data to." - ) - pkg_get_parser.add_argument( - "--version", nargs="?", help="Package version." - ) - pkg_get_parser.set_defaults(func=CmdPkgGet) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 2779230f9a..52a451c9ff 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -13,10 +13,10 @@ from dvc.dependency.ssh import DependencySSH from dvc.dependency.http import DependencyHTTP from dvc.dependency.https import DependencyHTTPS -from .pkg import DependencyPKG +from .repo import DependencyREPO from dvc.remote import Remote -from dvc.pkg import Pkg +from dvc.external_repo import ExternalRepo DEPS = [ @@ -47,7 +47,7 @@ SCHEMA = output.SCHEMA.copy() del SCHEMA[schema.Optional(OutputBase.PARAM_CACHE)] del SCHEMA[schema.Optional(OutputBase.PARAM_METRIC)] -SCHEMA[schema.Optional(DependencyPKG.PARAM_PKG)] = Pkg.SCHEMA +SCHEMA[schema.Optional(DependencyREPO.PARAM_REPO)] = ExternalRepo.SCHEMA def _get(stage, p, info): @@ -59,9 +59,9 @@ def _get(stage, p, info): remote = Remote(stage.repo, name=parsed.netloc) return DEP_MAP[remote.scheme](stage, p, info, remote=remote) - if info and info.get(DependencyPKG.PARAM_PKG): - pkg = info.pop(DependencyPKG.PARAM_PKG) - return DependencyPKG(pkg, stage, p, info) + if info and info.get(DependencyREPO.PARAM_REPO): + repo = info.pop(DependencyREPO.PARAM_REPO) + return DependencyREPO(repo, stage, p, info) for d in DEPS: if d.supported(p): @@ -77,9 +77,9 @@ def loadd_from(stage, d_list): return ret -def loads_from(stage, s_list, pkg=None): +def loads_from(stage, s_list, erepo=None): ret = [] for s in s_list: - info = {DependencyPKG.PARAM_PKG: pkg} if pkg else {} + info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} ret.append(_get(stage, s, info)) return ret diff --git a/dvc/dependency/pkg.py b/dvc/dependency/repo.py similarity index 53% rename from dvc/dependency/pkg.py rename to dvc/dependency/repo.py index 755556208f..040f919d4e 100644 --- a/dvc/dependency/pkg.py +++ b/dvc/dependency/repo.py @@ -4,25 +4,28 @@ import copy from dvc.utils.compat import urlparse +from dvc.external_repo import ExternalRepo from .local import DependencyLOCAL -class DependencyPKG(DependencyLOCAL): - PARAM_PKG = "pkg" +class DependencyREPO(DependencyLOCAL): + PARAM_REPO = "repo" - def __init__(self, pkg, stage, *args, **kwargs): - self.pkg = stage.repo.pkg.get_pkg(**pkg) + def __init__(self, erepo, stage, *args, **kwargs): + self.erepo = ExternalRepo(stage.repo.dvc_dir, **erepo) super(DependencyLOCAL, self).__init__(stage, *args, **kwargs) def _parse_path(self, remote, path): + self.erepo.install(self.repo.cache.local.cache_dir) + out_path = os.path.join( - self.pkg.repo.root_dir, urlparse(path).path.lstrip("/") + self.erepo.repo.root_dir, urlparse(path).path.lstrip("/") ) - out, = self.pkg.repo.find_outs_by_path(out_path) + out, = self.erepo.repo.find_outs_by_path(out_path) self.info = copy.copy(out.info) - self._pkg_stage = copy.copy(out.stage.path) + self._erepo_stage = copy.copy(out.stage.path) return self.REMOTE.path_cls(out.cache_path) @property @@ -31,10 +34,10 @@ def is_in_repo(self): def dumpd(self): ret = super(DependencyLOCAL, self).dumpd() - ret[self.PARAM_PKG] = self.pkg.dumpd() + ret[self.PARAM_REPO] = self.erepo.dumpd() return ret def download(self, to, resume=False): - self.pkg.repo.fetch(self._pkg_stage) + self.erepo.repo.fetch(self._erepo_stage) to.info = copy.copy(self.info) to.checkout() diff --git a/dvc/external_repo.py b/dvc/external_repo.py new file mode 100644 index 0000000000..ab8b0e4934 --- /dev/null +++ b/dvc/external_repo.py @@ -0,0 +1,153 @@ +import os +import uuid +import errno +import shutil +import logging + +from funcy import cached_property +from schema import Optional + +from dvc.config import Config +from dvc.cache import CacheConfig +from dvc.exceptions import DvcException + + +logger = logging.getLogger(__name__) + + +class ExternalRepoError(DvcException): + pass + + +class NotInstalledError(ExternalRepoError): + def __init__(self, name): + super(NotInstalledError, self).__init__( + "Repo '{}' is not installed".format(name) + ) + + +class InstallError(ExternalRepoError): + def __init__(self, url, path, cause): + super(InstallError, self).__init__( + "Failed to install repo '{}' to '{}'".format(url, path), + cause=cause, + ) + + +class RevError(ExternalRepoError): + def __init__(self, url, rev, cause): + super(RevError, self).__init__( + "Failed to access revision '{}' for repo '{}'".format(rev, url), + cause=cause, + ) + + +class ExternalRepo(object): + REPOS_DIR = "repos" + + PARAM_URL = "url" + PARAM_VERSION = "rev" + + SCHEMA = {Optional(PARAM_URL): str, Optional(PARAM_VERSION): str} + + def __init__(self, dvc_dir, **kwargs): + self.repos_dir = os.path.join(dvc_dir, self.REPOS_DIR) + self.url = kwargs[self.PARAM_URL] + + self.name = "{}-{}".format(os.path.basename(self.url), hash(self.url)) + + self.rev = kwargs.get(self.PARAM_VERSION) + self.path = os.path.join(self.repos_dir, self.name) + + @cached_property + def repo(self): + from dvc.repo import Repo + + if not self.installed: + raise NotInstalledError(self.name) + + return Repo(self.path, rev=self.rev) + + @property + def installed(self): + return os.path.exists(self.path) + + def _install_to(self, tmp_dir, cache_dir): + import git + + try: + git.Repo.clone_from( + self.url, tmp_dir, depth=1, no_single_branch=True + ) + except git.exc.GitCommandError as exc: + raise InstallError(self.url, tmp_dir, exc) + + if self.rev: + try: + repo = git.Repo(tmp_dir) + repo.git.checkout(self.rev) + repo.close() + except git.exc.GitCommandError as exc: + raise RevError(self.url, self.rev, exc) + + if cache_dir: + from dvc.repo import Repo + + repo = Repo(tmp_dir) + cache_config = CacheConfig(repo.config) + cache_config.set_dir(cache_dir, level=Config.LEVEL_LOCAL) + repo.scm.git.close() + + if self.installed: + self.uninstall() + + def install(self, cache_dir=None, force=False): + if self.installed and not force: + logger.info( + "Skipping installing '{}'('{}') as it is already " + "installed.".format(self.name, self.url) + ) + return + + try: + os.makedirs(self.repos_dir) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + + # installing package to a temporary directory until we are sure that + # it has been installed correctly. + # + # Note that tempfile.TemporaryDirectory is using symlinks to tmpfs, so + # we won't be able to use move properly. + tmp_dir = os.path.join(self.repos_dir, "." + str(uuid.uuid4())) + try: + self._install_to(tmp_dir, cache_dir) + except ExternalRepoError: + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) + raise + + shutil.move(tmp_dir, self.path) + + def uninstall(self): + if not self.installed: + logger.info( + "Skipping uninstalling '{}' as it is not installed.".format( + self.name + ) + ) + return + + shutil.rmtree(self.path) + + def update(self): + self.repo.scm.fetch(self.rev) + + def dumpd(self): + ret = {self.PARAM_URL: self.url} + + if self.rev: + ret[self.PARAM_VERSION] = self.rev + + return ret diff --git a/dvc/output/base.py b/dvc/output/base.py index 7e0ba51a49..1927d1b1f7 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -82,7 +82,7 @@ def __init__( # By resolved path, which contains actual location, # should be absolute and don't contain remote:// refs. self.stage = stage - self.repo = stage.repo + self.repo = stage.repo if stage else None self.def_path = path self.info = info self.remote = remote or self.REMOTE(self.repo, {}) @@ -395,7 +395,7 @@ def get_used_cache(self, **kwargs): include the `info` of its files. """ - if self.stage.is_pkg_import: + if self.stage.is_repo_import: return [] if not self.use_cache: diff --git a/dvc/pkg.py b/dvc/pkg.py deleted file mode 100644 index 0167038b1b..0000000000 --- a/dvc/pkg.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import uuid -import shutil -import logging - -from funcy import cached_property -from schema import Optional - -from dvc.config import Config -from dvc.cache import CacheConfig -from dvc.path_info import PathInfo -from dvc.exceptions import DvcException -from dvc.utils.compat import urlparse - - -logger = logging.getLogger(__name__) - - -class PkgError(DvcException): - pass - - -class NotInstalledError(PkgError): - def __init__(self, name): - super(NotInstalledError, self).__init__( - "Package '{}' is not installed".format(name) - ) - - -class InstallError(PkgError): - def __init__(self, url, path, cause): - super(InstallError, self).__init__( - "Failed to install pkg '{}' to '{}'".format(url, path), cause=cause - ) - - -class VersionError(PkgError): - def __init__(self, url, version, cause): - super(VersionError, self).__init__( - "Failed to access version '{}' for package '{}'".format( - version, url - ), - cause=cause, - ) - - -class Pkg(object): - PARAM_NAME = "name" - PARAM_URL = "url" - PARAM_VERSION = "version" - - SCHEMA = { - Optional(PARAM_NAME): str, - Optional(PARAM_URL): str, - Optional(PARAM_VERSION): str, - } - - def __init__(self, pkg_dir, **kwargs): - self.pkg_dir = pkg_dir - self.url = kwargs.get(self.PARAM_URL) - - name = kwargs.get(self.PARAM_NAME) - if name is None: - name = os.path.basename(self.url) - self.name = name - - self.version = kwargs.get(self.PARAM_VERSION) - self.path = os.path.join(pkg_dir, self.name) - - @cached_property - def repo(self): - from dvc.repo import Repo - - if not self.installed: - raise NotInstalledError(self.name) - - return Repo(self.path, version=self.version) - - @property - def installed(self): - return os.path.exists(self.path) - - def _install_to(self, tmp_dir, cache_dir): - import git - - try: - git.Repo.clone_from( - self.url, tmp_dir, depth=1, no_single_branch=True - ) - except git.exc.GitCommandError as exc: - raise InstallError(self.url, tmp_dir, exc) - - if self.version: - try: - repo = git.Repo(tmp_dir) - repo.git.checkout(self.version) - except git.exc.GitCommandError as exc: - raise VersionError(self.url, self.version, exc) - - if cache_dir: - from dvc.repo import Repo - - repo = Repo(tmp_dir) - cache_config = CacheConfig(repo.config) - cache_config.set_dir(cache_dir, level=Config.LEVEL_LOCAL) - - if self.installed: - self.uninstall() - - def install(self, cache_dir=None, force=False): - if self.installed and not force: - logger.info( - "Skipping installing '{}'('{}') as it is already " - "installed.".format(self.name, self.url) - ) - return - - # installing package to a temporary directory until we are sure that - # it has been installed correctly. - # - # Note that tempfile.TemporaryDirectory is using symlinks to tmpfs, so - # we won't be able to use move properly. - tmp_dir = os.path.join(self.pkg_dir, "." + str(uuid.uuid4())) - try: - self._install_to(tmp_dir, cache_dir) - except PkgError: - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) - raise - - shutil.move(tmp_dir, self.path) - - def uninstall(self): - if not self.installed: - logger.info( - "Skipping uninstalling '{}' as it is not installed.".format( - self.name - ) - ) - return - - shutil.rmtree(self.path) - - def update(self): - self.repo.scm.fetch(self.version) - - def dumpd(self): - ret = {self.PARAM_NAME: self.name} - - if self.url: - ret[self.PARAM_URL] = self.url - - if self.version: - ret[self.PARAM_VERSION] = self.version - - return ret - - -class PkgManager(object): - PKG_DIR = "pkg" - - def __init__(self, repo): - self.repo = repo - self.pkg_dir = os.path.join(repo.dvc_dir, self.PKG_DIR) - if not os.path.exists(self.pkg_dir): - os.makedirs(self.pkg_dir) - - self.cache_dir = repo.cache.local.cache_dir - - def install(self, url, force=False, **kwargs): - pkg = Pkg(self.pkg_dir, url=url, **kwargs) - pkg.install(cache_dir=self.cache_dir, force=force) - - def uninstall(self, name): - pkg = Pkg(self.pkg_dir, name=name) - pkg.uninstall() - - def get_pkg(self, **kwargs): - return Pkg(self.pkg_dir, **kwargs) - - def imp(self, name, src, out=None, version=None): - scheme = urlparse(name).scheme - - if os.path.exists(name) or scheme: - pkg = Pkg(self.pkg_dir, url=name) - pkg.install(cache_dir=self.cache_dir) - else: - pkg = Pkg(self.pkg_dir, name=name) - - info = {Pkg.PARAM_NAME: pkg.name} - if version: - info[Pkg.PARAM_VERSION] = version - - self.repo.imp(src, out, pkg=info) - - @classmethod - def get(cls, url, src, out=None, version=None): - if not out: - out = os.path.basename(src) - - # Creating a directory right beside the output to make sure that they - # are on the same filesystem, so we could take the advantage of - # reflink and/or hardlink. Not using tempfile.TemporaryDirectory - # because it will create a symlink to tmpfs, which defeats the purpose - # and won't work with reflink/hardlink. - dpath = os.path.dirname(os.path.abspath(out)) - tmp_dir = os.path.join(dpath, "." + str(uuid.uuid4())) - try: - pkg = Pkg(tmp_dir, url=url, version=version) - pkg.install() - # Try any links possible to avoid data duplication. - # - # Not using symlink, because we need to remove cache after we are - # done, and to make that work we would have to copy data over - # anyway before removing the cache, so we might just copy it - # right away. - # - # Also, we can't use theoretical "move" link type here, because - # the same cache file might be used a few times in a directory. - pkg.repo.config.set( - Config.SECTION_CACHE, - Config.SECTION_CACHE_TYPE, - "reflink,hardlink,copy", - ) - src = os.path.join(pkg.path, urlparse(src).path.lstrip("/")) - output, = pkg.repo.find_outs_by_path(src) - pkg.repo.fetch(output.stage.path) - output.path_info = PathInfo(os.path.abspath(out)) - with output.repo.state: - output.checkout() - finally: - shutil.rmtree(tmp_dir) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index a5f2624e9c..2a8daba292 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -28,6 +28,7 @@ class Repo(object): from dvc.repo.move import move from dvc.repo.run import run from dvc.repo.imp import imp + from dvc.repo.imp_url import imp_url from dvc.repo.reproduce import reproduce from dvc.repo.checkout import checkout from dvc.repo.push import push @@ -38,8 +39,10 @@ class Repo(object): from dvc.repo.commit import commit from dvc.repo.diff import diff from dvc.repo.brancher import brancher + from dvc.repo.get import get + from dvc.repo.get_url import get_url - def __init__(self, root_dir=None, version=None): + def __init__(self, root_dir=None, rev=None): from dvc.state import State from dvc.lock import Lock from dvc.scm import SCM @@ -49,8 +52,6 @@ def __init__(self, root_dir=None, version=None): from dvc.scm.tree import WorkingTree from dvc.repo.tag import Tag - from dvc.pkg import PkgManager - root_dir = self.find_root(root_dir) self.root_dir = os.path.abspath(os.path.realpath(root_dir)) @@ -60,8 +61,8 @@ def __init__(self, root_dir=None, version=None): self.scm = SCM(self.root_dir, repo=self) - if version: - self.tree = self.scm.get_tree(version) + if rev: + self.tree = self.scm.get_tree(rev) else: self.tree = WorkingTree(self.root_dir) @@ -81,7 +82,6 @@ def __init__(self, root_dir=None, version=None): self.metrics = Metrics(self) self.tag = Tag(self) - self.pkg = PkgManager(self) self._ignore() @@ -121,6 +121,7 @@ def unprotect(self, target): def _ignore(self): from dvc.updater import Updater + from dvc.external_repo import ExternalRepo updater = Updater(self.dvc_dir) @@ -130,7 +131,7 @@ def _ignore(self): self.config.config_local_file, updater.updater_file, updater.lock.lock_file, - self.pkg.pkg_dir, + os.path.join(self.dvc_dir, ExternalRepo.REPOS_DIR), ] + self.state.temp_files if self.cache.local.cache_dir.startswith(self.root_dir): diff --git a/dvc/repo/get.py b/dvc/repo/get.py new file mode 100644 index 0000000000..7d87379afb --- /dev/null +++ b/dvc/repo/get.py @@ -0,0 +1,48 @@ +import os +import shutil +import shortuuid + +from dvc.config import Config +from dvc.path_info import PathInfo +from dvc.external_repo import ExternalRepo +from dvc.utils.compat import urlparse + + +@staticmethod +def get(url, path, out=None, rev=None): + out = out or os.path.basename(urlparse(path).path) + + # Creating a directory right beside the output to make sure that they + # are on the same filesystem, so we could take the advantage of + # reflink and/or hardlink. Not using tempfile.TemporaryDirectory + # because it will create a symlink to tmpfs, which defeats the purpose + # and won't work with reflink/hardlink. + dpath = os.path.dirname(os.path.abspath(out)) + tmp_dir = os.path.join(dpath, "." + str(shortuuid.uuid())) + try: + erepo = ExternalRepo(tmp_dir, url=url, rev=rev) + erepo.install() + # Try any links possible to avoid data duplication. + # + # Not using symlink, because we need to remove cache after we are + # done, and to make that work we would have to copy data over + # anyway before removing the cache, so we might just copy it + # right away. + # + # Also, we can't use theoretical "move" link type here, because + # the same cache file might be used a few times in a directory. + erepo.repo.config.set( + Config.SECTION_CACHE, + Config.SECTION_CACHE_TYPE, + "reflink,hardlink,copy", + ) + src = os.path.join(erepo.path, urlparse(path).path.lstrip("/")) + o, = erepo.repo.find_outs_by_path(src) + erepo.repo.fetch(o.stage.path) + o.path_info = PathInfo(os.path.abspath(out)) + with o.repo.state: + o.checkout() + erepo.repo.scm.git.close() + finally: + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) diff --git a/dvc/repo/get_url.py b/dvc/repo/get_url.py new file mode 100644 index 0000000000..7ba65c4c71 --- /dev/null +++ b/dvc/repo/get_url.py @@ -0,0 +1,19 @@ +import os + +import dvc.output as output +import dvc.dependency as dependency + +from dvc.utils.compat import urlparse + + +@staticmethod +def get_url(url, out=None): + out = out or os.path.basename(urlparse(url).path) + + if os.path.exists(url): + src = os.path.abspath(url) + out = os.path.abspath(out) + + dep, = dependency.loads_from(None, [src]) + out, = output.loads_from(None, [out], use_cache=False) + dep.download(out) diff --git a/dvc/repo/imp.py b/dvc/repo/imp.py index 635c11908a..d0fcd716d4 100644 --- a/dvc/repo/imp.py +++ b/dvc/repo/imp.py @@ -1,23 +1,9 @@ -from dvc.repo.scm_context import scm_context +from dvc.external_repo import ExternalRepo -@scm_context -def imp(self, url, out, resume=False, fname=None, pkg=None): - from dvc.stage import Stage +def imp(self, url, path, out=None, rev=None): + erepo = {ExternalRepo.PARAM_URL: url} + if rev is not None: + erepo[ExternalRepo.PARAM_VERSION] = rev - with self.state: - stage = Stage.create( - repo=self, cmd=None, deps=[url], outs=[out], fname=fname, pkg=pkg - ) - - if stage is None: - return None - - self.check_dag(self.stages() + [stage]) - - with self.state: - stage.run(resume=resume) - - stage.dump() - - return stage + return self.imp_url(path, out=out, erepo=erepo) diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py new file mode 100644 index 0000000000..e5c6d947c1 --- /dev/null +++ b/dvc/repo/imp_url.py @@ -0,0 +1,34 @@ +import os + +from dvc.utils.compat import urlparse +from dvc.repo.scm_context import scm_context + + +@scm_context +def imp_url(self, url, out=None, resume=False, fname=None, erepo=None): + from dvc.stage import Stage + + default_out = os.path.basename(urlparse(url).path) + out = out or default_out + + with self.state: + stage = Stage.create( + repo=self, + cmd=None, + deps=[url], + outs=[out], + fname=fname, + erepo=erepo, + ) + + if stage is None: + return None + + self.check_dag(self.stages() + [stage]) + + with self.state: + stage.run(resume=resume) + + stage.dump() + + return stage diff --git a/dvc/stage.py b/dvc/stage.py index 7610653f35..772b0d9065 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -216,11 +216,11 @@ def is_import(self): return not self.cmd and len(self.deps) == 1 and len(self.outs) == 1 @property - def is_pkg_import(self): + def is_repo_import(self): if not self.is_import: return False - return isinstance(self.deps[0], dependency.DependencyPKG) + return isinstance(self.deps[0], dependency.DependencyREPO) def _changed_deps(self): if self.locked: @@ -428,7 +428,7 @@ def create( validate_state=True, outs_persist=None, outs_persist_no_cache=None, - pkg=None, + erepo=None, ): if outs is None: outs = [] @@ -469,7 +469,7 @@ def create( outs_persist, outs_persist_no_cache, ) - stage.deps = dependency.loads_from(stage, deps, pkg=pkg) + stage.deps = dependency.loads_from(stage, deps, erepo=erepo) stage._check_circular_dependency() stage._check_duplicated_arguments() diff --git a/setup.py b/setup.py index 0e2f3437f0..b10b8be35e 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ def run(self): "psutil==5.6.2", "funcy>=1.12", "pathspec>=0.5.9", + "shortuuid>=0.5.0", ] # Extra dependencies for remote integrations @@ -95,7 +96,6 @@ def run(self): "flake8-docstrings", "jaraco.windows==3.9.2", "mock-ssh-server>=0.5.0", - "shortuuid>=0.5.0", ] if (sys.version_info) >= (3, 6): diff --git a/tests/conftest.py b/tests/conftest.py index 7698fb06eb..e349d7639f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,7 +138,7 @@ def temporary_windows_drive(repo_dir): @pytest.fixture -def pkg(repo_dir): +def erepo(repo_dir): repo = TestDvcFixture() repo.setUp() try: @@ -146,7 +146,7 @@ def pkg(repo_dir): stage_bar = repo.dvc.add(repo.BAR)[0] stage_data_dir = repo.dvc.add(repo.DATA_DIR)[0] repo.dvc.scm.add([stage_foo.path, stage_bar.path, stage_data_dir.path]) - repo.dvc.scm.commit("init pkg") + repo.dvc.scm.commit("init repo") rconfig = RemoteConfig(repo.dvc.config) rconfig.add("upstream", repo.dvc.cache.local.cache_dir, default=True) diff --git a/tests/func/test_get.py b/tests/func/test_get.py new file mode 100644 index 0000000000..b545f1d4a1 --- /dev/null +++ b/tests/func/test_get.py @@ -0,0 +1,40 @@ +import os +import filecmp + +from dvc.repo import Repo + +from tests.utils import trees_equal + + +def test_get_repo_file(repo_dir, erepo): + src = erepo.FOO + dst = erepo.FOO + "_imported" + + Repo.get(erepo.root_dir, src, dst) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) + + +def test_get_repo_dir(repo_dir, erepo): + src = erepo.DATA_DIR + dst = erepo.DATA_DIR + "_imported" + + Repo.get(erepo.root_dir, src, dst) + + assert os.path.exists(dst) + assert os.path.isdir(dst) + trees_equal(src, dst) + + +def test_get_repo_rev(repo_dir, erepo): + src = "version" + dst = src + + Repo.get(erepo.root_dir, src, dst, rev="branch") + + assert os.path.exists(dst) + assert os.path.isfile(dst) + with open(dst, "r+") as fobj: + assert fobj.read() == "branch" diff --git a/tests/func/test_get_url.py b/tests/func/test_get_url.py new file mode 100644 index 0000000000..ca1a5ae4f2 --- /dev/null +++ b/tests/func/test_get_url.py @@ -0,0 +1,15 @@ +import os +import filecmp + +from dvc.repo import Repo + + +def test_get_file(repo_dir): + src = repo_dir.FOO + dst = repo_dir.FOO + "_imported" + + Repo.get_url(src, dst) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) diff --git a/tests/func/test_import.py b/tests/func/test_import.py index ede967f596..b91dccfc49 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -1,187 +1,38 @@ -import dvc - -from dvc.utils.compat import str - import os -import logging -from uuid import uuid4 - -from dvc.utils.compat import urljoin -from dvc.exceptions import DvcException -from dvc.main import main -from mock import patch, mock_open, call -from tests.basic_env import TestDvc -from tests.utils import spy -from tests.utils.httpd import StaticFileServer - - -class TestCmdImport(TestDvc): - def test(self): - ret = main(["import", self.FOO, "import"]) - self.assertEqual(ret, 0) - self.assertTrue(os.path.exists("import.dvc")) - - ret = main(["import", "non-existing-file", "import"]) - self.assertNotEqual(ret, 0) - - def test_unsupported(self): - ret = main(["import", "unsupported://path", "import_unsupported"]) - self.assertNotEqual(ret, 0) - - -class TestDefaultOutput(TestDvc): - def test(self): - tmpdir = self.mkdtemp() - filename = str(uuid4()) - tmpfile = os.path.join(tmpdir, filename) - - with open(tmpfile, "w") as fd: - fd.write("content") - - ret = main(["import", tmpfile]) - self.assertEqual(ret, 0) - self.assertTrue(os.path.exists(filename)) - with open(filename) as fd: - self.assertEqual(fd.read(), "content") - - -class TestFailedImportMessage(TestDvc): - @patch("dvc.command.imp.urlparse") - def test(self, imp_urlparse_patch): - page_address = "http://somesite.com/file_name" - - def dvc_exception(*args, **kwargs): - raise DvcException("message") - - imp_urlparse_patch.side_effect = dvc_exception - - self._caplog.clear() - - with self._caplog.at_level(logging.ERROR, logger="dvc"): - main(["import", page_address]) - - expected_error = ( - "failed to import http://somesite.com/file_name." - " You could also try downloading it manually and" - " adding it with `dvc add` command." - ) - - assert expected_error in self._caplog.text - - -class TestInterruptedDownload(TestDvc): - @property - def remote(self): - return "http://localhost:8000/" - - def _prepare_interrupted_download(self): - import_url = urljoin(self.remote, self.FOO) - import_output = "imported_file" - tmp_file_name = import_output + ".part" - tmp_file_path = os.path.realpath( - os.path.join(self._root_dir, tmp_file_name) - ) - self._import_with_interrupt(import_output, import_url) - self.assertTrue(os.path.exists(tmp_file_name)) - self.assertFalse(os.path.exists(import_output)) - return import_output, import_url, tmp_file_path - - def _import_with_interrupt(self, import_output, import_url): - def interrupting_generator(): - yield self.FOO[0].encode("utf8") - raise KeyboardInterrupt - - with patch( - "requests.models.Response.iter_content", - return_value=interrupting_generator(), - ): - with patch( - "dvc.remote.http.RemoteHTTP._content_length", return_value=3 - ): - result = main(["import", import_url, import_output]) - self.assertEqual(result, 252) - - -class TestShouldResumeDownload(TestInterruptedDownload): - @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) - def test(self): - with StaticFileServer(): - import_output, import_url, tmp_file_path = ( - self._prepare_interrupted_download() - ) - - m = mock_open() - with patch("dvc.remote.http.open", m): - result = main( - ["import", "--resume", import_url, import_output] - ) - self.assertEqual(result, 0) - m.assert_called_once_with(tmp_file_path, "ab") - m_handle = m() - expected_calls = [call(b"o"), call(b"o")] - m_handle.write.assert_has_calls(expected_calls, any_order=False) - +import filecmp -class TestShouldNotResumeDownload(TestInterruptedDownload): - @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) - def test(self): - with StaticFileServer(): - import_output, import_url, tmp_file_path = ( - self._prepare_interrupted_download() - ) +from tests.utils import trees_equal - m = mock_open() - with patch("dvc.remote.http.open", m): - result = main(["import", import_url, import_output]) - self.assertEqual(result, 0) - m.assert_called_once_with(tmp_file_path, "wb") - m_handle = m() - expected_calls = [call(b"f"), call(b"o"), call(b"o")] - m_handle.write.assert_has_calls(expected_calls, any_order=False) +def test_import(repo_dir, dvc_repo, erepo): + src = erepo.FOO + dst = erepo.FOO + "_imported" -class TestShouldRemoveOutsBeforeImport(TestDvc): - def setUp(self): - super(TestShouldRemoveOutsBeforeImport, self).setUp() - tmp_dir = self.mkdtemp() - self.external_source = os.path.join(tmp_dir, "file") - with open(self.external_source, "w") as fobj: - fobj.write("content") + dvc_repo.imp(erepo.root_dir, src, dst) - def test(self): - remove_outs_call_counter = spy(dvc.stage.Stage.remove_outs) - with patch.object( - dvc.stage.Stage, "remove_outs", remove_outs_call_counter - ): - ret = main(["import", self.external_source]) - self.assertEqual(0, ret) + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) - self.assertEqual(1, remove_outs_call_counter.mock.call_count) +def test_import_dir(repo_dir, dvc_repo, erepo): + src = erepo.DATA_DIR + dst = erepo.DATA_DIR + "_imported" -class TestImportFilename(TestDvc): - def setUp(self): - super(TestImportFilename, self).setUp() - tmp_dir = self.mkdtemp() - self.external_source = os.path.join(tmp_dir, "file") - with open(self.external_source, "w") as fobj: - fobj.write("content") + dvc_repo.imp(erepo.root_dir, src, dst) - def test(self): - ret = main(["import", "-f", "bar.dvc", self.external_source]) - self.assertEqual(0, ret) - self.assertTrue(os.path.exists("bar.dvc")) + assert os.path.exists(dst) + assert os.path.isdir(dst) + trees_equal(src, dst) - os.remove("bar.dvc") - ret = main(["import", "--file", "bar.dvc", self.external_source]) - self.assertEqual(0, ret) - self.assertTrue(os.path.exists("bar.dvc")) +def test_import_rev(repo_dir, dvc_repo, erepo): + src = "version" + dst = src - os.remove("bar.dvc") - os.mkdir("sub") + dvc_repo.imp(erepo.root_dir, src, dst, rev="branch") - path = os.path.join("sub", "bar.dvc") - ret = main(["import", "--file", path, self.external_source]) - self.assertEqual(0, ret) - self.assertTrue(os.path.exists(path)) + assert os.path.exists(dst) + assert os.path.isfile(dst) + with open(dst, "r+") as fobj: + assert fobj.read() == "branch" diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py new file mode 100644 index 0000000000..3b1006f801 --- /dev/null +++ b/tests/func/test_import_url.py @@ -0,0 +1,187 @@ +import dvc + +from dvc.utils.compat import str + +import os +import logging +from uuid import uuid4 + +from dvc.utils.compat import urljoin +from dvc.exceptions import DvcException +from dvc.main import main +from mock import patch, mock_open, call +from tests.basic_env import TestDvc +from tests.utils import spy +from tests.utils.httpd import StaticFileServer + + +class TestCmdImport(TestDvc): + def test(self): + ret = main(["import-url", self.FOO, "import"]) + self.assertEqual(ret, 0) + self.assertTrue(os.path.exists("import.dvc")) + + ret = main(["import-url", "non-existing-file", "import"]) + self.assertNotEqual(ret, 0) + + def test_unsupported(self): + ret = main(["import-url", "unsupported://path", "import_unsupported"]) + self.assertNotEqual(ret, 0) + + +class TestDefaultOutput(TestDvc): + def test(self): + tmpdir = self.mkdtemp() + filename = str(uuid4()) + tmpfile = os.path.join(tmpdir, filename) + + with open(tmpfile, "w") as fd: + fd.write("content") + + ret = main(["import-url", tmpfile]) + self.assertEqual(ret, 0) + self.assertTrue(os.path.exists(filename)) + with open(filename) as fd: + self.assertEqual(fd.read(), "content") + + +class TestFailedImportMessage(TestDvc): + @patch("dvc.repo.imp_url.urlparse") + def test(self, imp_urlparse_patch): + page_address = "http://somesite.com/file_name" + + def dvc_exception(*args, **kwargs): + raise DvcException("message") + + imp_urlparse_patch.side_effect = dvc_exception + + self._caplog.clear() + + with self._caplog.at_level(logging.ERROR, logger="dvc"): + main(["import-url", page_address]) + + expected_error = ( + "failed to import http://somesite.com/file_name." + " You could also try downloading it manually and" + " adding it with `dvc add` command." + ) + + assert expected_error in self._caplog.text + + +class TestInterruptedDownload(TestDvc): + @property + def remote(self): + return "http://localhost:8000/" + + def _prepare_interrupted_download(self): + import_url = urljoin(self.remote, self.FOO) + import_output = "imported_file" + tmp_file_name = import_output + ".part" + tmp_file_path = os.path.realpath( + os.path.join(self._root_dir, tmp_file_name) + ) + self._import_with_interrupt(import_output, import_url) + self.assertTrue(os.path.exists(tmp_file_name)) + self.assertFalse(os.path.exists(import_output)) + return import_output, import_url, tmp_file_path + + def _import_with_interrupt(self, import_output, import_url): + def interrupting_generator(): + yield self.FOO[0].encode("utf8") + raise KeyboardInterrupt + + with patch( + "requests.models.Response.iter_content", + return_value=interrupting_generator(), + ): + with patch( + "dvc.remote.http.RemoteHTTP._content_length", return_value=3 + ): + result = main(["import-url", import_url, import_output]) + self.assertEqual(result, 252) + + +class TestShouldResumeDownload(TestInterruptedDownload): + @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) + def test(self): + with StaticFileServer(): + import_output, import_url, tmp_file_path = ( + self._prepare_interrupted_download() + ) + + m = mock_open() + with patch("dvc.remote.http.open", m): + result = main( + ["import-url", "--resume", import_url, import_output] + ) + self.assertEqual(result, 0) + m.assert_called_once_with(tmp_file_path, "ab") + m_handle = m() + expected_calls = [call(b"o"), call(b"o")] + m_handle.write.assert_has_calls(expected_calls, any_order=False) + + +class TestShouldNotResumeDownload(TestInterruptedDownload): + @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) + def test(self): + with StaticFileServer(): + import_output, import_url, tmp_file_path = ( + self._prepare_interrupted_download() + ) + + m = mock_open() + with patch("dvc.remote.http.open", m): + result = main(["import-url", import_url, import_output]) + self.assertEqual(result, 0) + m.assert_called_once_with(tmp_file_path, "wb") + m_handle = m() + expected_calls = [call(b"f"), call(b"o"), call(b"o")] + m_handle.write.assert_has_calls(expected_calls, any_order=False) + + +class TestShouldRemoveOutsBeforeImport(TestDvc): + def setUp(self): + super(TestShouldRemoveOutsBeforeImport, self).setUp() + tmp_dir = self.mkdtemp() + self.external_source = os.path.join(tmp_dir, "file") + with open(self.external_source, "w") as fobj: + fobj.write("content") + + def test(self): + remove_outs_call_counter = spy(dvc.stage.Stage.remove_outs) + with patch.object( + dvc.stage.Stage, "remove_outs", remove_outs_call_counter + ): + ret = main(["import-url", self.external_source]) + self.assertEqual(0, ret) + + self.assertEqual(1, remove_outs_call_counter.mock.call_count) + + +class TestImportFilename(TestDvc): + def setUp(self): + super(TestImportFilename, self).setUp() + tmp_dir = self.mkdtemp() + self.external_source = os.path.join(tmp_dir, "file") + with open(self.external_source, "w") as fobj: + fobj.write("content") + + def test(self): + ret = main(["import-url", "-f", "bar.dvc", self.external_source]) + self.assertEqual(0, ret) + self.assertTrue(os.path.exists("bar.dvc")) + + os.remove("bar.dvc") + + ret = main(["import-url", "--file", "bar.dvc", self.external_source]) + self.assertEqual(0, ret) + self.assertTrue(os.path.exists("bar.dvc")) + + os.remove("bar.dvc") + os.mkdir("sub") + + path = os.path.join("sub", "bar.dvc") + ret = main(["import-url", "--file", path, self.external_source]) + self.assertEqual(0, ret) + self.assertTrue(os.path.exists(path)) diff --git a/tests/func/test_pkg.py b/tests/func/test_pkg.py deleted file mode 100644 index b6ce904eaf..0000000000 --- a/tests/func/test_pkg.py +++ /dev/null @@ -1,202 +0,0 @@ -import os -import git -import pytest -import filecmp - -from dvc.pkg import PkgManager, InstallError, VersionError - -from tests.utils import trees_equal - - -def test_install_and_uninstall(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - dvc_repo.pkg.install(pkg.root_dir) - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - dvc_repo.pkg.install(pkg.root_dir) - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - git_repo = git.Repo(mypkg_dir) - assert git_repo.active_branch.name == "master" - - dvc_repo.pkg.uninstall(name) - assert not os.path.exists(mypkg_dir) - - dvc_repo.pkg.uninstall(name) - assert not os.path.exists(mypkg_dir) - - -def test_failed_install(repo_dir, dvc_repo): - with pytest.raises(InstallError): - dvc_repo.pkg.install("some-non-existing-url") - - -def test_failed_install_version(repo_dir, dvc_repo, pkg): - with pytest.raises(VersionError): - dvc_repo.pkg.install(pkg.root_dir, version="non-existing-version") - - -def test_install_atomic(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - dvc_repo.pkg.install(pkg.root_dir) - - with pytest.raises(InstallError): - dvc_repo.pkg.install("some-non-existing-url", name=name, force=True) - - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - -def test_uninstall_corrupted(repo_dir, dvc_repo): - name = os.path.basename("mypkg") - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - os.makedirs(mypkg_dir) - - dvc_repo.pkg.uninstall(name) - assert not os.path.exists(mypkg_dir) - - -def test_force_install(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - os.makedirs(mypkg_dir) - - dvc_repo.pkg.install(pkg.root_dir) - assert not os.listdir(mypkg_dir) - - dvc_repo.pkg.install(pkg.root_dir, force=True) - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - -def test_install_version(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - dvc_repo.pkg.install(pkg.root_dir, version="branch") - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - git_repo = git.Repo(mypkg_dir) - assert git_repo.active_branch.name == "branch" - - -def test_import(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - - src = pkg.FOO - dst = pkg.FOO + "_imported" - - dvc_repo.pkg.install(pkg.root_dir) - dvc_repo.pkg.imp(name, src, dst) - - assert os.path.exists(dst) - assert os.path.isfile(dst) - assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) - - -def test_import_dir(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - - src = pkg.DATA_DIR - dst = pkg.DATA_DIR + "_imported" - - dvc_repo.pkg.install(pkg.root_dir) - dvc_repo.pkg.imp(name, src, dst) - - assert os.path.exists(dst) - assert os.path.isdir(dst) - trees_equal(src, dst) - - -def test_import_url(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - src = pkg.FOO - dst = pkg.FOO + "_imported" - - dvc_repo.pkg.imp(pkg.root_dir, src, dst) - - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - assert os.path.exists(dst) - assert os.path.isfile(dst) - assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) - - -def test_import_url_version(repo_dir, dvc_repo, pkg): - name = os.path.basename(pkg.root_dir) - pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") - mypkg_dir = os.path.join(pkg_dir, name) - - src = "version" - dst = src - - dvc_repo.pkg.imp(pkg.root_dir, src, dst, version="branch") - - assert os.path.exists(pkg_dir) - assert os.path.isdir(pkg_dir) - assert os.path.exists(mypkg_dir) - assert os.path.isdir(mypkg_dir) - assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - - assert os.path.exists(dst) - assert os.path.isfile(dst) - with open(dst, "r+") as fobj: - assert fobj.read() == "branch" - - -def test_get_file(repo_dir, dvc_repo, pkg): - src = pkg.FOO - dst = pkg.FOO + "_imported" - - PkgManager.get(pkg.root_dir, src, dst) - - assert os.path.exists(dst) - assert os.path.isfile(dst) - assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) - - -def test_get_dir(repo_dir, dvc_repo, pkg): - src = pkg.DATA_DIR - dst = pkg.DATA_DIR + "_imported" - - PkgManager.get(pkg.root_dir, src, dst) - - assert os.path.exists(dst) - assert os.path.isdir(dst) - trees_equal(src, dst) diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 1f34a7b714..5f174a78d6 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -893,14 +893,14 @@ def test(self, mock_prompt): self.write(self.bucket, foo_key, self.FOO_CONTENTS) - import_stage = self.dvc.imp(out_foo_path, "import") + import_stage = self.dvc.imp_url(out_foo_path, "import") self.assertTrue(os.path.exists("import")) self.assertTrue(filecmp.cmp("import", self.FOO, shallow=False)) self.assertEqual(self.dvc.status(import_stage.path), {}) self.check_already_cached(import_stage) - import_remote_stage = self.dvc.imp( + import_remote_stage = self.dvc.imp_url( out_foo_path, out_foo_path + "_imported" ) self.assertEqual(self.dvc.status(import_remote_stage.path), {}) @@ -1173,7 +1173,7 @@ def test(self): with StaticFileServer(): import_url = urljoin(self.remote, self.FOO) import_output = "imported_file" - import_stage = self.dvc.imp(import_url, import_output) + import_stage = self.dvc.imp_url(import_url, import_output) self.assertTrue(os.path.exists(import_output)) self.assertTrue(filecmp.cmp(import_output, self.FOO, shallow=False)) @@ -1183,7 +1183,7 @@ def test(self): with StaticFileServer(handler="Content-MD5"): import_url = urljoin(self.remote, self.FOO) import_output = "imported_file" - import_stage = self.dvc.imp(import_url, import_output) + import_stage = self.dvc.imp_url(import_url, import_output) self.assertTrue(os.path.exists(import_output)) self.assertTrue(filecmp.cmp(import_output, self.FOO, shallow=False)) @@ -1318,7 +1318,7 @@ def test_force_with_dependencies(self): self.assertNotEqual(run_out.checksum, repro_out.checksum) def test_force_import(self): - ret = main(["import", self.FOO, self.BAR]) + ret = main(["import-url", self.FOO, self.BAR]) self.assertEqual(ret, 0) patch_download = patch.object( diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index 8598411aeb..6c56a1a532 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -154,7 +154,7 @@ def test_remote_dependency(self): assert main(["remote", "add", "tmp", tmp_path]) == 0 assert main(["remote", "add", "storage", "remote://tmp/storage"]) == 0 - assert main(["import", "remote://storage/file", "movie.txt"]) == 0 + assert main(["import-url", "remote://storage/file", "movie.txt"]) == 0 assert os.path.exists("movie.txt") diff --git a/tests/unit/command/test_get.py b/tests/unit/command/test_get.py new file mode 100644 index 0000000000..2c464ecc94 --- /dev/null +++ b/tests/unit/command/test_get.py @@ -0,0 +1,16 @@ +from dvc.cli import parse_args +from dvc.command.get import CmdGet + + +def test_get(mocker): + cli_args = parse_args( + ["get", "repo_url", "src", "--out", "out", "--rev", "version"] + ) + assert cli_args.func == CmdGet + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.Repo.get") + + assert cmd.run() == 0 + + m.assert_called_once_with("repo_url", path="src", out="out", rev="version") diff --git a/tests/unit/command/test_get_url.py b/tests/unit/command/test_get_url.py new file mode 100644 index 0000000000..9b4a5d54c4 --- /dev/null +++ b/tests/unit/command/test_get_url.py @@ -0,0 +1,14 @@ +from dvc.cli import parse_args +from dvc.command.get_url import CmdGetUrl + + +def test_get_url(mocker): + cli_args = parse_args(["get-url", "src", "out"]) + assert cli_args.func == CmdGetUrl + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.Repo.get_url") + + assert cmd.run() == 0 + + m.assert_called_once_with("src", out="out") diff --git a/tests/unit/command/test_imp.py b/tests/unit/command/test_imp.py new file mode 100644 index 0000000000..5775849f1f --- /dev/null +++ b/tests/unit/command/test_imp.py @@ -0,0 +1,16 @@ +from dvc.cli import parse_args +from dvc.command.imp import CmdImport + + +def test_import(mocker, dvc_repo): + cli_args = parse_args( + ["import", "repo_url", "src", "--out", "out", "--rev", "version"] + ) + assert cli_args.func == CmdImport + + cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo, "imp", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with("repo_url", path="src", out="out", rev="version") diff --git a/tests/unit/command/test_imp_url.py b/tests/unit/command/test_imp_url.py new file mode 100644 index 0000000000..c7d4e2e487 --- /dev/null +++ b/tests/unit/command/test_imp_url.py @@ -0,0 +1,16 @@ +from dvc.cli import parse_args +from dvc.command.imp_url import CmdImportUrl + + +def test_import_url(mocker, dvc_repo): + cli_args = parse_args( + ["import-url", "src", "out", "--resume", "--file", "file"] + ) + assert cli_args.func == CmdImportUrl + + cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo, "imp_url", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with("src", out="out", resume=True, fname="file") diff --git a/tests/unit/command/test_pkg.py b/tests/unit/command/test_pkg.py deleted file mode 100644 index 9ee45d1bb5..0000000000 --- a/tests/unit/command/test_pkg.py +++ /dev/null @@ -1,95 +0,0 @@ -from dvc.cli import parse_args -from dvc.command.pkg import CmdPkgInstall, CmdPkgUninstall, CmdPkgImport -from dvc.command.pkg import CmdPkgGet - - -def test_pkg_install(mocker, dvc_repo): - args = parse_args( - [ - "pkg", - "install", - "url", - "--version", - "version", - "--name", - "name", - "--force", - ] - ) - assert args.func == CmdPkgInstall - - cmd = args.func(args) - m = mocker.patch.object(cmd.repo.pkg, "install", autospec=True) - - assert cmd.run() == 0 - - m.assert_called_once_with( - "url", version="version", name="name", force=True - ) - - -def test_pkg_uninstall(mocker, dvc_repo): - targets = ["target1", "target2"] - args = parse_args(["pkg", "uninstall"] + targets) - assert args.func == CmdPkgUninstall - - cmd = args.func(args) - m = mocker.patch.object(cmd.repo.pkg, "uninstall", autospec=True) - - assert cmd.run() == 0 - - calls = [] - for target in targets: - calls.append(mocker.call(name=target)) - - m.assert_has_calls(calls) - - -def test_pkg_import(mocker, dvc_repo): - cli_args = parse_args( - [ - "pkg", - "import", - "pkg_name", - "pkg_path", - "--out", - "out", - "--version", - "version", - ] - ) - assert cli_args.func == CmdPkgImport - - cmd = cli_args.func(cli_args) - m = mocker.patch.object(cmd.repo.pkg, "imp", autospec=True) - - assert cmd.run() == 0 - - m.assert_called_once_with( - "pkg_name", "pkg_path", out="out", version="version" - ) - - -def test_pkg_get(mocker, dvc_repo): - cli_args = parse_args( - [ - "pkg", - "get", - "pkg_url", - "pkg_path", - "--out", - "out", - "--version", - "version", - ] - ) - assert cli_args.func == CmdPkgGet - - cmd = cli_args.func(cli_args) - m = mocker.patch("dvc.pkg.PkgManager.get") - - assert cmd.run() == 0 - - m.assert_called_once_with( - "pkg_url", "pkg_path", out="out", version="version" - )