diff --git a/dvc/repo/experiments/apply.py b/dvc/repo/experiments/apply.py index a8906d5afd..ed660dc0b5 100644 --- a/dvc/repo/experiments/apply.py +++ b/dvc/repo/experiments/apply.py @@ -21,7 +21,7 @@ @scm_context def apply(repo, rev, force=False, **kwargs): from dvc.repo.checkout import checkout as dvc_checkout - from dvc.scm.base import SCMError + from dvc.scm.base import MergeConflictError, SCMError exps = repo.experiments @@ -48,7 +48,7 @@ def apply(repo, rev, force=False, **kwargs): if workspace: try: repo.scm.stash.apply(workspace) - except SCMError: + except MergeConflictError as exc: # Applied experiment conflicts with user's workspace changes if force: # prefer applied experiment changes over prior stashed changes @@ -57,7 +57,9 @@ def apply(repo, rev, force=False, **kwargs): # revert applied changes and restore user's workspace repo.scm.reset(hard=True) repo.scm.stash.pop() - raise ApplyConflictError(rev) + raise ApplyConflictError(rev) from exc + except SCMError as exc: + raise ApplyConflictError(rev) from exc repo.scm.stash.drop() repo.scm.reset() diff --git a/dvc/scm/git/backend/base.py b/dvc/scm/git/backend/base.py index fe3eba2f3e..d2bc24935a 100644 --- a/dvc/scm/git/backend/base.py +++ b/dvc/scm/git/backend/base.py @@ -151,9 +151,7 @@ def set_ref( """ @abstractmethod - def get_ref( - self, name: str, follow: Optional[bool] = True - ) -> Optional[str]: + def get_ref(self, name: str, follow: bool = True) -> Optional[str]: """Return the value of specified ref. If follow is false, symbolic refs will not be dereferenced. diff --git a/dvc/scm/git/backend/dulwich.py b/dvc/scm/git/backend/dulwich.py index 074540ec25..ea01d09ed3 100644 --- a/dvc/scm/git/backend/dulwich.py +++ b/dvc/scm/git/backend/dulwich.py @@ -262,7 +262,7 @@ def set_ref( ): raise SCMError(f"Failed to set '{name}'") - def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]: + def get_ref(self, name, follow: bool = True) -> Optional[str]: from dulwich.refs import parse_symref_value name_b = os.fsencode(name) diff --git a/dvc/scm/git/backend/gitpython.py b/dvc/scm/git/backend/gitpython.py index 058ec0c8bf..091749485c 100644 --- a/dvc/scm/git/backend/gitpython.py +++ b/dvc/scm/git/backend/gitpython.py @@ -370,14 +370,15 @@ def set_ref( except GitCommandError as exc: raise SCMError(f"Failed to set ref '{name}'") from exc - def get_ref( - self, name: str, follow: Optional[bool] = True - ) -> Optional[str]: + def get_ref(self, name: str, follow: bool = True) -> Optional[str]: from git.exc import GitCommandError if name == "HEAD": try: - return self.repo.head.commit.hexsha + if follow or self.repo.head.is_detached: + return self.repo.head.commit.hexsha + else: + return f"refs/heads/{self.repo.active_branch}" except (GitCommandError, ValueError): return None elif name.startswith("refs/heads/"): @@ -495,7 +496,8 @@ def _stash_apply(self, rev: str): try: self.git.stash("apply", rev) except GitCommandError as exc: - if "CONFLICT" in str(exc): + out = str(exc) + if "CONFLICT" in out or "already exists" in out: raise MergeConflictError( "Stash apply resulted in merge conflicts" ) from exc diff --git a/dvc/scm/git/backend/pygit2.py b/dvc/scm/git/backend/pygit2.py index 0a058e5deb..7fd16830c4 100644 --- a/dvc/scm/git/backend/pygit2.py +++ b/dvc/scm/git/backend/pygit2.py @@ -4,7 +4,7 @@ from io import BytesIO, StringIO from typing import Callable, Iterable, List, Mapping, Optional, Tuple -from dvc.scm.base import MergeConflictError, SCMError +from dvc.scm.base import MergeConflictError, RevError, SCMError from dvc.utils import relpath from ..objects import GitObject @@ -93,7 +93,25 @@ def commit(self, msg: str, no_verify: bool = False): def checkout( self, branch: str, create_new: Optional[bool] = False, **kwargs, ): - raise NotImplementedError + from pygit2 import GitError + + if create_new: + commit = self.repo.revparse_single("HEAD") + new_branch = self.repo.branches.local.create(branch, commit) + self.repo.checkout(new_branch) + else: + if branch == "-": + branch = "@{-1}" + try: + commit, ref = self.repo.resolve_refish(branch) + except (KeyError, GitError): + raise RevError(f"unknown Git revision '{branch}'") + self.repo.checkout_tree(commit) + detach = kwargs.get("detach", False) + if ref and not detach: + self.repo.set_head(ref.name) + else: + self.repo.set_head(commit.id) def pull(self, **kwargs): raise NotImplementedError @@ -136,7 +154,24 @@ def get_rev(self) -> str: raise NotImplementedError def resolve_rev(self, rev: str) -> str: - raise NotImplementedError + from pygit2 import GitError + + try: + commit, _ref = self.repo.resolve_refish(rev) + return str(commit.id) + except (KeyError, GitError): + pass + + # Look for single exact match in remote refs + shas = { + self.get_ref(f"refs/remotes/{remote.name}/{rev}") + for remote in self.repo.remotes + } - {None} + if len(shas) > 1: + raise RevError(f"ambiguous Git revision '{rev}'") + if len(shas) == 1: + return shas.pop() # type: ignore + raise RevError(f"unknown Git revision '{rev}'") def resolve_commit(self, rev: str) -> str: raise NotImplementedError @@ -158,16 +193,38 @@ def set_ref( message: Optional[str] = None, symbolic: Optional[bool] = False, ): - raise NotImplementedError + if old_ref and old_ref != self.get_ref(name, follow=False): + raise SCMError(f"Failed to set '{name}'") - def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]: - raise NotImplementedError + if symbolic: + ref = self.repo.create_reference_symbolic(name, new_ref, True) + else: + ref = self.repo.create_reference_direct(name, new_ref, True) + if message: + ref.set_target(new_ref, message) + + def get_ref(self, name, follow: bool = True) -> Optional[str]: + from pygit2 import GIT_REF_SYMBOLIC + + ref = self.repo.references.get(name) + if not ref: + return None + if follow and ref.type == GIT_REF_SYMBOLIC: + ref = ref.resolve() + return str(ref.target) def remove_ref(self, name: str, old_ref: Optional[str] = None): - raise NotImplementedError + ref = self.repo.references.get(name) + if not ref: + raise SCMError(f"Ref '{name}' does not exist") + if old_ref and old_ref != str(ref.target): + raise SCMError(f"Failed to remove '{name}'") + ref.delete() def iter_refs(self, base: Optional[str] = None): - raise NotImplementedError + for ref in self.repo.references: + if ref.startswith(base): + yield ref def get_refs_containing(self, rev: str, pattern: Optional[str] = None): raise NotImplementedError @@ -223,7 +280,20 @@ def diff(self, rev_a: str, rev_b: str, binary=False) -> str: raise NotImplementedError def reset(self, hard: bool = False, paths: Iterable[str] = None): - raise NotImplementedError + from pygit2 import GIT_RESET_HARD, GIT_RESET_MIXED, IndexEntry + + self.repo.index.read(False) + if paths is not None: + tree = self.repo.revparse_single("HEAD").tree + for path in paths: + rel = relpath(path, self.root_dir) + obj = tree[relpath(rel, self.root_dir)] + self.repo.index.add(IndexEntry(rel, obj.oid, obj.filemode)) + self.repo.index.write() + elif hard: + self.repo.reset(self.repo.head.target, GIT_RESET_HARD) + else: + self.repo.reset(self.repo.head.target, GIT_RESET_MIXED) def checkout_index( self, @@ -299,7 +369,9 @@ def merge( raise SCMError("Merge commit message is required") try: + self.repo.index.read(False) self.repo.merge(rev) + self.repo.index.write() except GitError as exc: raise SCMError("Merge failed") from exc @@ -316,4 +388,5 @@ def merge( if squash: self.repo.reset(self.repo.head.target, GIT_RESET_MIXED) self.repo.state_cleanup() + self.repo.index.write() return None diff --git a/tests/func/test_install.py b/tests/func/test_install.py index 0b3fb39814..38ee9f04c3 100644 --- a/tests/func/test_install.py +++ b/tests/func/test_install.py @@ -50,7 +50,7 @@ def test_post_checkout(self, tmp_dir, scm, dvc): os.unlink("file") scm.install() - scm.checkout("new_branch", create_new=True) + scm.gitpython.git.checkout("-b", "new_branch") assert os.path.isfile("file") diff --git a/tests/unit/scm/test_git.py b/tests/unit/scm/test_git.py index 73a6f9fcf8..3a5b669aa6 100644 --- a/tests/unit/scm/test_git.py +++ b/tests/unit/scm/test_git.py @@ -141,9 +141,6 @@ def test_branch_revs(tmp_dir, scm): def test_set_ref(tmp_dir, git): - if git.test_backend == "pygit2": - pytest.skip() - tmp_dir.scm_gen({"file": "0"}, commit="init") init_rev = tmp_dir.scm.get_rev() tmp_dir.scm_gen({"file": "1"}, commit="commit") @@ -171,9 +168,6 @@ def test_set_ref(tmp_dir, git): def test_get_ref(tmp_dir, git): - if git.test_backend == "pygit2": - pytest.skip() - tmp_dir.scm_gen({"file": "0"}, commit="init") init_rev = tmp_dir.scm.get_rev() tmp_dir.gen( @@ -192,9 +186,6 @@ def test_get_ref(tmp_dir, git): def test_remove_ref(tmp_dir, git): - if git.test_backend == "pygit2": - pytest.skip() - tmp_dir.scm_gen({"file": "0"}, commit="init") init_rev = tmp_dir.scm.get_rev() tmp_dir.gen(os.path.join(".git", "refs", "foo", "bar"), init_rev) @@ -410,3 +401,100 @@ def test_checkout_index_conflicts(tmp_dir, scm, git, strategy, expected): else: git.checkout_index(theirs=True) assert (tmp_dir / "file").read_text() == expected + + +def test_resolve_rev(tmp_dir, scm, make_tmp_dir, git): + from dvc.scm.base import RevError + + if git.test_backend == "dulwich": + pytest.skip() + + remote_dir = make_tmp_dir("git-remote", scm=True) + url = "file://{}".format(remote_dir.resolve().as_posix()) + scm.gitpython.repo.create_remote("origin", url) + scm.gitpython.repo.create_remote("upstream", url) + + tmp_dir.scm_gen({"file": "0"}, commit="init") + init_rev = scm.get_rev() + tmp_dir.scm_gen({"file": "1"}, commit="1") + rev = scm.get_rev() + scm.checkout("branch", create_new=True) + tmp_dir.gen( + { + os.path.join(".git", "refs", "foo"): rev, + os.path.join(".git", "refs", "remotes", "origin", "bar"): rev, + os.path.join(".git", "refs", "remotes", "origin", "baz"): rev, + os.path.join( + ".git", "refs", "remotes", "upstream", "baz" + ): init_rev, + } + ) + + assert git.resolve_rev(rev) == rev + assert git.resolve_rev(rev[:7]) == rev + assert git.resolve_rev("HEAD") == rev + assert git.resolve_rev("branch") == rev + assert git.resolve_rev("refs/foo") == rev + assert git.resolve_rev("bar") == rev + assert git.resolve_rev("origin/baz") == rev + + with pytest.raises(RevError): + git.resolve_rev("qux") + + with pytest.raises(RevError): + git.resolve_rev("baz") + + +def test_checkout(tmp_dir, scm, git): + if git.test_backend == "dulwich": + pytest.skip() + + tmp_dir.scm_gen({"foo": "foo"}, commit="foo") + foo_rev = scm.get_rev() + tmp_dir.scm_gen("foo", "bar", commit="bar") + bar_rev = scm.get_rev() + + git.checkout("branch", create_new=True) + assert git.get_ref("HEAD", follow=False) == "refs/heads/branch" + assert (tmp_dir / "foo").read_text() == "bar" + + git.checkout("master", detach=True) + assert git.get_ref("HEAD", follow=False) == bar_rev + + git.checkout("master") + assert git.get_ref("HEAD", follow=False) == "refs/heads/master" + + git.checkout(foo_rev[:7]) + assert git.get_ref("HEAD", follow=False) == foo_rev + assert (tmp_dir / "foo").read_text() == "foo" + + +def test_reset(tmp_dir, scm, git): + if git.test_backend == "dulwich": + pytest.skip() + + tmp_dir.scm_gen({"foo": "foo"}, commit="init") + + tmp_dir.gen("foo", "bar") + scm.add(["foo"]) + git.reset() + assert (tmp_dir / "foo").read_text() == "bar" + staged, unstaged, _ = scm.status() + assert len(staged) == 0 + assert list(unstaged) == ["foo"] + + scm.add(["foo"]) + git.reset(hard=True) + assert (tmp_dir / "foo").read_text() == "foo" + staged, unstaged, _ = scm.status() + assert len(staged) == 0 + assert len(unstaged) == 0 + + tmp_dir.gen({"foo": "bar", "bar": "bar"}) + scm.add(["foo", "bar"]) + git.reset(paths=["foo"]) + assert (tmp_dir / "foo").read_text() == "bar" + assert (tmp_dir / "bar").read_text() == "bar" + staged, unstaged, _ = scm.status() + assert len(staged) == 1 + assert len(unstaged) == 1