From 0fedd0786a9c2a8dae9108bd6d10c93de1440126 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Sun, 12 Jul 2020 05:41:44 +0300 Subject: [PATCH] config: use repo.tree Fixes #4188 --- dvc/config.py | 49 +++++++++++++++++++++++------------- dvc/ignore.py | 3 +++ dvc/repo/__init__.py | 4 +-- dvc/repo/tree.py | 4 +-- dvc/scm/git/tree.py | 2 +- dvc/scm/tree.py | 6 +++++ tests/unit/repo/test_repo.py | 27 +++++++++++++++++++- 7 files changed, 72 insertions(+), 23 deletions(-) diff --git a/dvc/config.py b/dvc/config.py index 43bc165168..92cfde7fe2 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -232,8 +232,10 @@ class Config(dict): CONFIG_LOCAL = "config.local" def __init__( - self, dvc_dir=None, validate=True + self, dvc_dir=None, validate=True, tree=None, ): # pylint: disable=super-init-not-called + from dvc.scm.tree import WorkingTree + self.dvc_dir = dvc_dir if not dvc_dir: @@ -246,6 +248,9 @@ def __init__( else: self.dvc_dir = os.path.abspath(os.path.realpath(dvc_dir)) + self.wtree = WorkingTree(self.dvc_dir) + self.tree = tree.tree if tree else self.wtree + self.load(validate=validate) @classmethod @@ -304,8 +309,32 @@ def load(self, validate=True): if not self["cache"].get("dir") and self.dvc_dir: self["cache"]["dir"] = os.path.join(self.dvc_dir, "cache") + def _load_config(self, level): + filename = self.files[level] + tree = self.tree if level == "repo" else self.wtree + + if tree.exists(filename): + with tree.open(filename) as fobj: + conf_obj = configobj.ConfigObj(fobj) + else: + conf_obj = configobj.ConfigObj() + return _parse_remotes(_lower_keys(conf_obj.dict())) + + def _save_config(self, level, conf_dict): + filename = self.files[level] + tree = self.tree if level == "repo" else self.wtree + + logger.debug(f"Writing '{filename}'.") + + tree.makedirs(os.path.dirname(filename), exist_ok=True) + + config = configobj.ConfigObj(_pack_remotes(conf_dict)) + with tree.open(filename, "wb") as fobj: + config.write(fobj) + config.filename = filename + def load_one(self, level): - conf = _load_config(self.files[level]) + conf = self._load_config(level) conf = self._load_paths(conf, self.files[level]) # Auto-verify sections @@ -375,7 +404,7 @@ def edit(self, level="repo"): _merge(merged_conf, conf) self.validate(merged_conf) - _save_config(self.files[level], conf) + self._save_config(level, conf) self.load() @staticmethod @@ -386,20 +415,6 @@ def validate(data): raise ConfigError(str(exc)) from None -def _load_config(filename): - conf_obj = configobj.ConfigObj(filename) - return _parse_remotes(_lower_keys(conf_obj.dict())) - - -def _save_config(filename, conf_dict): - logger.debug(f"Writing '{filename}'.") - os.makedirs(os.path.dirname(filename), exist_ok=True) - - config = configobj.ConfigObj(_pack_remotes(conf_dict)) - config.filename = filename - config.write() - - def _parse_remotes(conf): result = {"remote": {}} diff --git a/dvc/ignore.py b/dvc/ignore.py index 35d7ede2d4..946e5077e9 100644 --- a/dvc/ignore.py +++ b/dvc/ignore.py @@ -254,3 +254,6 @@ def stat(self, path): @property def hash_jobs(self): return self.tree.hash_jobs + + def makedirs(self, path, mode=0o777, exist_ok=True): + self.tree.makedirs(path, mode=mode, exist_ok=exist_ok) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 685f179eeb..23ebe38a4a 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -91,14 +91,14 @@ def __init__(self, root_dir=None, scm=None, rev=None): else: root_dir = self.find_root(root_dir) self.root_dir = os.path.abspath(os.path.realpath(root_dir)) + self.tree = WorkingTree(self.root_dir) self.dvc_dir = os.path.join(self.root_dir, self.DVC_DIR) - self.config = Config(self.dvc_dir) + self.config = Config(self.dvc_dir, tree=self.tree) if not scm: no_scm = self.config["core"].get("no_scm", False) self.scm = SCM(self.root_dir, no_scm=no_scm) - self.tree = WorkingTree(self.root_dir) self.tmp_dir = os.path.join(self.dvc_dir, "tmp") self.index_dir = os.path.join(self.tmp_dir, "index") diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index df6af99756..1b12ad68ef 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class DvcTree(BaseTree): +class DvcTree(BaseTree): # pylint:disable=abstract-method """DVC repo tree. Args: @@ -236,7 +236,7 @@ def get_file_hash(self, path_info): return out.checksum -class RepoTree(BaseTree): +class RepoTree(BaseTree): # pylint:disable=abstract-method """DVC + git-tracked files tree. Args: diff --git a/dvc/scm/git/tree.py b/dvc/scm/git/tree.py index 5b3197866d..0793d3d04a 100644 --- a/dvc/scm/git/tree.py +++ b/dvc/scm/git/tree.py @@ -19,7 +19,7 @@ def _item_basename(item): return os.path.basename(item.path) -class GitTree(BaseTree): +class GitTree(BaseTree): # pylint:disable=abstract-method """Proxies the repo file access methods to Git objects""" def __init__(self, git, rev): diff --git a/dvc/scm/tree.py b/dvc/scm/tree.py index c923cd6bee..d89f97ee98 100644 --- a/dvc/scm/tree.py +++ b/dvc/scm/tree.py @@ -37,6 +37,9 @@ def walk_files(self, top): # NOTE: os.path.join is ~5.5 times slower yield f"{root}{os.sep}{file}" + def makedirs(self, path, mode=0o777, exist_ok=True): + raise NotImplementedError + class WorkingTree(BaseTree): """Proxies the repo file access methods to working tree files""" @@ -90,6 +93,9 @@ def stat(path): def hash_jobs(self): return max(1, min(4, cpu_count() // 2)) + def makedirs(self, path, mode=0o777, exist_ok=True): + os.makedirs(path, mode=mode, exist_ok=exist_ok) + def is_working_tree(tree): return isinstance(tree, WorkingTree) or isinstance( diff --git a/tests/unit/repo/test_repo.py b/tests/unit/repo/test_repo.py index 6395d47a82..3a91e31bc4 100644 --- a/tests/unit/repo/test_repo.py +++ b/tests/unit/repo/test_repo.py @@ -3,7 +3,8 @@ import pytest from funcy import raiser -from dvc.repo import locked +from dvc.repo import NotDvcRepoError, Repo, locked +from dvc.utils.fs import remove def test_is_dvc_internal(dvc): @@ -127,3 +128,27 @@ def test_skip_graph_checks(tmp_dir, dvc, mocker, run_copy): dvc.add("quux") run_copy("quux", "quuz", single_stage=True) assert mock_collect_graph.called + + +def test_branch_config(tmp_dir, scm): + tmp_dir.scm_gen("foo", "foo", commit="init") + + scm.checkout("branch", create_new=True) + dvc = Repo.init() + with dvc.config.edit() as conf: + conf["remote"]["branch"] = {"url": "/some/path"} + scm.add([".dvc"]) + scm.commit("init dvc") + scm.checkout("master") + + remove(".dvc") + + # sanity check + with pytest.raises(NotDvcRepoError): + Repo() + + with pytest.raises(NotDvcRepoError): + Repo(scm=scm, rev="master") + + dvc = Repo(scm=scm, rev="branch") + assert dvc.config["remote"]["branch"]["url"] == "/some/path"