diff --git a/dvc/api.py b/dvc/api.py index ffb044a027..655594bf1c 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -1,34 +1,43 @@ +from builtins import open as builtin_open import importlib import os import sys -import copy from urllib.parse import urlparse from contextlib import contextmanager, _GeneratorContextManager as GCM +import threading +from funcy import wrap_with import ruamel.yaml from voluptuous import Schema, Required, Invalid from dvc.repo import Repo -from dvc.exceptions import DvcException, FileMissingError +from dvc.exceptions import DvcException from dvc.external_repo import external_repo -SUMMON_SCHEMA = Schema( +SUMMON_FILE_SCHEMA = Schema( { Required("objects"): [ { Required("name"): str, "meta": dict, Required("summon"): { - Required("type"): "python", - Required("call"): str, - "args": dict, + Required("type"): str, "deps": [str], + str: object, }, } ] } ) +SUMMON_PYTHON_SCHEMA = Schema( + { + Required("type"): "python", + Required("call"): str, + "args": dict, + "deps": [str], + } +) class SummonError(DvcException): @@ -97,32 +106,74 @@ def _make_repo(repo_url, rev=None): def summon(name, repo=None, rev=None, summon_file="dvcsummon.yaml", args=None): """Instantiate an object described in the summon file.""" + with prepare_summon( + name, repo=repo, rev=rev, summon_file=summon_file + ) as desc: + try: + summon_dict = SUMMON_PYTHON_SCHEMA(desc.obj["summon"]) + except Invalid as exc: + raise SummonError(str(exc)) from exc + + _args = {**summon_dict.get("args", {}), **(args or {})} + return _invoke_method(summon_dict["call"], _args, desc.repo.root_dir) + + +@contextmanager +def prepare_summon(name, repo=None, rev=None, summon_file="dvcsummon.yaml"): + """Does a couple of things every summon needs as a prerequisite: + clones the repo, parses the summon file and pulls the deps. + + Calling code is expected to complete the summon logic following + instructions stated in "summon" dict of the object spec. + + Returns a SummonDesc instance, which contains references to a Repo object, + named object specification and resolved paths to deps. + """ with _make_repo(repo, rev=rev) as _repo: try: path = os.path.join(_repo.root_dir, summon_file) - obj = _get_object_from_summon_file(name, path) - info = obj["summon"] + obj = _get_object_spec(name, path) + yield SummonDesc(_repo, obj) except SummonError as exc: raise SummonError( str(exc) + " at '{}' in '{}'".format(summon_file, repo) - ) from exc + ) from exc.__cause__ + + +class SummonDesc: + def __init__(self, repo, obj): + self.repo = repo + self.obj = obj + self._pull_deps() + + @property + def deps(self): + return [os.path.join(self.repo.root_dir, d) for d in self._deps] + + @property + def _deps(self): + return self.obj["summon"].get("deps", []) - _pull_dependencies(_repo, info.get("deps", [])) + def _pull_deps(self): + if not self._deps: + return - _args = copy.deepcopy(info.get("args", {})) - _args.update(args or {}) + outs = [self.repo.find_out_by_relpath(d) for d in self._deps] - return _invoke_method(info["call"], _args, path=_repo.root_dir) + with self.repo.state: + for out in outs: + self.repo.cloud.pull(out.get_used_cache()) + out.checkout() -def _get_object_from_summon_file(name, path): +def _get_object_spec(name, path): """ Given a summonable object's name, search for it on the given file and return its description. """ try: - with open(path, "r") as fobj: - content = SUMMON_SCHEMA(ruamel.yaml.safe_load(fobj.read())) + with builtin_open(path, "r") as fobj: + content = SUMMON_FILE_SCHEMA(ruamel.yaml.safe_load(fobj.read())) objects = [x for x in content["objects"] if x["name"] == name] if not objects: @@ -134,34 +185,20 @@ def _get_object_from_summon_file(name, path): return objects[0] - except FileMissingError: - raise SummonError("Summon file not found") + except FileNotFoundError as exc: + raise SummonError("Summon file not found") from exc except ruamel.yaml.YAMLError as exc: raise SummonError("Failed to parse summon file") from exc except Invalid as exc: - raise SummonError(str(exc)) - - -def _pull_dependencies(repo, deps): - if not deps: - return - - outs = [repo.find_out_by_relpath(dep) for dep in deps] - - with repo.state: - for out in outs: - repo.cloud.pull(out.get_used_cache()) - out.checkout() + raise SummonError(str(exc)) from exc +@wrap_with(threading.Lock()) def _invoke_method(call, args, path): # XXX: Some issues with this approach: - # * Not thread safe # * Import will pollute sys.modules - # * Weird errors if there is a name clash within sys.modules - - # XXX: sys.path manipulation is "theoretically" not needed - # but tests are failing for an unknown reason. + # * sys.path manipulation is "theoretically" not needed, + # but tests are failing for an unknown reason. cwd = os.getcwd() try: