diff --git a/dvc/config.py b/dvc/config.py index b0372fd578..e8fd3e956c 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -246,7 +246,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes def __init__(self, dvc_dir=None, validate=True): self.dvc_dir = dvc_dir - self.validate = validate + self.should_validate = validate if not dvc_dir: try: @@ -376,14 +376,10 @@ def load(self): c = self._lower(c) self.config.merge(c) - if not self.validate: + if not self.should_validate: return - d = self.config.dict() - try: - d = self.COMPILED_SCHEMA(d) - except Invalid as exc: - raise ConfigError(str(exc)) from exc + d = self.validate(self.config) self.config = configobj.ConfigObj(d, write_empty_values=True) def save(self, config=None): @@ -421,6 +417,12 @@ def _save(config): raise config.write() + def validate(self, config): + try: + return self.COMPILED_SCHEMA(config.dict()) + except Invalid as exc: + raise ConfigError(str(exc)) from exc + def unset(self, section, opt=None, level=None, force=False): """Unsets specified option and/or section in the config. @@ -485,6 +487,11 @@ def set(self, section, opt, value, level=None, force=True): ) config[section][opt] = value + + result = copy.deepcopy(self.config) + result.merge(config) + self.validate(result) + self.save(config) def get(self, section, opt=None, level=None): diff --git a/tests/func/test_config.py b/tests/func/test_config.py index f6cae78d3c..7fa4a45718 100644 --- a/tests/func/test_config.py +++ b/tests/func/test_config.py @@ -1,6 +1,8 @@ +import pytest import configobj from dvc.main import main +from dvc.config import Config, ConfigError from tests.basic_env import TestDvc @@ -33,11 +35,11 @@ def test_root(self): self.assertEqual(ret, 0) def _do_test(self, local=False): - section = "setsection" - field = "setfield" + section = "core" + field = "analytics" section_field = "{}.{}".format(section, field) - value = "setvalue" - newvalue = "setnewvalue" + value = "True" + newvalue = "False" base = ["config"] if local: @@ -83,3 +85,13 @@ def test_non_existing(self): ret = main(["config", "core.non_existing_field", "-u"]) self.assertEqual(ret, 251) + + +def test_set_invalid_key(dvc): + with pytest.raises(ConfigError, match=r"extra keys not allowed"): + dvc.config.set("core", "invalid.key", "value") + + +def test_merging_two_levels(dvc): + dvc.config.set('remote "test"', "url", "https://example.com") + dvc.config.set('remote "test"', "password", "1", level=Config.LEVEL_LOCAL) diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 97a471a51a..6ebfc19edd 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -32,7 +32,7 @@ def test(self): self.assertEqual(main(["remote", "list"]), 0) self.assertEqual( - main(["remote", "modify", remotes[0], "option", "value"]), 0 + main(["remote", "modify", remotes[0], "checksum_jobs", "1"]), 0 ) self.assertEqual(main(["remote", "remove", remotes[0]]), 0) @@ -113,7 +113,7 @@ def test(self): class TestRemoteRemove(TestDvc): def test(self): - ret = main(["config", "core.jobs", "1"]) + ret = main(["config", "core.checksum_jobs", "1"]) self.assertEqual(ret, 0) remote = "mys3"