diff --git a/dvc/api.py b/dvc/api.py index 22dccb5e11..44999af754 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -2,7 +2,6 @@ import importlib import os import sys -import copy from urllib.parse import urlparse from contextlib import contextmanager, _GeneratorContextManager as GCM import threading @@ -16,22 +15,29 @@ from dvc.external_repo import external_repo -SUMMON_SCHEMA = Schema( +SUMMON_FILE_SCHEMA = Schema( { Required("objects"): [ { Required("name"): str, "meta": dict, Required("summon"): { - Required("type"): "python", - Required("call"): str, - "args": dict, + Required("type"): str, "deps": [str], + str: object, }, } ] } ) +SUMMON_PYTHON_SCHEMA = Schema( + { + Required("type"): "python", + Required("call"): str, + "args": dict, + "deps": [str], + } +) class SummonError(DvcException): @@ -99,33 +105,67 @@ def _make_repo(repo_url, rev=None): def summon(name, repo=None, rev=None, summon_file="dvcsummon.yaml", args=None): + """Instantiate an object described in the summon file.""" + with prepare_summon( + name, repo=repo, rev=rev, summon_file=summon_file + ) as desc: + try: + summon = SUMMON_PYTHON_SCHEMA(desc.obj["summon"]) + except Invalid as exc: + raise SummonError(str(exc)) from exc + + _args = {**summon.get("args", {}), **(args or {})} + return _invoke_method(summon["call"], _args, path=desc.repo.root_dir) + + +@contextmanager +def prepare_summon(name, repo=None, rev=None, summon_file="dvcsummon.yaml"): """Instantiate an object described in the summon file.""" with _make_repo(repo, rev=rev) as _repo: try: path = os.path.join(_repo.root_dir, summon_file) - obj = _get_object_from_summon_file(name, path) - info = obj["summon"] + obj = _get_object_desc(name, path) + yield SummonDesc(_repo, obj) except SummonError as exc: raise SummonError( str(exc) + " at '{}' in '{}'".format(summon_file, repo) - ) from exc + ) from exc.__cause__ - _pull_dependencies(_repo, info.get("deps", [])) - _args = copy.deepcopy(info.get("args", {})) - _args.update(args or {}) +class SummonDesc: + def __init__(self, repo, obj): + self.repo = repo + self.obj = obj + self._pull_deps() - return _invoke_method(info["call"], _args, path=_repo.root_dir) + @property + def deps(self): + return [os.path.join(self.repo.root_dir, d) for d in self._deps] + @property + def _deps(self): + return self.obj["summon"].get("deps", []) -def _get_object_from_summon_file(name, path): + def _pull_deps(self): + if not self._deps: + return + + outs = [self.repo.find_out_by_relpath(d) for d in self._deps] + + with self.repo.state: + for out in outs: + self.repo.cloud.pull(out.get_used_cache()) + out.checkout() + + +def _get_object_desc(name, path): """ Given a summonable object's name, search for it on the given file and return its description. """ try: - with builtin_open(path, "r") as fd: - content = SUMMON_SCHEMA(ruamel.yaml.safe_load(fd.read())) + with builtin_open(path, "r") as fobj: + content = SUMMON_FILE_SCHEMA(ruamel.yaml.safe_load(fobj.read())) objects = [x for x in content["objects"] if x["name"] == name] if not objects: @@ -142,19 +182,7 @@ def _get_object_from_summon_file(name, path): except ruamel.yaml.YAMLError as exc: raise SummonError("Failed to parse summon file") from exc except Invalid as exc: - raise SummonError(str(exc)) - - -def _pull_dependencies(repo, deps): - if not deps: - return - - outs = [repo.find_out_by_relpath(dep) for dep in deps] - - with repo.state: - for out in outs: - repo.cloud.pull(out.get_used_cache()) - out.checkout() + raise SummonError(str(exc)) from exc @wrap_with(threading.Lock())