Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scm: implement pygit2 resolve rev and native checkout #5270

Merged
merged 12 commits into from
Jan 18, 2021
8 changes: 5 additions & 3 deletions dvc/repo/experiments/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@scm_context
def apply(repo, rev, force=False, **kwargs):
from dvc.repo.checkout import checkout as dvc_checkout
from dvc.scm.base import SCMError
from dvc.scm.base import MergeConflictError, SCMError

exps = repo.experiments

Expand All @@ -48,7 +48,7 @@ def apply(repo, rev, force=False, **kwargs):
if workspace:
try:
repo.scm.stash.apply(workspace)
except SCMError:
except MergeConflictError as exc:
# Applied experiment conflicts with user's workspace changes
if force:
# prefer applied experiment changes over prior stashed changes
Expand All @@ -57,7 +57,9 @@ def apply(repo, rev, force=False, **kwargs):
# revert applied changes and restore user's workspace
repo.scm.reset(hard=True)
repo.scm.stash.pop()
raise ApplyConflictError(rev)
raise ApplyConflictError(rev) from exc
except SCMError as exc:
raise ApplyConflictError(rev) from exc
repo.scm.stash.drop()
repo.scm.reset()

Expand Down
4 changes: 1 addition & 3 deletions dvc/scm/git/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def set_ref(
"""

@abstractmethod
def get_ref(
self, name: str, follow: Optional[bool] = True
) -> Optional[str]:
def get_ref(self, name: str, follow: bool = True) -> Optional[str]:
"""Return the value of specified ref.

If follow is false, symbolic refs will not be dereferenced.
Expand Down
2 changes: 1 addition & 1 deletion dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def set_ref(
):
raise SCMError(f"Failed to set '{name}'")

def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]:
def get_ref(self, name, follow: bool = True) -> Optional[str]:
from dulwich.refs import parse_symref_value

name_b = os.fsencode(name)
Expand Down
12 changes: 7 additions & 5 deletions dvc/scm/git/backend/gitpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,15 @@ def set_ref(
except GitCommandError as exc:
raise SCMError(f"Failed to set ref '{name}'") from exc

def get_ref(
self, name: str, follow: Optional[bool] = True
) -> Optional[str]:
def get_ref(self, name: str, follow: bool = True) -> Optional[str]:
from git.exc import GitCommandError

if name == "HEAD":
try:
return self.repo.head.commit.hexsha
if follow or self.repo.head.is_detached:
return self.repo.head.commit.hexsha
else:
return f"refs/heads/{self.repo.active_branch}"
Comment on lines +380 to +381
Copy link
Contributor

@efiop efiop Jan 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit offtopic and not important in such a small else statement, but what do you guys think about enabling no-else-after-return check in our linters? I seem to stumble upon it regularly.

Suggested change
else:
return f"refs/heads/{self.repo.active_branch}"
return f"refs/heads/{self.repo.active_branch}"

EDIT: add to our retro to discuss πŸ™‚

except (GitCommandError, ValueError):
return None
elif name.startswith("refs/heads/"):
Expand Down Expand Up @@ -495,7 +496,8 @@ def _stash_apply(self, rev: str):
try:
self.git.stash("apply", rev)
except GitCommandError as exc:
if "CONFLICT" in str(exc):
out = str(exc)
if "CONFLICT" in out or "already exists" in out:
raise MergeConflictError(
"Stash apply resulted in merge conflicts"
) from exc
Expand Down
91 changes: 82 additions & 9 deletions dvc/scm/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import BytesIO, StringIO
from typing import Callable, Iterable, List, Mapping, Optional, Tuple

from dvc.scm.base import MergeConflictError, SCMError
from dvc.scm.base import MergeConflictError, RevError, SCMError
from dvc.utils import relpath

from ..objects import GitObject
Expand Down Expand Up @@ -93,7 +93,25 @@ def commit(self, msg: str, no_verify: bool = False):
def checkout(
self, branch: str, create_new: Optional[bool] = False, **kwargs,
):
raise NotImplementedError
from pygit2 import GitError

if create_new:
commit = self.repo.revparse_single("HEAD")
new_branch = self.repo.branches.local.create(branch, commit)
self.repo.checkout(new_branch)
else:
if branch == "-":
branch = "@{-1}"
Comment on lines +103 to +104
Copy link
Contributor Author

@pmrowla pmrowla Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For whatever reason, libgit2's resolve_refish doesn't handle - shorthand, probably because it's specific to git checkout/switch and not a general git ref shorthand symbol

try:
commit, ref = self.repo.resolve_refish(branch)
except (KeyError, GitError):
raise RevError(f"unknown Git revision '{branch}'")
self.repo.checkout_tree(commit)
detach = kwargs.get("detach", False)
if ref and not detach:
self.repo.set_head(ref.name)
else:
self.repo.set_head(commit.id)

def pull(self, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -136,7 +154,24 @@ def get_rev(self) -> str:
raise NotImplementedError

def resolve_rev(self, rev: str) -> str:
raise NotImplementedError
from pygit2 import GitError

try:
commit, _ref = self.repo.resolve_refish(rev)
return str(commit.id)
except (KeyError, GitError):
pass

# Look for single exact match in remote refs
shas = {
self.get_ref(f"refs/remotes/{remote.name}/{rev}")
for remote in self.repo.remotes
} - {None}
if len(shas) > 1:
raise RevError(f"ambiguous Git revision '{rev}'")
if len(shas) == 1:
return shas.pop() # type: ignore
raise RevError(f"unknown Git revision '{rev}'")

def resolve_commit(self, rev: str) -> str:
raise NotImplementedError
Expand All @@ -158,16 +193,38 @@ def set_ref(
message: Optional[str] = None,
symbolic: Optional[bool] = False,
):
raise NotImplementedError
if old_ref and old_ref != self.get_ref(name, follow=False):
raise SCMError(f"Failed to set '{name}'")

def get_ref(self, name, follow: Optional[bool] = True) -> Optional[str]:
raise NotImplementedError
if symbolic:
ref = self.repo.create_reference_symbolic(name, new_ref, True)
else:
ref = self.repo.create_reference_direct(name, new_ref, True)
if message:
ref.set_target(new_ref, message)

def get_ref(self, name, follow: bool = True) -> Optional[str]:
from pygit2 import GIT_REF_SYMBOLIC

ref = self.repo.references.get(name)
if not ref:
return None
if follow and ref.type == GIT_REF_SYMBOLIC:
ref = ref.resolve()
return str(ref.target)

def remove_ref(self, name: str, old_ref: Optional[str] = None):
raise NotImplementedError
ref = self.repo.references.get(name)
if not ref:
raise SCMError(f"Ref '{name}' does not exist")
if old_ref and old_ref != str(ref.target):
raise SCMError(f"Failed to remove '{name}'")
ref.delete()

def iter_refs(self, base: Optional[str] = None):
raise NotImplementedError
for ref in self.repo.references:
if ref.startswith(base):
yield ref

def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
raise NotImplementedError
Expand Down Expand Up @@ -223,7 +280,20 @@ def diff(self, rev_a: str, rev_b: str, binary=False) -> str:
raise NotImplementedError

def reset(self, hard: bool = False, paths: Iterable[str] = None):
raise NotImplementedError
from pygit2 import GIT_RESET_HARD, GIT_RESET_MIXED, IndexEntry

self.repo.index.read(False)
if paths is not None:
tree = self.repo.revparse_single("HEAD").tree
for path in paths:
rel = relpath(path, self.root_dir)
obj = tree[relpath(rel, self.root_dir)]
self.repo.index.add(IndexEntry(rel, obj.oid, obj.filemode))
self.repo.index.write()
elif hard:
self.repo.reset(self.repo.head.target, GIT_RESET_HARD)
else:
self.repo.reset(self.repo.head.target, GIT_RESET_MIXED)

def checkout_index(
self,
Expand Down Expand Up @@ -299,7 +369,9 @@ def merge(
raise SCMError("Merge commit message is required")

try:
self.repo.index.read(False)
self.repo.merge(rev)
self.repo.index.write()
except GitError as exc:
raise SCMError("Merge failed") from exc

Expand All @@ -316,4 +388,5 @@ def merge(
if squash:
self.repo.reset(self.repo.head.target, GIT_RESET_MIXED)
self.repo.state_cleanup()
self.repo.index.write()
return None
2 changes: 1 addition & 1 deletion tests/func/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_post_checkout(self, tmp_dir, scm, dvc):
os.unlink("file")
scm.install()

scm.checkout("new_branch", create_new=True)
scm.gitpython.git.checkout("-b", "new_branch")
Copy link
Contributor Author

@pmrowla pmrowla Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libgit2 does not run git hooks on any command/function call, and internally we don't need to run this hook.

For this test we want to make sure that the installed DVC post-checkout hook is run when a user does command-line git checkout, not that it is run on DVC scm.checkout.


assert os.path.isfile("file")

Expand Down
106 changes: 97 additions & 9 deletions tests/unit/scm/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def test_branch_revs(tmp_dir, scm):


def test_set_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.scm_gen({"file": "1"}, commit="commit")
Expand Down Expand Up @@ -171,9 +168,6 @@ def test_set_ref(tmp_dir, git):


def test_get_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.gen(
Expand All @@ -192,9 +186,6 @@ def test_get_ref(tmp_dir, git):


def test_remove_ref(tmp_dir, git):
if git.test_backend == "pygit2":
pytest.skip()

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = tmp_dir.scm.get_rev()
tmp_dir.gen(os.path.join(".git", "refs", "foo", "bar"), init_rev)
Expand Down Expand Up @@ -410,3 +401,100 @@ def test_checkout_index_conflicts(tmp_dir, scm, git, strategy, expected):
else:
git.checkout_index(theirs=True)
assert (tmp_dir / "file").read_text() == expected


def test_resolve_rev(tmp_dir, scm, make_tmp_dir, git):
from dvc.scm.base import RevError

if git.test_backend == "dulwich":
pytest.skip()

remote_dir = make_tmp_dir("git-remote", scm=True)
url = "file://{}".format(remote_dir.resolve().as_posix())
scm.gitpython.repo.create_remote("origin", url)
scm.gitpython.repo.create_remote("upstream", url)

tmp_dir.scm_gen({"file": "0"}, commit="init")
init_rev = scm.get_rev()
tmp_dir.scm_gen({"file": "1"}, commit="1")
rev = scm.get_rev()
scm.checkout("branch", create_new=True)
tmp_dir.gen(
{
os.path.join(".git", "refs", "foo"): rev,
os.path.join(".git", "refs", "remotes", "origin", "bar"): rev,
os.path.join(".git", "refs", "remotes", "origin", "baz"): rev,
os.path.join(
".git", "refs", "remotes", "upstream", "baz"
): init_rev,
}
)

assert git.resolve_rev(rev) == rev
assert git.resolve_rev(rev[:7]) == rev
assert git.resolve_rev("HEAD") == rev
assert git.resolve_rev("branch") == rev
assert git.resolve_rev("refs/foo") == rev
assert git.resolve_rev("bar") == rev
assert git.resolve_rev("origin/baz") == rev

with pytest.raises(RevError):
git.resolve_rev("qux")

with pytest.raises(RevError):
git.resolve_rev("baz")


def test_checkout(tmp_dir, scm, git):
if git.test_backend == "dulwich":
pytest.skip()

tmp_dir.scm_gen({"foo": "foo"}, commit="foo")
foo_rev = scm.get_rev()
tmp_dir.scm_gen("foo", "bar", commit="bar")
bar_rev = scm.get_rev()

git.checkout("branch", create_new=True)
assert git.get_ref("HEAD", follow=False) == "refs/heads/branch"
assert (tmp_dir / "foo").read_text() == "bar"

git.checkout("master", detach=True)
assert git.get_ref("HEAD", follow=False) == bar_rev

git.checkout("master")
assert git.get_ref("HEAD", follow=False) == "refs/heads/master"

git.checkout(foo_rev[:7])
assert git.get_ref("HEAD", follow=False) == foo_rev
assert (tmp_dir / "foo").read_text() == "foo"


def test_reset(tmp_dir, scm, git):
if git.test_backend == "dulwich":
pytest.skip()

tmp_dir.scm_gen({"foo": "foo"}, commit="init")

tmp_dir.gen("foo", "bar")
scm.add(["foo"])
git.reset()
assert (tmp_dir / "foo").read_text() == "bar"
staged, unstaged, _ = scm.status()
assert len(staged) == 0
assert list(unstaged) == ["foo"]

scm.add(["foo"])
git.reset(hard=True)
assert (tmp_dir / "foo").read_text() == "foo"
staged, unstaged, _ = scm.status()
assert len(staged) == 0
assert len(unstaged) == 0

tmp_dir.gen({"foo": "bar", "bar": "bar"})
scm.add(["foo", "bar"])
git.reset(paths=["foo"])
assert (tmp_dir / "foo").read_text() == "bar"
assert (tmp_dir / "bar").read_text() == "bar"
staged, unstaged, _ = scm.status()
assert len(staged) == 1
assert len(unstaged) == 1