diff --git a/dvc/env.py b/dvc/env.py index 0ac4e46e8b..e5f8ad912b 100644 --- a/dvc/env.py +++ b/dvc/env.py @@ -7,3 +7,5 @@ DVCLIVE_HTML = "DVCLIVE_HTML" DVCLIVE_RESUME = "DVCLIVE_RESUME" DVC_IGNORE_ISATTY = "DVC_IGNORE_ISATTY" +DVC_EXP_GIT_REMOTE = "DVC_EXP_GIT_REMOTE" +DVC_EXP_AUTO_PUSH = "DVC_EXP_AUTO_PUSH" diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index f797728042..729836f152 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -16,8 +16,10 @@ from funcy import cached_property +from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE from dvc.exceptions import DvcException from dvc.path_info import PathInfo +from dvc.repo import Repo from dvc.repo.experiments.base import ( EXEC_BASELINE, EXEC_BRANCH, @@ -36,7 +38,7 @@ from dvc.stage import PipelineStage from dvc.stage.monitor import CheckpointKilledError from dvc.stage.serialize import to_lockfile -from dvc.utils import dict_sha256 +from dvc.utils import dict_sha256, env2bool from dvc.utils.fs import remove if TYPE_CHECKING: @@ -248,6 +250,19 @@ def on_diverged_ref(orig_ref: str, new_rev: str): ) return refs + @classmethod + def _validate_remotes(cls, dvc: "Repo", git_remote: Optional[str]): + + if git_remote == dvc.root_dir: + logger.warning( + f"'{git_remote}' points to the current Git repo, experiment " + "Git refs will not be pushed. But DVC cache and run cache " + "will automatically be pushed to the default DVC remote " + "(if any) on each experiment commit." + ) + dvc.scm.validate_git_remote(git_remote) + dvc.cloud.get_remote_odb() + @classmethod def reproduce( cls, @@ -270,6 +285,9 @@ def reproduce( from dvc.repo.checkout import checkout as dvc_checkout from dvc.repo.reproduce import reproduce as dvc_reproduce + auto_push = env2bool(DVC_EXP_AUTO_PUSH) + git_remote = os.getenv(DVC_EXP_GIT_REMOTE, None) + unchanged = [] if queue is not None: @@ -292,6 +310,9 @@ def filter_pipeline(stages): log_errors, **kwargs, ) as dvc: + if auto_push: + cls._validate_remotes(dvc, git_remote) + args, kwargs = cls._repro_args(dvc) if args: targets: Optional[Union[list, str]] = args[0] @@ -331,6 +352,7 @@ def filter_pipeline(stages): checkpoint_func = partial( cls.checkpoint_callback, + dvc, dvc.scm, name, repro_force or checkpoint_reset, @@ -361,6 +383,8 @@ def filter_pipeline(stages): force=repro_force, checkpoint=is_checkpoint, ) + if auto_push: + cls._auto_push(dvc, dvc.scm, git_remote) except UnchangedExperimentError: pass ref = dvc.scm.get_ref(EXEC_BRANCH, follow=False) @@ -393,7 +417,6 @@ def _repro_dvc( git_url: Optional[str] = None, **kwargs, ): - from dvc.repo import Repo from dvc.utils.serialize import modify_yaml dvc = Repo(dvc_dir) @@ -450,9 +473,32 @@ def _repro_args(cls, dvc): kwargs = {} return args, kwargs + @staticmethod + def _auto_push( + dvc: "Repo", + scm: "Git", + git_remote: Optional[str], + push_cache=True, + run_cache=True, + ): + branch = scm.get_ref(EXEC_BRANCH, follow=False) + try: + dvc.experiments.push( + git_remote, + branch, + push_cache=push_cache, + run_cache=run_cache, + ) + except BaseException as exc: + logger.warning( + "Something went wrong while auto pushing experiment " + f"to the remote '{git_remote}': {exc}" + ) + @classmethod def checkpoint_callback( cls, + dvc: "Repo", scm: "Git", name: Optional[str], force: bool, @@ -464,6 +510,10 @@ def checkpoint_callback( exp_rev = cls.commit( scm, exp_hash, exp_name=name, force=force, checkpoint=True ) + + if env2bool(DVC_EXP_AUTO_PUSH): + git_remote = os.getenv(DVC_EXP_GIT_REMOTE) + cls._auto_push(dvc, scm, git_remote) logger.info("Checkpoint experiment iteration '%s'.", exp_rev[:7]) except UnchangedExperimentError: pass diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index ca0867bd8d..7b58682e33 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -2,6 +2,7 @@ from dvc.exceptions import DvcException, InvalidArgumentError from dvc.repo import locked +from dvc.repo.experiments.base import ExpRefInfo from dvc.repo.scm_context import scm_context from .utils import exp_commits, exp_refs_by_name @@ -12,7 +13,13 @@ @locked @scm_context def push( - repo, git_remote, exp_name, *args, force=False, push_cache=False, **kwargs + repo, + git_remote, + exp_name: str, + *args, + force=False, + push_cache=False, + **kwargs, ): exp_ref = _get_exp_ref(repo, exp_name) @@ -35,9 +42,9 @@ def on_diverged(refname: str, rev: str) -> bool: _push_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, exp_name): +def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: if exp_name.startswith("refs/"): - return exp_name + return ExpRefInfo.from_ref(exp_name) exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) if not exp_refs: diff --git a/dvc/scm/base.py b/dvc/scm/base.py index cf7083bb3f..c4813ea272 100644 --- a/dvc/scm/base.py +++ b/dvc/scm/base.py @@ -41,6 +41,12 @@ class MergeConflictError(SCMError): pass +class InvalidRemoteSCMRepo(SCMError): + def __init__(self, url: str): + msg = f"'{url}' is not a valid Git remote or URL" + super().__init__(msg) + + class Base: """Base class for source control management driver implementations.""" diff --git a/dvc/scm/git/__init__.py b/dvc/scm/git/__init__.py index 8be455a546..41eef55a63 100644 --- a/dvc/scm/git/__init__.py +++ b/dvc/scm/git/__init__.py @@ -392,6 +392,7 @@ def get_fs(self, rev: str): checkout_index = partialmethod(_backend_func, "checkout_index") status = partialmethod(_backend_func, "status") merge = partialmethod(_backend_func, "merge") + validate_git_remote = partialmethod(_backend_func, "validate_git_remote") def resolve_rev(self, rev: str) -> str: from dvc.repo.experiments.utils import exp_refs_by_name diff --git a/dvc/scm/git/backend/base.py b/dvc/scm/git/backend/base.py index 3196f35907..ad950574a0 100644 --- a/dvc/scm/git/backend/base.py +++ b/dvc/scm/git/backend/base.py @@ -347,3 +347,7 @@ def merge( Returns revision of the merge commit or None if no commit was made. """ + + @abstractmethod + def validate_git_remote(self, url: str): + """Verify that url is a valid git URL or remote name.""" diff --git a/dvc/scm/git/backend/dulwich.py b/dvc/scm/git/backend/dulwich.py index c1860f85ea..12e1691426 100644 --- a/dvc/scm/git/backend/dulwich.py +++ b/dvc/scm/git/backend/dulwich.py @@ -20,7 +20,7 @@ from dvc.path_info import PathInfo from dvc.progress import Tqdm -from dvc.scm.base import SCMError +from dvc.scm.base import InvalidRemoteSCMRepo, SCMError from dvc.utils import relpath from ..objects import GitObject @@ -348,24 +348,26 @@ def iter_refs(self, base: Optional[str] = None): def iter_remote_refs(self, url: str, base: Optional[str] = None): from dulwich.client import get_transport_and_path + from dulwich.errors import NotGitRepository from dulwich.porcelain import get_remote_repo try: _remote, location = get_remote_repo(self.repo, url) client, path = get_transport_and_path(location) except Exception as exc: - raise SCMError( - f"'{url}' is not a valid Git remote or URL" - ) from exc + raise InvalidRemoteSCMRepo(url) from exc - if base: - yield from ( - os.fsdecode(ref) - for ref in client.get_refs(path) - if ref.startswith(os.fsencode(base)) - ) - else: - yield from (os.fsdecode(ref) for ref in client.get_refs(path)) + try: + if base: + yield from ( + os.fsdecode(ref) + for ref in client.get_refs(path) + if ref.startswith(os.fsencode(base)) + ) + else: + yield from (os.fsdecode(ref) for ref in client.get_refs(path)) + except NotGitRepository as exc: + raise InvalidRemoteSCMRepo(url) from exc def get_refs_containing(self, rev: str, pattern: Optional[str] = None): raise NotImplementedError @@ -642,3 +644,17 @@ def merge( squash: bool = False, ) -> Optional[str]: raise NotImplementedError + + def validate_git_remote(self, url: str): + from dulwich.client import LocalGitClient, get_transport_and_path + from dulwich.porcelain import get_remote_repo + + try: + _, location = get_remote_repo(self.repo, url) + client, path = get_transport_and_path(location) + except Exception as exc: + raise InvalidRemoteSCMRepo(url) from exc + if isinstance(client, LocalGitClient) and not os.path.exists( + os.path.join("", path) + ): + raise InvalidRemoteSCMRepo(url) diff --git a/dvc/scm/git/backend/gitpython.py b/dvc/scm/git/backend/gitpython.py index c49ee8ef7c..4c9b40896f 100644 --- a/dvc/scm/git/backend/gitpython.py +++ b/dvc/scm/git/backend/gitpython.py @@ -618,3 +618,6 @@ def merge( raise MergeConflictError("Merge contained conflicts") from exc raise SCMError("Merge failed") from exc return None + + def validate_git_remote(self, url: str): + raise NotImplementedError diff --git a/dvc/scm/git/backend/pygit2.py b/dvc/scm/git/backend/pygit2.py index aeee30579b..db30f441db 100644 --- a/dvc/scm/git/backend/pygit2.py +++ b/dvc/scm/git/backend/pygit2.py @@ -561,3 +561,6 @@ def merge( self.repo.state_cleanup() self.repo.index.write() return None + + def validate_git_remote(self, url: str): + raise NotImplementedError diff --git a/tests/func/experiments/conftest.py b/tests/func/experiments/conftest.py index 455aa687c4..fb592ae152 100644 --- a/tests/func/experiments/conftest.py +++ b/tests/func/experiments/conftest.py @@ -88,3 +88,21 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker): scm.commit("init") stage.iterations = DEFAULT_ITERATIONS return stage + + +@pytest.fixture +def git_upstream(tmp_dir, erepo_dir): + url = "file://{}".format(erepo_dir.resolve().as_posix()) + tmp_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir + + +@pytest.fixture +def git_downstream(tmp_dir, erepo_dir): + url = "file://{}".format(tmp_dir.resolve().as_posix()) + erepo_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir diff --git a/tests/func/experiments/test_checkpoints.py b/tests/func/experiments/test_checkpoints.py index 02c93cd14c..4859b462ca 100644 --- a/tests/func/experiments/test_checkpoints.py +++ b/tests/func/experiments/test_checkpoints.py @@ -1,9 +1,16 @@ +import logging + import pytest from funcy import first +from dvc.config import NoRemoteError +from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE from dvc.exceptions import DvcException from dvc.repo.experiments import MultipleBranchError from dvc.repo.experiments.base import EXEC_APPLY, EXEC_CHECKPOINT +from dvc.repo.experiments.executor.base import BaseExecutor +from dvc.repo.experiments.utils import exp_refs_by_rev +from dvc.scm.base import InvalidRemoteSCMRepo @pytest.mark.parametrize("workspace", [True, False]) @@ -188,3 +195,90 @@ def test_resume_non_head_checkpoint( ) new_head = first(results) assert orig_branch != dvc.experiments.get_branch_by_rev(new_head) + + +@pytest.fixture +def clear_env(monkeypatch): + yield + monkeypatch.delenv(DVC_EXP_GIT_REMOTE, raising=False) + monkeypatch.delenv(DVC_EXP_AUTO_PUSH, raising=False) + + +@pytest.mark.parametrize("use_url", [True, False]) +def test_auto_push_during_iterations( + tmp_dir, + scm, + dvc, + checkpoint_stage, + git_upstream, + local_remote, + use_url, + monkeypatch, + mocker, + clear_env, +): + # set up remote repo + remote = git_upstream.url if use_url else git_upstream.remote + git_upstream.scm.fetch_refspecs(str(tmp_dir), ["master:master"]) + monkeypatch.setenv(DVC_EXP_GIT_REMOTE, remote) + auto_push_spy = mocker.spy(BaseExecutor, "_auto_push") + + # without auto push + results = dvc.experiments.run(checkpoint_stage.addressing) + assert auto_push_spy.call_count == 0 + + # add auto push + monkeypatch.setenv(DVC_EXP_AUTO_PUSH, "true") + results = dvc.experiments.run(checkpoint_stage.addressing) + assert (tmp_dir / "foo").read_text() == "4" + exp = first(results) + ref_info = first(exp_refs_by_rev(scm, exp)) + assert git_upstream.scm.get_ref(str(ref_info)) == exp + + assert auto_push_spy.call_count == 2 + assert auto_push_spy.call_args[0][2] == remote + + +def test_auto_push_error_url( + dvc, scm, checkpoint_stage, local_remote, monkeypatch, clear_env +): + monkeypatch.setenv(DVC_EXP_GIT_REMOTE, "true") + monkeypatch.setenv(DVC_EXP_AUTO_PUSH, "true") + with pytest.raises(InvalidRemoteSCMRepo): + dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"]) + + +def test_auto_push_no_remote( + dvc, scm, checkpoint_stage, git_upstream, monkeypatch, clear_env +): + monkeypatch.setenv(DVC_EXP_GIT_REMOTE, git_upstream.url) + monkeypatch.setenv(DVC_EXP_AUTO_PUSH, "true") + with pytest.raises(NoRemoteError): + dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"]) + + +def test_auto_push_self_remote( + tmp_dir, + dvc, + scm, + checkpoint_stage, + local_remote, + caplog, + monkeypatch, + clear_env, +): + root_dir = str(tmp_dir) + monkeypatch.setenv(DVC_EXP_GIT_REMOTE, root_dir) + monkeypatch.setenv(DVC_EXP_AUTO_PUSH, "true") + assert ( + dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"]) + != {} + ) + + with caplog.at_level(logging.WARNING, logger="dvc.repo.experiments"): + assert ( + f"'{root_dir}' points to the current Git repo, experiment " + "Git refs will not be pushed. But DVC cache and run cache will " + "automatically be pushed to the default DVC remote (if any) " + "on each experiment commit." in caplog.text + ) diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 177476853b..dffb1970c6 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -7,24 +7,6 @@ from dvc.repo.experiments.utils import exp_refs_by_rev -@pytest.fixture -def git_upstream(tmp_dir, erepo_dir): - url = f"file://{erepo_dir.resolve().as_posix()}" - tmp_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir - - -@pytest.fixture -def git_downstream(tmp_dir, erepo_dir): - url = f"file://{tmp_dir.resolve().as_posix()}" - erepo_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir - - @pytest.mark.parametrize("use_url", [True, False]) def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url): from dvc.exceptions import InvalidArgumentError