diff --git a/.dvc/.gitignore b/.dvc/.gitignore index 5b2f952a64..775b312bc0 100644 --- a/.dvc/.gitignore +++ b/.dvc/.gitignore @@ -6,3 +6,4 @@ /state-journal /state-wal /cache +/pkg diff --git a/dvc/command/pkg.py b/dvc/command/pkg.py index 7b2bccc0f9..c82f46c218 100644 --- a/dvc/command/pkg.py +++ b/dvc/command/pkg.py @@ -3,8 +3,9 @@ import argparse import logging +from dvc.pkg import PkgManager from dvc.exceptions import DvcException -from dvc.command.base import CmdBase, fix_subparsers, append_doc_link +from .base import CmdBase, CmdBaseNoRepo, fix_subparsers, append_doc_link logger = logging.getLogger(__name__) @@ -14,22 +15,69 @@ class CmdPkgInstall(CmdBase): def run(self): try: self.repo.pkg.install( - self.args.address, - self.args.target_dir, - self.args.select, - self.args.file, + self.args.url, version=self.args.version, name=self.args.name ) return 0 except DvcException: logger.exception( - "failed to install package '{}'".format(self.args.address) + "failed to install package '{}'".format(self.args.url) ) return 1 -def add_parser(subparsers, parent_parser): - from dvc.command.config import parent_config_parser +class CmdPkgUninstall(CmdBase): + def run(self): + ret = 0 + for target in self.args.targets: + try: + self.repo.pkg.uninstall(target) + except DvcException: + logger.exception( + "failed to uninstall package '{}'".format(target) + ) + ret = 1 + return ret + + +class CmdPkgImport(CmdBase): + def run(self): + try: + self.repo.pkg.imp( + self.args.name, + self.args.src, + out=self.args.out, + version=self.args.version, + ) + return 0 + except DvcException: + logger.exception( + "failed to import '{}' from package '{}'".format( + self.args.src, self.args.name + ) + ) + return 1 + + +class CmdPkgGet(CmdBaseNoRepo): + def run(self): + try: + PkgManager.get( + self.args.url, + self.args.src, + out=self.args.out, + version=self.args.version, + ) + return 0 + except DvcException: + logger.exception( + "failed to get '{}' from package '{}'".format( + self.args.src, self.args.name + ) + ) + return 1 + +def add_parser(subparsers, parent_parser): PKG_HELP = "Manage DVC packages." pkg_parser = subparsers.add_parser( "pkg", @@ -48,42 +96,70 @@ def add_parser(subparsers, parent_parser): PKG_INSTALL_HELP = "Install package." pkg_install_parser = pkg_subparsers.add_parser( "install", - parents=[parent_config_parser, parent_parser], + parents=[parent_parser], description=append_doc_link(PKG_INSTALL_HELP, "pkg-install"), help=PKG_INSTALL_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + pkg_install_parser.add_argument("url", help="Package URL.") pkg_install_parser.add_argument( - "address", - nargs="?", - default="", - help="Package address: git:// or https://github.com/...", + "--version", nargs="?", help="Package version." ) pkg_install_parser.add_argument( - "target_dir", - metavar="target", + "--name", nargs="?", - default=".", - help="Target directory to deploy package outputs. " - "Default value is the current dir.", + help=( + "Package alias. If not specified, the name will be determined " + "from URL." + ), ) - pkg_install_parser.add_argument( - "-s", - "--select", - metavar="OUT", - action="append", - default=[], - help="Select and persist only specified outputs from a package. " - "The parameter can be used multiple times. " - "All outputs will be selected by default.", + pkg_install_parser.set_defaults(func=CmdPkgInstall) + + PKG_UNINSTALL_HELP = "Uninstall package(s)." + pkg_uninstall_parser = pkg_subparsers.add_parser( + "uninstall", + parents=[parent_parser], + description=append_doc_link(PKG_UNINSTALL_HELP, "pkg-uninstall"), + help=PKG_UNINSTALL_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, ) - pkg_install_parser.add_argument( - "-f", - "--file", - help="Specify name of the stage file. It should be " - "either 'Dvcfile' or have a '.dvc' suffix (e.g. " - "'prepare.dvc', 'clean.dvc', etc). " - "By default the file has 'mod_' prefix and imported package name " - "followed by .dvc", + pkg_uninstall_parser.add_argument( + "targets", nargs="*", default=[None], help="Package name." ) - pkg_install_parser.set_defaults(func=CmdPkgInstall) + pkg_uninstall_parser.set_defaults(func=CmdPkgUninstall) + + PKG_IMPORT_HELP = "Import data from package." + pkg_import_parser = pkg_subparsers.add_parser( + "import", + parents=[parent_parser], + description=append_doc_link(PKG_IMPORT_HELP, "pkg-import"), + help=PKG_IMPORT_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + pkg_import_parser.add_argument("name", help="Package name or url.") + pkg_import_parser.add_argument("src", help="Path to data in the package.") + pkg_import_parser.add_argument( + "-o", "--out", nargs="?", help="Destination path to put data to." + ) + pkg_import_parser.add_argument( + "--version", nargs="?", help="Package version." + ) + pkg_import_parser.set_defaults(func=CmdPkgImport) + + PKG_GET_HELP = "Download data from the package." + pkg_get_parser = pkg_subparsers.add_parser( + "get", + parents=[parent_parser], + description=append_doc_link(PKG_GET_HELP, "pkg-get"), + help=PKG_GET_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + pkg_get_parser.add_argument("url", help="Package url.") + pkg_get_parser.add_argument("src", help="Path to data in the package.") + pkg_get_parser.add_argument( + "-o", "--out", nargs="?", help="Destination path to put data to." + ) + pkg_get_parser.add_argument( + "--version", nargs="?", help="Package version." + ) + pkg_get_parser.set_defaults(func=CmdPkgGet) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 78ac2da312..2779230f9a 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -13,8 +13,11 @@ from dvc.dependency.ssh import DependencySSH from dvc.dependency.http import DependencyHTTP from dvc.dependency.https import DependencyHTTPS +from .pkg import DependencyPKG from dvc.remote import Remote +from dvc.pkg import Pkg + DEPS = [ DependencyGS, @@ -44,6 +47,7 @@ SCHEMA = output.SCHEMA.copy() del SCHEMA[schema.Optional(OutputBase.PARAM_CACHE)] del SCHEMA[schema.Optional(OutputBase.PARAM_METRIC)] +SCHEMA[schema.Optional(DependencyPKG.PARAM_PKG)] = Pkg.SCHEMA def _get(stage, p, info): @@ -55,6 +59,10 @@ def _get(stage, p, info): remote = Remote(stage.repo, name=parsed.netloc) return DEP_MAP[remote.scheme](stage, p, info, remote=remote) + if info and info.get(DependencyPKG.PARAM_PKG): + pkg = info.pop(DependencyPKG.PARAM_PKG) + return DependencyPKG(pkg, stage, p, info) + for d in DEPS: if d.supported(p): return d(stage, p, info) @@ -69,8 +77,9 @@ def loadd_from(stage, d_list): return ret -def loads_from(stage, s_list): +def loads_from(stage, s_list, pkg=None): ret = [] for s in s_list: - ret.append(_get(stage, s, {})) + info = {DependencyPKG.PARAM_PKG: pkg} if pkg else {} + ret.append(_get(stage, s, info)) return ret diff --git a/dvc/dependency/pkg.py b/dvc/dependency/pkg.py new file mode 100644 index 0000000000..755556208f --- /dev/null +++ b/dvc/dependency/pkg.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + +import os +import copy + +from dvc.utils.compat import urlparse + +from .local import DependencyLOCAL + + +class DependencyPKG(DependencyLOCAL): + PARAM_PKG = "pkg" + + def __init__(self, pkg, stage, *args, **kwargs): + self.pkg = stage.repo.pkg.get_pkg(**pkg) + super(DependencyLOCAL, self).__init__(stage, *args, **kwargs) + + def _parse_path(self, remote, path): + out_path = os.path.join( + self.pkg.repo.root_dir, urlparse(path).path.lstrip("/") + ) + + out, = self.pkg.repo.find_outs_by_path(out_path) + self.info = copy.copy(out.info) + self._pkg_stage = copy.copy(out.stage.path) + return self.REMOTE.path_cls(out.cache_path) + + @property + def is_in_repo(self): + return False + + def dumpd(self): + ret = super(DependencyLOCAL, self).dumpd() + ret[self.PARAM_PKG] = self.pkg.dumpd() + return ret + + def download(self, to, resume=False): + self.pkg.repo.fetch(self._pkg_stage) + to.info = copy.copy(self.info) + to.checkout() diff --git a/dvc/output/base.py b/dvc/output/base.py index f9360a78c3..7e0ba51a49 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -5,6 +5,7 @@ from schema import Or, Optional +import dvc.prompt as prompt from dvc.exceptions import DvcException from dvc.utils.compat import str, urlparse from dvc.remote.base import RemoteBASE @@ -271,8 +272,8 @@ def verify_metric(self): "verify metric is not supported for {}".format(self.scheme) ) - def download(self, to_info, resume=False): - self.remote.download([self.path_info], [to_info], resume=resume) + def download(self, to, resume=False): + self.remote.download([self.path_info], [to.path_info], resume=resume) def checkout(self, force=False, progress_callback=None, tag=None): if not self.use_cache: @@ -323,3 +324,96 @@ def get_files_number(self): def unprotect(self): if self.exists: self.remote.unprotect(self.path_info) + + def _collect_used_dir_cache(self, remote=None, force=False, jobs=None): + """Get a list of `info`s retaled to the given directory. + + - Pull the directory entry from the remote cache if it was changed. + + Example: + + Given the following commands: + + $ echo "foo" > directory/foo + $ echo "bar" > directory/bar + $ dvc add directory + + It will return something similar to the following list: + + [ + { 'path': 'directory/foo', 'md5': 'c157a79031e1', ... }, + { 'path': 'directory/bar', 'md5': 'd3b07384d113', ... }, + ] + """ + + ret = [] + + if self.cache.changed_cache_file(self.checksum): + try: + self.repo.cloud.pull( + [ + { + self.remote.PARAM_CHECKSUM: self.checksum, + "name": str(self), + } + ], + jobs=jobs, + remote=remote, + show_checksums=False, + ) + except DvcException: + logger.debug("failed to pull cache for '{}'".format(self)) + + if self.cache.changed_cache_file(self.checksum): + msg = ( + "Missing cache for directory '{}'. " + "Cache for files inside will be lost. " + "Would you like to continue? Use '-f' to force." + ) + if not force and not prompt.confirm(msg): + raise DvcException( + "unable to fully collect used cache" + " without cache for directory '{}'".format(self) + ) + else: + return ret + + for entry in self.dir_cache: + info = copy(entry) + path_info = self.path_info / entry[self.remote.PARAM_RELPATH] + info["name"] = str(path_info) + ret.append(info) + + return ret + + def get_used_cache(self, **kwargs): + """Get a dumpd of the given `out`, with an entry including the branch. + + The `used_cache` of an output is no more than its `info`. + + In case that the given output is a directory, it will also + include the `info` of its files. + """ + + if self.stage.is_pkg_import: + return [] + + if not self.use_cache: + return [] + + if not self.info: + logger.warning( + "Output '{}'({}) is missing version info. Cache for it will " + "not be collected. Use dvc repro to get your pipeline up to " + "date.".format(self, self.stage) + ) + return [] + + ret = [{self.remote.PARAM_CHECKSUM: self.checksum, "name": str(self)}] + + if not self.is_dir_checksum: + return ret + + ret.extend(self._collect_used_dir_cache(**kwargs)) + + return ret diff --git a/dvc/pkg.py b/dvc/pkg.py new file mode 100644 index 0000000000..ad7194bfd5 --- /dev/null +++ b/dvc/pkg.py @@ -0,0 +1,173 @@ +import os +import shutil +import logging + +from funcy import cached_property +from schema import Optional + +from dvc.config import Config +from dvc.cache import CacheConfig +from dvc.path_info import PathInfo +from dvc.exceptions import DvcException +from dvc.utils.compat import urlparse, TemporaryDirectory + + +logger = logging.getLogger(__name__) + + +class NotInstalledPkgError(DvcException): + def __init__(self, name): + super(NotInstalledPkgError, self).__init__( + "Package '{}' is not installed".format(name) + ) + + +class Pkg(object): + PARAM_NAME = "name" + PARAM_URL = "url" + PARAM_VERSION = "version" + + SCHEMA = { + Optional(PARAM_NAME): str, + Optional(PARAM_URL): str, + Optional(PARAM_VERSION): str, + } + + def __init__(self, pkg_dir, **kwargs): + self.pkg_dir = pkg_dir + self.url = kwargs.get(self.PARAM_URL) + + name = kwargs.get(self.PARAM_NAME) + if name is None: + name = os.path.basename(self.url) + self.name = name + + self.version = kwargs.get(self.PARAM_VERSION) + self.path = os.path.join(pkg_dir, self.name) + + @cached_property + def repo(self): + from dvc.repo import Repo + + if not self.installed: + raise NotInstalledPkgError(self.name) + + return Repo(self.path, version=self.version) + + @property + def installed(self): + return os.path.exists(self.path) + + def install(self, cache_dir=None): + import git + + if self.installed: + logger.info( + "Skipping installing '{}'('{}') as it is already " + "installed.".format(self.name, self.url) + ) + return + + git.Repo.clone_from( + self.url, self.path, depth=1, no_single_branch=True + ) + + if self.version: + self.repo.scm.checkout(self.version) + + if cache_dir: + cache_config = CacheConfig(self.repo.config) + cache_config.set_dir(cache_dir, level=Config.LEVEL_LOCAL) + + def uninstall(self): + if not self.installed: + logger.info( + "Skipping uninstalling '{}' as it is not installed.".format( + self.name + ) + ) + return + + shutil.rmtree(self.path) + + def update(self): + self.repo.scm.fetch(self.version) + + def dumpd(self): + ret = {self.PARAM_NAME: self.name} + + if self.url: + ret[self.PARAM_URL] = self.url + + if self.version: + ret[self.PARAM_VERSION] = self.version + + return ret + + +class PkgManager(object): + PKG_DIR = "pkg" + + def __init__(self, repo): + self.repo = repo + self.pkg_dir = os.path.join(repo.dvc_dir, self.PKG_DIR) + self.cache_dir = repo.cache.local.cache_dir + + def install(self, url, **kwargs): + pkg = Pkg(self.pkg_dir, url=url, **kwargs) + pkg.install(cache_dir=self.cache_dir) + + def uninstall(self, name): + pkg = Pkg(self.pkg_dir, name=name) + pkg.uninstall() + + def get_pkg(self, **kwargs): + return Pkg(self.pkg_dir, **kwargs) + + def imp(self, name, src, out=None, version=None): + scheme = urlparse(name).scheme + + if os.path.exists(name) or scheme: + pkg = Pkg(self.pkg_dir, url=name) + pkg.install(cache_dir=self.cache_dir) + else: + pkg = Pkg(self.pkg_dir, name=name) + + info = {Pkg.PARAM_NAME: pkg.name} + if version: + info[Pkg.PARAM_VERSION] = version + + self.repo.imp(src, out, pkg=info) + + @classmethod + def get(cls, url, src, out=None, version=None): + if not out: + out = os.path.basename(src) + + # Creating a directory right beside the output to make sure that they + # are on the same filesystem, so we could take the advantage of + # reflink and/or hardlink. + dpath = os.path.dirname(os.path.abspath(out)) + with TemporaryDirectory(dir=dpath, prefix=".") as tmp_dir: + pkg = Pkg(tmp_dir, url=url, version=version) + pkg.install() + # Try any links possible to avoid data duplication. + # + # Not using symlink, because we need to remove cache after we are + # done, and to make that work we would have to copy data over + # anyway before removing the cache, so we might just copy it + # right away. + # + # Also, we can't use theoretical "move" link type here, because + # the same cache file might be used a few times in a directory. + pkg.repo.config.set( + Config.SECTION_CACHE, + Config.SECTION_CACHE_TYPE, + "reflink,hardlink,copy", + ) + src = os.path.join(pkg.path, urlparse(src).path.lstrip("/")) + output, = pkg.repo.find_outs_by_path(src) + pkg.repo.fetch(output.stage.path) + output.path_info = PathInfo(os.path.abspath(out)) + with output.repo.state: + output.checkout() diff --git a/dvc/remote/local/__init__.py b/dvc/remote/local/__init__.py index 94c48e1c0e..8c37e7a51b 100644 --- a/dvc/remote/local/__init__.py +++ b/dvc/remote/local/__init__.py @@ -285,7 +285,7 @@ def _group(self, checksum_infos, show_checksums=False): by_md5[md5] = {"name": md5} continue - name = info[self.PARAM_PATH] + name = info["name"] branch = info.get("branch") if branch: name += "({})".format(branch) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 5b86d1ce78..b2e2dfad81 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -3,11 +3,8 @@ import os import logging -import dvc.prompt as prompt - from dvc.config import Config from dvc.exceptions import ( - DvcException, NotDvcRepoError, OutputNotFoundError, TargetNotDirectoryError, @@ -41,7 +38,7 @@ class Repo(object): from dvc.repo.diff import diff from dvc.repo.brancher import brancher - def __init__(self, root_dir=None): + def __init__(self, root_dir=None, version=None): from dvc.state import State from dvc.lock import Lock from dvc.scm import SCM @@ -50,7 +47,8 @@ def __init__(self, root_dir=None): from dvc.repo.metrics import Metrics from dvc.scm.tree import WorkingTree from dvc.repo.tag import Tag - from dvc.repo.pkg import Pkg + + from dvc.pkg import PkgManager root_dir = self.find_root(root_dir) @@ -59,9 +57,13 @@ def __init__(self, root_dir=None): self.config = Config(self.dvc_dir) - self.tree = WorkingTree(self.root_dir) - self.scm = SCM(self.root_dir, repo=self) + + if version: + self.tree = self.scm.get_tree(version) + else: + self.tree = WorkingTree(self.root_dir) + self.lock = Lock(self.dvc_dir) # NOTE: storing state and link_state in the repository itself to avoid # any possible state corruption in 'shared cache dir' scenario. @@ -78,7 +80,7 @@ def __init__(self, root_dir=None): self.metrics = Metrics(self) self.tag = Tag(self) - self.pkg = Pkg(self) + self.pkg = PkgManager(self) self._ignore() @@ -127,6 +129,7 @@ def _ignore(self): self.config.config_local_file, updater.updater_file, updater.lock.lock_file, + self.pkg.pkg_dir, ] + self.state.temp_files if self.cache.local.cache_dir.startswith(self.root_dir): @@ -173,100 +176,6 @@ def collect(self, target, with_deps=False, recursive=False): return ret - def _collect_dir_cache( - self, out, branch=None, remote=None, force=False, jobs=None - ): - """Get a list of `info`s retaled to the given directory. - - - Pull the directory entry from the remote cache if it was changed. - - Example: - - Given the following commands: - - $ echo "foo" > directory/foo - $ echo "bar" > directory/bar - $ dvc add directory - - It will return something similar to the following list: - - [ - { 'path': 'directory', 'md5': '168fd6761b9c.dir', ... }, - { 'path': 'directory/foo', 'md5': 'c157a79031e1', ... }, - { 'path': 'directory/bar', 'md5': 'd3b07384d113', ... }, - ] - """ - info = out.dumpd() - ret = [info] - r = out.remote - md5 = info[r.PARAM_CHECKSUM] - - if self.cache.local.changed_cache_file(md5): - try: - self.cloud.pull( - ret, jobs=jobs, remote=remote, show_checksums=False - ) - except DvcException as exc: - msg = "Failed to pull cache for '{}': {}" - logger.debug(msg.format(out, exc)) - - if self.cache.local.changed_cache_file(md5): - msg = ( - "Missing cache for directory '{}'. " - "Cache for files inside will be lost. " - "Would you like to continue? Use '-f' to force. " - ) - if not force and not prompt.confirm(msg): - raise DvcException( - "unable to fully collect used cache" - " without cache for directory '{}'".format(out) - ) - else: - return ret - - for i in out.dir_cache: - i["branch"] = branch - i[r.PARAM_PATH] = os.path.join( - info[r.PARAM_PATH], i[r.PARAM_RELPATH] - ) - ret.append(i) - - return ret - - def _collect_used_cache( - self, out, branch=None, remote=None, force=False, jobs=None - ): - """Get a dumpd of the given `out`, with an entry including the branch. - - The `used_cache` of an output is no more than its `info`. - - In case that the given output is a directory, it will also - include the `info` of its files. - """ - if not out.use_cache or not out.info: - if not out.info: - logger.warning( - "Output '{}'({}) is missing version " - "info. Cache for it will not be collected. " - "Use dvc repro to get your pipeline up to " - "date.".format(out, out.stage) - ) - return [] - - info = out.dumpd() - info["branch"] = branch - ret = [info] - - if out.scheme != "local": - return ret - - if not out.is_dir_checksum: - return ret - - return self._collect_dir_cache( - out, branch=branch, remote=remote, force=force, jobs=jobs - ) - def used_cache( self, target=None, @@ -325,14 +234,12 @@ def used_cache( for out in stage.outs: scheme = out.path_info.scheme + used_cache = out.get_used_cache( + remote=remote, force=force, jobs=jobs + ) + cache[scheme].extend( - self._collect_used_cache( - out, - branch=branch, - remote=remote, - force=force, - jobs=jobs, - ) + dict(entry, branch=branch) for entry in used_cache ) return cache @@ -547,7 +454,7 @@ def open(self, path, remote=None, mode="r", encoding=None): raise ValueError("Can't open a dir") with self.state: - cache_info = self._collect_used_cache(out, remote=remote) + cache_info = out.get_used_cache(remote=remote) self.cloud.pull(cache_info, remote=remote) cache_file = self.cache.local.checksum_to_path_info(out.checksum) diff --git a/dvc/repo/imp.py b/dvc/repo/imp.py index acbbc4afd6..635c11908a 100644 --- a/dvc/repo/imp.py +++ b/dvc/repo/imp.py @@ -2,12 +2,12 @@ @scm_context -def imp(self, url, out, resume=False, fname=None): +def imp(self, url, out, resume=False, fname=None, pkg=None): from dvc.stage import Stage with self.state: stage = Stage.create( - repo=self, cmd=None, deps=[url], outs=[out], fname=fname + repo=self, cmd=None, deps=[url], outs=[out], fname=fname, pkg=pkg ) if stage is None: diff --git a/dvc/repo/pkg/__init__.py b/dvc/repo/pkg/__init__.py deleted file mode 100644 index f511203cea..0000000000 --- a/dvc/repo/pkg/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -class Pkg(object): - def __init__(self, repo): - self.repo = repo - - def install(self, *args, **kwargs): - from dvc.repo.pkg.install import install - - return install(self.repo, *args, **kwargs) diff --git a/dvc/repo/pkg/install.py b/dvc/repo/pkg/install.py deleted file mode 100644 index 671f2e9d9e..0000000000 --- a/dvc/repo/pkg/install.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import unicode_literals - -import os -import shutil -import logging - -from dvc.exceptions import DvcException -from dvc.stage import Stage -from dvc.scm.git.temp_git_repo import TempGitRepo - - -logger = logging.getLogger(__name__) - - -class PackageManager(object): - PACKAGE_FILE = "package.yaml" - - @staticmethod - def read_packages(): - return [] - - @staticmethod - def get_package(addr): - for pkg_class in [GitPackage]: - try: - return pkg_class() - except Exception: - pass - return None - - def __init__(self, addr): - self._addr = addr - - -class Package(object): - MODULES_DIR = "dvc_mod" - - def install_or_update( - self, parent_repo, address, target_dir, select=[], fname=None - ): - raise NotImplementedError( - "A method of abstract Package class was called" - ) - - def is_in_root(self): - return True - - -class GitPackage(Package): - DEF_DVC_FILE_PREFIX = "mod_" - - def install_or_update( - self, parent_repo, address, target_dir, select=[], fname=None - ): - from git.cmd import Git - - if not self.is_in_root(): - raise DvcException( - "This command can be run only from a repository root" - ) - - if not os.path.exists(self.MODULES_DIR): - logger.debug("Creating modules dir {}".format(self.MODULES_DIR)) - os.makedirs(self.MODULES_DIR) - parent_repo.scm.ignore(os.path.abspath(self.MODULES_DIR)) - - module_name = Git.polish_url(address).strip("/").split("/")[-1] - if not module_name: - raise DvcException( - "Package address error: unable to extract package name" - ) - - with TempGitRepo( - address, module_name, Package.MODULES_DIR - ) as tmp_repo: - outputs_to_copy = tmp_repo.outs - if select: - outputs_to_copy = list( - filter(lambda out: out.dvc_path in select, outputs_to_copy) - ) - - fetched_stage_files = set( - map(lambda o: o.stage.path, outputs_to_copy) - ) - tmp_repo.fetch(fetched_stage_files) - - module_dir = self.create_module_dir(module_name) - tmp_repo.persist_to(module_dir, parent_repo) - - dvc_file = self.get_dvc_file_name(fname, target_dir, module_name) - try: - self.persist_stage_and_scm_state( - parent_repo, outputs_to_copy, target_dir, dvc_file - ) - except Exception as ex: - raise DvcException( - "Package '{}' was installed " - "but stage file '{}' " - "was not created properly: {}".format( - address, dvc_file, ex - ) - ) - - parent_repo.checkout(dvc_file) - - @staticmethod - def persist_stage_and_scm_state( - parent_repo, outputs_to_copy, target_dir, dvc_file - ): - stage = Stage.create( - repo=parent_repo, - fname=dvc_file, - validate_state=False, - wdir=target_dir, - ) - stage.outs = list( - map(lambda o: o.assign_to_stage_file(stage), outputs_to_copy) - ) - - for out in stage.outs: - parent_repo.scm.ignore(out.fspath, in_curr_dir=True) - - stage.dump() - - @staticmethod - def create_module_dir(module_name): - module_dir = os.path.join(GitPackage.MODULES_DIR, module_name) - if os.path.exists(module_dir): - logger.info("Updating package {}".format(module_name)) - shutil.rmtree(module_dir) - else: - logger.info("Adding package {}".format(module_name)) - return module_dir - - def get_dvc_file_name(self, stage_file, target_dir, module_name): - if stage_file: - dvc_file_path = stage_file - else: - dvc_file_name = self.DEF_DVC_FILE_PREFIX + module_name + ".dvc" - dvc_file_path = os.path.join(target_dir, dvc_file_name) - return dvc_file_path - - -def install(self, address, target_dir, select=[], fname=None): - """ - Install package. - - The command can be run only from DVC project root. - - E.g. - Having: DVC package in https://github.com/dmpetrov/tag_classifier - - $ dvc pkg install https://github.com/dmpetrov/tag_classifier - - Result: tag_classifier package in dvc_mod/ directory - """ - - if not os.path.isdir(target_dir): - raise DvcException( - "target directory '{}' does not exist".format(target_dir) - ) - - curr_dir = os.path.realpath(os.curdir) - if not os.path.realpath(target_dir).startswith(curr_dir): - raise DvcException( - "the current directory should be a subdirectory of the target " - "dir '{}'".format(target_dir) - ) - - addresses = [address] if address else PackageManager.read_packages() - for addr in addresses: - mgr = PackageManager.get_package(addr) - mgr.install_or_update(self, address, target_dir, select, fname) diff --git a/dvc/scm/git/temp_git_repo.py b/dvc/scm/git/temp_git_repo.py deleted file mode 100644 index c40a5f36b0..0000000000 --- a/dvc/scm/git/temp_git_repo.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import shutil -import tempfile -import logging - -from dvc.exceptions import DvcException - - -logger = logging.getLogger(__name__) - - -class TempRepoException(DvcException): - """Raise when temporary package repository has not properly created""" - - def __init__(self, temp_repo, msg, cause=None): - m = "temp repository '{}' error: {}".format(temp_repo.addr, msg) - super(TempRepoException, self).__init__(m, cause) - - -class TempGitRepo(object): - GIT_DIR_TO_REMOVE = ".git" - GIT_FILES_TO_REMOVE = [".gitignore", ".gitmodules"] - SUFFIX = "tmp_DVC_mod" - - def __init__(self, addr, module_name, modules_dir): - self.addr = addr - self.modules_dir = modules_dir - self._tmp_mod_prefix = module_name + "_" + self.SUFFIX + "_" - self._reset_state() - - def _reset_state(self): - self._set_state(None, []) - - def _set_state(self, cloned_tmp_dir, outs): - from dvc.repo import Repo - - self.repo = Repo(cloned_tmp_dir) if cloned_tmp_dir else None - - self._cloned_tmp_dir = cloned_tmp_dir - self.outs = outs - self._moved_files = [] - - def fetch(self, targets): - for target in targets: - try: - self.repo.fetch(target) - except Exception as ex: - msg = "error in fetching data from {}: {}".format( - os.path.basename(target), ex - ) - raise TempRepoException(self, msg) - - @property - def is_state_set(self): - return self._cloned_tmp_dir is not None - - def __enter__(self): - if self.is_state_set: - raise TempRepoException(self, "Git repo cloning duplication") - - module_temp_dir = None - try: - module_temp_dir = self._clone_to_temp_dir() - self._remove_git_files_and_dirs(module_temp_dir) - - outputs = self._read_outputs(module_temp_dir) - self._set_state(module_temp_dir, outputs) - finally: - self._clean_mod_temp_dir(module_temp_dir) - - return self - - def _clone_to_temp_dir(self): - import git - - result = tempfile.mktemp( - prefix=self._tmp_mod_prefix, dir=self.modules_dir - ) - logger.debug( - "Cloning git repository {} to temp dir {}".format( - self.addr, result - ) - ) - git.Repo.clone_from(self.addr, result, depth=1) - return result - - def _remove_git_files_and_dirs(self, module_temp_dir): - logger.debug( - "Removing git meta files from {}: {} dif and {}".format( - module_temp_dir, - TempGitRepo.GIT_DIR_TO_REMOVE, - ", ".join(TempGitRepo.GIT_FILES_TO_REMOVE), - ) - ) - shutil.rmtree( - os.path.join(module_temp_dir, TempGitRepo.GIT_DIR_TO_REMOVE) - ) - for item in TempGitRepo.GIT_FILES_TO_REMOVE: - fname = os.path.join(module_temp_dir, item) - if os.path.exists(fname): - os.remove(fname) - - def _clean_mod_temp_dir(self, module_temp_dir): - if not self.is_state_set: - if module_temp_dir and os.path.exists(module_temp_dir): - shutil.rmtree(module_temp_dir, ignore_errors=True) - - def _read_outputs(self, module_temp_dir): - from dvc.repo import Repo - - pkg_repo = Repo(module_temp_dir) - stages = pkg_repo.stages() - return [out for s in stages for out in s.outs] - - def persist_to(self, module_dir, parent_repo): - if not self.is_state_set: - raise TempRepoException(self, "cannot persist") - - tmp_repo_cache = self.repo.cache.local.path_info.fspath - - for prefix in os.listdir(tmp_repo_cache): - if len(prefix) != 2: - logger.warning( - "wrong dir format in cache {}: dir {}".format( - tmp_repo_cache, prefix - ) - ) - self._move_all_cache_files(parent_repo, prefix, tmp_repo_cache) - - shutil.move(self._cloned_tmp_dir, module_dir) - self._reset_state() - - @staticmethod - def _move_all_cache_files(parent_repo, prefix, tmp_repo_cache): - obj_name = os.path.join(tmp_repo_cache, prefix) - for suffix in os.listdir(obj_name): - src_path = os.path.join(tmp_repo_cache, prefix, suffix) - pre = os.path.join( - parent_repo.cache.local.path_info.fspath, prefix - ) - if not os.path.exists(pre): - os.mkdir(pre) - dest = os.path.join(pre, suffix) - shutil.move(src_path, dest) - - def __exit__(self, exc_type, exc_value, traceback): - if self.is_state_set: - shutil.rmtree(self._cloned_tmp_dir) - self._reset_state() - pass diff --git a/dvc/scm/git/tree.py b/dvc/scm/git/tree.py index 3312942ff5..2ef8016ab2 100644 --- a/dvc/scm/git/tree.py +++ b/dvc/scm/git/tree.py @@ -3,6 +3,7 @@ from dvc.ignore import DvcIgnoreFilter from dvc.utils.compat import StringIO, BytesIO +from dvc.exceptions import DvcException from dvc.scm.tree import BaseTree @@ -76,10 +77,42 @@ def _is_tree_and_contains(obj, path): return True return False + def _try_fetch_from_remote(self): + import git + + try: + # checking if tag/branch exists locally + self.git.git.show_ref(self.rev, verify=True) + return + except git.exc.GitCommandError: + pass + + try: + # checking if it exists on the remote + self.git.git.ls_remote("origin", self.rev, exit_code=True) + # fetching remote tag/branch so we can reference it locally + self.git.git.fetch("origin", "{rev}:{rev}".format(rev=self.rev)) + except git.exc.GitCommandError: + pass + def git_object_by_path(self, path): + import git + path = os.path.relpath(os.path.realpath(path), self.git.working_dir) assert path.split(os.sep, 1)[0] != ".." - tree = self.git.tree(self.rev) + + self._try_fetch_from_remote() + + try: + tree = self.git.tree(self.rev) + except git.exc.BadName as exc: + raise DvcException( + "revision '{}' not found in git '{}'".format( + self.rev, os.path.relpath(self.git.working_dir) + ), + cause=exc, + ) + if not path or path == ".": return tree for i in path.split(os.sep): diff --git a/dvc/stage.py b/dvc/stage.py index 118ee6e180..f5feccf0c0 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -213,6 +213,13 @@ def is_import(self): """Whether the stage file was created with `dvc import`.""" return not self.cmd and len(self.deps) == 1 and len(self.outs) == 1 + @property + def is_pkg_import(self): + if not self.is_import: + return False + + return isinstance(self.deps[0], dependency.DependencyPKG) + def _changed_deps(self): if self.locked: return False @@ -419,6 +426,7 @@ def create( validate_state=True, outs_persist=None, outs_persist_no_cache=None, + pkg=None, ): if outs is None: outs = [] @@ -459,7 +467,7 @@ def create( outs_persist, outs_persist_no_cache, ) - stage.deps = dependency.loads_from(stage, deps) + stage.deps = dependency.loads_from(stage, deps, pkg=pkg) stage._check_circular_dependency() stage._check_duplicated_arguments() @@ -795,9 +803,7 @@ def run(self, dry=False, resume=False, no_commit=False, force=False): if self._already_cached() and not force: self.outs[0].checkout() else: - self.deps[0].download( - self.outs[0].path_info, resume=resume - ) + self.deps[0].download(self.outs[0], resume=resume) elif self.is_data_source: msg = "Verifying data sources in '{}'".format(self.relpath) diff --git a/dvc/utils/compat.py b/dvc/utils/compat.py index 0abb08c4aa..e8ff4a1a02 100644 --- a/dvc/utils/compat.py +++ b/dvc/utils/compat.py @@ -98,6 +98,7 @@ def _makedirs(name, mode=0o777, exist_ok=False): from io import open # noqa: F401 import pathlib2 as pathlib # noqa: F401 from collections import Mapping # noqa: F401 + from backports.tempfile import TemporaryDirectory # noqa: F403 builtin_str = str # noqa: F821 bytes = str # noqa: F821 @@ -144,6 +145,7 @@ def __exit__(self, *args): ) # noqa: F401 import configparser as ConfigParser # noqa: F401 from collections.abc import Mapping # noqa: F401 + from tempfile import TemporaryDirectory # noqa: F401 builtin_str = str # noqa: F821 str = str # noqa: F821 diff --git a/setup.py b/setup.py index dc413d4895..b0d1391e7a 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,11 @@ def run(self): "oss": oss, "ssh": ssh, # NOTE: https://github.com/inveniosoftware/troubleshooting/issues/1 - ":python_version=='2.7'": ["futures", "pathlib2"], + ":python_version=='2.7'": [ + "futures", + "pathlib2", + "backports.tempfile", + ], "tests": tests_requirements, }, keywords="data science, data version control, machine learning", diff --git a/tests/basic_env.py b/tests/basic_env.py index 8fc20e08c3..44248a2bea 100644 --- a/tests/basic_env.py +++ b/tests/basic_env.py @@ -251,19 +251,3 @@ def __init__(self, methodName, root_dir=None): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog - - -class TestDvcPkg(TestDvcFixture, TestCase): - GIT_PKG = "git_pkg" - CACHE_DIR = "mycache" - - def __init__(self, methodName, root_dir=None): - TestDvcFixture.__init__(self, root_dir) - TestCase.__init__(self, methodName) - - self.pkg_dir = os.path.join(self._root_dir, self.GIT_PKG) - cache_dir = os.path.join(self._root_dir, self.CACHE_DIR) - self.pkg_fixture = TestDvcDataFileFixture( - root_dir=self.pkg_dir, cache_dir=cache_dir - ) - self.pkg_fixture.setUp() diff --git a/tests/conftest.py b/tests/conftest.py index 7e7d1208ec..522baf3fe3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,11 @@ from git import Repo from git.exc import GitCommandNotFound +from dvc.remote.config import RemoteConfig from dvc.utils.compat import cast_bytes_py2 from dvc.remote.ssh.connection import SSHConnection from dvc.repo import Repo as DvcRepo -from .basic_env import TestDirFixture +from .basic_env import TestDirFixture, TestDvcFixture # Prevent updater and analytics from running their processes @@ -132,3 +133,39 @@ def temporary_windows_drive(repo_dir): ) if tear_down_result == 0: raise RuntimeError("Could not unmount windows drive!") + + +@pytest.fixture +def pkg(repo_dir): + repo = TestDvcFixture() + repo.setUp() + try: + stage_foo = repo.dvc.add(repo.FOO)[0] + stage_bar = repo.dvc.add(repo.BAR)[0] + stage_data_dir = repo.dvc.add(repo.DATA_DIR)[0] + repo.dvc.scm.add([stage_foo.path, stage_bar.path, stage_data_dir.path]) + repo.dvc.scm.commit("init pkg") + + rconfig = RemoteConfig(repo.dvc.config) + rconfig.add("upstream", repo.dvc.cache.local.cache_dir, default=True) + repo.dvc.scm.add([repo.dvc.config.config_file]) + repo.dvc.scm.commit("add remote") + + repo.create("version", "master") + repo.dvc.add("version") + repo.dvc.scm.add([".gitignore", "version.dvc"]) + repo.dvc.scm.commit("master") + + repo.dvc.scm.checkout("branch", create_new=True) + os.unlink(os.path.join(repo.root_dir, "version")) + repo.create("version", "branch") + repo.dvc.add("version") + repo.dvc.scm.add([".gitignore", "version.dvc"]) + repo.dvc.scm.commit("branch") + + repo.dvc.scm.checkout("master") + + repo._popd() + yield repo + finally: + repo.tearDown() diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index bc7a3864e6..644a534a9d 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -290,18 +290,21 @@ def _test_cloud(self): self.assertEqual(len(stages), 1) stage = stages[0] self.assertTrue(stage is not None) - cache = stage.outs[0].cache_path - info = stage.outs[0].dumpd() - md5 = info["md5"] + out = stage.outs[0] + cache = out.cache_path + name = str(out) + md5 = out.checksum + info = {"name": name, out.remote.PARAM_CHECKSUM: md5} stages = self.dvc.add(self.DATA_DIR) self.assertEqual(len(stages), 1) stage_dir = stages[0] self.assertTrue(stage_dir is not None) - - cache_dir = stage_dir.outs[0].cache_path - info_dir = stage_dir.outs[0].dumpd() - md5_dir = info_dir["md5"] + out_dir = stage_dir.outs[0] + cache_dir = out.cache_path + name_dir = str(out) + md5_dir = out.checksum + info_dir = {"name": name_dir, out_dir.remote.PARAM_CHECKSUM: md5_dir} with self.cloud.repo.state: # Check status diff --git a/tests/func/test_pkg.py b/tests/func/test_pkg.py index 06a73e4e18..14bfdba8cf 100644 --- a/tests/func/test_pkg.py +++ b/tests/func/test_pkg.py @@ -1,35 +1,156 @@ import os +import git +import filecmp -from dvc.repo.pkg.install import Package -from tests.basic_env import TestDvcPkg +from dvc.pkg import PkgManager +from tests.utils import trees_equal -class TestPkgBasic(TestDvcPkg): - NEW_PKG_DIR = "new_pkg" - def setUp(self): - super(TestPkgBasic, self).setUp() +def test_install_and_uninstall(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) - self.to_dir = os.path.join(self._root_dir, self.NEW_PKG_DIR) - os.mkdir(self.to_dir) - self.dvc.pkg.install(self.pkg_fixture._root_dir, self.to_dir) + dvc_repo.pkg.install(pkg.root_dir) + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - def test_installed_dirs_and_files(self): - self.assertTrue(os.path.isdir(self.pkg_dir)) + dvc_repo.pkg.install(pkg.root_dir) + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) - mod_dir_fullpath = os.path.join(self._root_dir, Package.MODULES_DIR) - self.assertTrue(os.path.isdir(mod_dir_fullpath)) + git_repo = git.Repo(mypkg_dir) + assert git_repo.active_branch.name == "master" - installed_pkg_dir = os.path.join(mod_dir_fullpath, self.GIT_PKG) - for file in [self.FOO, self.BAR, self.CODE]: - installed_file = os.path.join(installed_pkg_dir, file) - self.assertTrue(os.path.isfile(installed_file)) + dvc_repo.pkg.uninstall(name) + assert not os.path.exists(mypkg_dir) - def test_target_dir_checkout(self): - pkg_data_file = os.path.join(self.to_dir, self.DATA) - self.assertTrue(os.path.exists(pkg_data_file)) - self.assertTrue(os.path.isfile(pkg_data_file)) + dvc_repo.pkg.uninstall(name) + assert not os.path.exists(mypkg_dir) - pkg_data_dir = os.path.join(self.to_dir, self.DATA_SUB_DIR) - self.assertTrue(os.path.exists(pkg_data_dir)) - self.assertTrue(os.path.isdir(pkg_data_dir)) + +def test_uninstall_corrupted(repo_dir, dvc_repo): + name = os.path.basename("mypkg") + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) + + os.makedirs(mypkg_dir) + + dvc_repo.pkg.uninstall(name) + assert not os.path.exists(mypkg_dir) + + +def test_install_version(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) + + dvc_repo.pkg.install(pkg.root_dir, version="branch") + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) + + git_repo = git.Repo(mypkg_dir) + assert git_repo.active_branch.name == "branch" + + +def test_import(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + + src = pkg.FOO + dst = pkg.FOO + "_imported" + + dvc_repo.pkg.install(pkg.root_dir) + dvc_repo.pkg.imp(name, src, dst) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) + + +def test_import_dir(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + + src = pkg.DATA_DIR + dst = pkg.DATA_DIR + "_imported" + + dvc_repo.pkg.install(pkg.root_dir) + dvc_repo.pkg.imp(name, src, dst) + + assert os.path.exists(dst) + assert os.path.isdir(dst) + trees_equal(src, dst) + + +def test_import_url(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) + + src = pkg.FOO + dst = pkg.FOO + "_imported" + + dvc_repo.pkg.imp(pkg.root_dir, src, dst) + + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) + + +def test_import_url_version(repo_dir, dvc_repo, pkg): + name = os.path.basename(pkg.root_dir) + pkg_dir = os.path.join(repo_dir.root_dir, ".dvc", "pkg") + mypkg_dir = os.path.join(pkg_dir, name) + + src = "version" + dst = src + + dvc_repo.pkg.imp(pkg.root_dir, src, dst, version="branch") + + assert os.path.exists(pkg_dir) + assert os.path.isdir(pkg_dir) + assert os.path.exists(mypkg_dir) + assert os.path.isdir(mypkg_dir) + assert os.path.isdir(os.path.join(mypkg_dir, ".git")) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + with open(dst, "r+") as fobj: + assert fobj.read() == "branch" + + +def test_get_file(repo_dir, dvc_repo, pkg): + src = pkg.FOO + dst = pkg.FOO + "_imported" + + PkgManager.get(pkg.root_dir, src, dst) + + assert os.path.exists(dst) + assert os.path.isfile(dst) + assert filecmp.cmp(repo_dir.FOO, dst, shallow=False) + + +def test_get_dir(repo_dir, dvc_repo, pkg): + src = pkg.DATA_DIR + dst = pkg.DATA_DIR + "_imported" + + PkgManager.get(pkg.root_dir, src, dst) + + assert os.path.exists(dst) + assert os.path.isdir(dst) + trees_equal(src, dst) diff --git a/tests/unit/command/test_pkg.py b/tests/unit/command/test_pkg.py new file mode 100644 index 0000000000..0e2ebe81a6 --- /dev/null +++ b/tests/unit/command/test_pkg.py @@ -0,0 +1,84 @@ +from dvc.cli import parse_args +from dvc.command.pkg import CmdPkgInstall, CmdPkgUninstall, CmdPkgImport +from dvc.command.pkg import CmdPkgGet + + +def test_pkg_install(mocker, dvc_repo): + args = parse_args( + ["pkg", "install", "url", "--version", "version", "--name", "name"] + ) + assert args.func == CmdPkgInstall + + cmd = args.func(args) + m = mocker.patch.object(cmd.repo.pkg, "install", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with("url", version="version", name="name") + + +def test_pkg_uninstall(mocker, dvc_repo): + targets = ["target1", "target2"] + args = parse_args(["pkg", "uninstall"] + targets) + assert args.func == CmdPkgUninstall + + cmd = args.func(args) + m = mocker.patch.object(cmd.repo.pkg, "uninstall", autospec=True) + + assert cmd.run() == 0 + + calls = [] + for target in targets: + calls.append(mocker.call(name=target)) + + m.assert_has_calls(calls) + + +def test_pkg_import(mocker, dvc_repo): + cli_args = parse_args( + [ + "pkg", + "import", + "pkg_name", + "pkg_path", + "--out", + "out", + "--version", + "version", + ] + ) + assert cli_args.func == CmdPkgImport + + cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo.pkg, "imp", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with( + "pkg_name", "pkg_path", out="out", version="version" + ) + + +def test_pkg_get(mocker, dvc_repo): + cli_args = parse_args( + [ + "pkg", + "get", + "pkg_url", + "pkg_path", + "--out", + "out", + "--version", + "version", + ] + ) + assert cli_args.func == CmdPkgGet + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.pkg.PkgManager.get") + + assert cmd.run() == 0 + + m.assert_called_once_with( + "pkg_url", "pkg_path", out="out", version="version" + ) diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index dd547595e4..cc4f579faa 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -8,21 +8,9 @@ def test_status_download_optimization(mocker): """ remote = RemoteLOCAL(None, {}) - checksum_infos = [ - { - "path": "foo", - "metric": False, - "cache": True, - "persist": False, - "md5": "acbd18db4cc2f85cedef654fccc4a4d8", - }, - { - "path": "bar", - "metric": False, - "cache": True, - "persist": False, - "md5": "37b51d194a7513e45b56f6524f2d51f2", - }, + infos = [ + {"name": "foo", "md5": "acbd18db4cc2f85cedef654fccc4a4d8"}, + {"name": "bar", "md5": "37b51d194a7513e45b56f6524f2d51f2"}, ] local_exists = [ @@ -36,6 +24,6 @@ def test_status_download_optimization(mocker): other_remote.url = "other_remote" other_remote.cache_exists.return_value = [] - remote.status(checksum_infos, other_remote, download=True) + remote.status(infos, other_remote, download=True) assert other_remote.cache_exists.call_count == 0