Skip to content

Commit

Permalink
scm: implement pygit2 resolve rev and native checkout (#5270)
Browse files Browse the repository at this point in the history
* pygit2: implement get/set/remove refs

* pygit: implement iter_refs

* pygit: implement resolve_rev

* pygit2: implement working tree checkout

* pygit2: implement reset, explicitly read/write index when used

* pygit2: checkout does not need to reset index

* pygit2: handle "-" branch shorthand in checkout

* tests: install post-checkout hook should be run via gitpython

* exp apply: use stricter exception handling for merge conflict

* pygit: fix index read/write usage

* gitpython: handle untracked file conflicts in stash apply

* fix flaky/race index issue
  • Loading branch information
pmrowla authored Jan 18, 2021
1 parent 78e1232 commit 407f560
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 31 deletions.
8 changes: 5 additions & 3 deletions dvc/repo/experiments/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@scm_context
def apply(repo, rev, force=False, **kwargs):
from dvc.repo.checkout import checkout as dvc_checkout
from dvc.scm.base import SCMError
from dvc.scm.base import MergeConflictError, SCMError

exps = repo.experiments

Expand All @@ -48,7 +48,7 @@ def apply(repo, rev, force=False, **kwargs):
if workspace:
try:
repo.scm.stash.apply(workspace)
except SCMError:
except MergeConflictError as exc:
# Applied experiment conflicts with user's workspace changes
if force:
# prefer applied experiment changes over prior stashed changes
Expand All @@ -57,7 +57,9 @@ def apply(repo, rev, force=False, **kwargs):
# revert applied changes and restore user's workspace
repo.scm.reset(hard=True)
repo.scm.stash.pop()
raise ApplyConflictError(rev)
raise ApplyConflictError(rev) from exc
except SCMError as exc:
raise ApplyConflictError(rev) from exc
repo.scm.stash.drop()
repo.scm.reset()

Expand Down
4 changes: 1 addition & 3 deletions dvc/scm/git/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def set_ref(
"""

@abstractmethod
def get_ref(
self, name: str, follow: Optional[bool] = True
) -> Optional[str]:
def get_ref(self, name: str, follow: bool = True) -> Optional[str]:
"""Return the value of specified ref.
If follow is false, symbolic refs will not be dereferenced.
Expand Down
2 changes: 1 addition & 1 deletion dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def set_ref(
):
raise SCMError(f"Failed to set '{name}'")

def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]:
def get_ref(self, name, follow: bool = True) -> Optional[str]:
from dulwich.refs import parse_symref_value

name_b = os.fsencode(name)
Expand Down
12 changes: 7 additions & 5 deletions dvc/scm/git/backend/gitpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,15 @@ def set_ref(
except GitCommandError as exc:
raise SCMError(f"Failed to set ref '{name}'") from exc

def get_ref(
self, name: str, follow: Optional[bool] = True
) -> Optional[str]:
def get_ref(self, name: str, follow: bool = True) -> Optional[str]:
from git.exc import GitCommandError

if name == "HEAD":
try:
return self.repo.head.commit.hexsha
if follow or self.repo.head.is_detached:
return self.repo.head.commit.hexsha
else:
return f"refs/heads/{self.repo.active_branch}"
except (GitCommandError, ValueError):
return None
elif name.startswith("refs/heads/"):
Expand Down Expand Up @@ -495,7 +496,8 @@ def _stash_apply(self, rev: str):
try:
self.git.stash("apply", rev)
except GitCommandError as exc:
if "CONFLICT" in str(exc):
out = str(exc)
if "CONFLICT" in out or "already exists" in out:
raise MergeConflictError(
"Stash apply resulted in merge conflicts"
) from exc
Expand Down
91 changes: 82 additions & 9 deletions dvc/scm/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import BytesIO, StringIO
from typing import Callable, Iterable, List, Mapping, Optional, Tuple

from dvc.scm.base import MergeConflictError, SCMError
from dvc.scm.base import MergeConflictError, RevError, SCMError
from dvc.utils import relpath

from ..objects import GitObject
Expand Down Expand Up @@ -93,7 +93,25 @@ def commit(self, msg: str, no_verify: bool = False):
def checkout(
self, branch: str, create_new: Optional[bool] = False, **kwargs,
):
raise NotImplementedError
from pygit2 import GitError

if create_new:
commit = self.repo.revparse_single("HEAD")
new_branch = self.repo.branches.local.create(branch, commit)
self.repo.checkout(new_branch)
else:
if branch == "-":
branch = "@{-1}"
try:
commit, ref = self.repo.resolve_refish(branch)
except (KeyError, GitError):
raise RevError(f"unknown Git revision '{branch}'")
self.repo.checkout_tree(commit)
detach = kwargs.get("detach", False)
if ref and not detach:
self.repo.set_head(ref.name)
else:
self.repo.set_head(commit.id)

def pull(self, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -136,7 +154,24 @@ def get_rev(self) -> str:
raise NotImplementedError

def resolve_rev(self, rev: str) -> str:
raise NotImplementedError
from pygit2 import GitError

try:
commit, _ref = self.repo.resolve_refish(rev)
return str(commit.id)
except (KeyError, GitError):
pass

# Look for single exact match in remote refs
shas = {
self.get_ref(f"refs/remotes/{remote.name}/{rev}")
for remote in self.repo.remotes
} - {None}
if len(shas) > 1:
raise RevError(f"ambiguous Git revision '{rev}'")
if len(shas) == 1:
return shas.pop() # type: ignore
raise RevError(f"unknown Git revision '{rev}'")

def resolve_commit(self, rev: str) -> str:
raise NotImplementedError
Expand All @@ -158,16 +193,38 @@ def set_ref(
message: Optional[str] = None,
symbolic: Optional[bool] = False,
):
raise NotImplementedError
if old_ref and old_ref != self.get_ref(name, follow=False):
raise SCMError(f"Failed to set '{name}'")

def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]:
raise NotImplementedError
if symbolic:
ref = self.repo.create_reference_symbolic(name, new_ref, True)
else:
ref = self.repo.create_reference_direct(name, new_ref, True)
if message:
ref.set_target(new_ref, message)

def get_ref(self, name, follow: bool = True) -> Optional[str]:
from pygit2 import GIT_REF_SYMBOLIC

ref = self.repo.references.get(name)
if not ref:
return None
if follow and ref.type == GIT_REF_SYMBOLIC:
ref = ref.resolve()
return str(ref.target)

def remove_ref(self, name: str, old_ref: Optional[str] = None):
raise NotImplementedError
ref = self.repo.references.get(name)
if not ref:
raise SCMError(f"Ref '{name}' does not exist")
if old_ref and old_ref != str(ref.target):
raise SCMError(f"Failed to remove '{name}'")
ref.delete()

def iter_refs(self, base: Optional[str] = None):
raise NotImplementedError
for ref in self.repo.references:
if ref.startswith(base):
yield ref

def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
raise NotImplementedError
Expand Down Expand Up @@ -223,7 +280,20 @@ def diff(self, rev_a: str, rev_b: str, binary=False) -> str:
raise NotImplementedError

def reset(self, hard: bool = False, paths: Iterable[str] = None):
raise NotImplementedError
from pygit2 import GIT_RESET_HARD, GIT_RESET_MIXED, IndexEntry

self.repo.index.read(False)
if paths is not None:
tree = self.repo.revparse_single("HEAD").tree
for path in paths:
rel = relpath(path, self.root_dir)
obj = tree[relpath(rel, self.root_dir)]
self.repo.index.add(IndexEntry(rel, obj.oid, obj.filemode))
self.repo.index.write()
elif hard:
self.repo.reset(self.repo.head.target, GIT_RESET_HARD)
else:
self.repo.reset(self.repo.head.target, GIT_RESET_MIXED)

def checkout_index(
self,
Expand Down Expand Up @@ -299,7 +369,9 @@ def merge(
raise SCMError("Merge commit message is required")

try:
self.repo.index.read(False)
self.repo.merge(rev)
self.repo.index.write()
except GitError as exc:
raise SCMError("Merge failed") from exc

Expand All @@ -316,4 +388,5 @@ def merge(
if squash:
self.repo.reset(self.repo.head.target, GIT_RESET_MIXED)
self.repo.state_cleanup()
self.repo.index.write()
return None
2 changes: 1 addition & 1 deletion tests/func/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_post_checkout(self, tmp_dir, scm, dvc):
os.unlink("file")
scm.install()

scm.checkout("new_branch", create_new=True)
scm.gitpython.git.checkout("-b", "new_branch")

assert os.path.isfile("file")

Expand Down
106 changes: 97 additions & 9 deletions tests/unit/scm/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def test_branch_revs(tmp_dir, scm):


def test_set_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.scm_gen({"file": "1"}, commit="commit")
Expand Down Expand Up @@ -171,9 +168,6 @@ def test_set_ref(tmp_dir, git):


def test_get_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.gen(
Expand All @@ -192,9 +186,6 @@ def test_get_ref(tmp_dir, git):


def test_remove_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.gen(os.path.join(".git", "refs", "foo", "bar"), init_rev)
Expand Down Expand Up @@ -410,3 +401,100 @@ def test_checkout_index_conflicts(tmp_dir, scm, git, strategy, expected):
else:
git.checkout_index(theirs=True)
assert (tmp_dir / "file").read_text() == expected


def test_resolve_rev(tmp_dir, scm, make_tmp_dir, git):
from dvc.scm.base import RevError

if git.test_backend == "dulwich":
pytest.skip()

remote_dir = make_tmp_dir("git-remote", scm=True)
url = "file://{}".format(remote_dir.resolve().as_posix())
scm.gitpython.repo.create_remote("origin", url)
scm.gitpython.repo.create_remote("upstream", url)

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = scm.get_rev()
tmp_dir.scm_gen({"file": "1"}, commit="1")
rev = scm.get_rev()
scm.checkout("branch", create_new=True)
tmp_dir.gen(
{
os.path.join(".git", "refs", "foo"): rev,
os.path.join(".git", "refs", "remotes", "origin", "bar"): rev,
os.path.join(".git", "refs", "remotes", "origin", "baz"): rev,
os.path.join(
".git", "refs", "remotes", "upstream", "baz"
): init_rev,
}
)

assert git.resolve_rev(rev) == rev
assert git.resolve_rev(rev[:7]) == rev
assert git.resolve_rev("HEAD") == rev
assert git.resolve_rev("branch") == rev
assert git.resolve_rev("refs/foo") == rev
assert git.resolve_rev("bar") == rev
assert git.resolve_rev("origin/baz") == rev

with pytest.raises(RevError):
git.resolve_rev("qux")

with pytest.raises(RevError):
git.resolve_rev("baz")


def test_checkout(tmp_dir, scm, git):
if git.test_backend == "dulwich":
pytest.skip()

tmp_dir.scm_gen({"foo": "foo"}, commit="foo")
foo_rev = scm.get_rev()
tmp_dir.scm_gen("foo", "bar", commit="bar")
bar_rev = scm.get_rev()

git.checkout("branch", create_new=True)
assert git.get_ref("HEAD", follow=False) == "refs/heads/branch"
assert (tmp_dir / "foo").read_text() == "bar"

git.checkout("master", detach=True)
assert git.get_ref("HEAD", follow=False) == bar_rev

git.checkout("master")
assert git.get_ref("HEAD", follow=False) == "refs/heads/master"

git.checkout(foo_rev[:7])
assert git.get_ref("HEAD", follow=False) == foo_rev
assert (tmp_dir / "foo").read_text() == "foo"


def test_reset(tmp_dir, scm, git):
if git.test_backend == "dulwich":
pytest.skip()

tmp_dir.scm_gen({"foo": "foo"}, commit="init")

tmp_dir.gen("foo", "bar")
scm.add(["foo"])
git.reset()
assert (tmp_dir / "foo").read_text() == "bar"
staged, unstaged, _ = scm.status()
assert len(staged) == 0
assert list(unstaged) == ["foo"]

scm.add(["foo"])
git.reset(hard=True)
assert (tmp_dir / "foo").read_text() == "foo"
staged, unstaged, _ = scm.status()
assert len(staged) == 0
assert len(unstaged) == 0

tmp_dir.gen({"foo": "bar", "bar": "bar"})
scm.add(["foo", "bar"])
git.reset(paths=["foo"])
assert (tmp_dir / "foo").read_text() == "bar"
assert (tmp_dir / "bar").read_text() == "bar"
staged, unstaged, _ = scm.status()
assert len(staged) == 1
assert len(unstaged) == 1

0 comments on commit 407f560

Please sign in to comment.