diff --git a/scmrepo/git/backend/dulwich/__init__.py b/scmrepo/git/backend/dulwich/__init__.py index d09237c3..b3a308a6 100644 --- a/scmrepo/git/backend/dulwich/__init__.py +++ b/scmrepo/git/backend/dulwich/__init__.py @@ -3,6 +3,7 @@ import logging import os import stat +from enum import Enum from functools import partial from io import BytesIO, StringIO from typing import ( @@ -37,6 +38,13 @@ logger = logging.getLogger(__name__) +class SyncStatus(Enum): + SUCCESS = 0 + DUPLICATED = 1 + DIVERGED = 2 + FAILED = 3 + + class DulwichObject(GitObject): def __init__(self, repo, name, mode, sha): self.repo = repo @@ -491,23 +499,20 @@ def get_refs_containing(self, rev: str, pattern: Optional[str] = None): def push_refspec( self, url: str, - src: Optional[str], - dest: str, + refspecs: Union[str, Iterable[str]], force: bool = False, - on_diverged: Optional[Callable[[str, str], bool]] = None, progress: Callable[["GitProgressEvent"], None] = None, **kwargs, - ): + ) -> Mapping[str, int]: from dulwich.client import HTTPUnauthorized, get_transport_and_path from dulwich.errors import NotGitRepository, SendPackError + from dulwich.objectspec import parse_reftuples from dulwich.porcelain import ( DivergedBranches, check_diverged, get_remote_repo, ) - dest_refs, values = self._push_dest_refs(src, dest) - try: _remote, location = get_remote_repo(self.repo, url) client, path = get_transport_and_path(location, **kwargs) @@ -516,26 +521,40 @@ def push_refspec( f"'{url}' is not a valid Git remote or URL" ) from exc + change_result = {} + selected_refs = [] + def update_refs(refs): from dulwich.objects import ZERO_SHA + selected_refs.extend( + parse_reftuples(self.repo.refs, refs, refspecs, force=force) + ) new_refs = {} - for ref, value in zip(dest_refs, values): - if ref in refs and value != ZERO_SHA: - local_sha = self.repo.refs[ref] - remote_sha = refs[ref] + for (lh, rh, _) in selected_refs: + refname = os.fsdecode(rh) + if rh in refs and lh is not None: + if refs[rh] == self.repo.refs[lh]: + change_result[refname] = SyncStatus.DUPLICATED + continue try: - check_diverged(self.repo, remote_sha, local_sha) + check_diverged(self.repo, refs[rh], self.repo.refs[lh]) except DivergedBranches: if not force: - overwrite = False - if on_diverged: - overwrite = on_diverged( - os.fsdecode(ref), os.fsdecode(remote_sha) - ) - if not overwrite: - continue - new_refs[ref] = value + change_result[refname] = SyncStatus.DIVERGED + continue + except Exception: + change_result[refname] = SyncStatus.FAILED + continue + + if lh is None: + value = ZERO_SHA + else: + value = self.repo.refs[lh] + + new_refs[rh] = value + change_result[refname] = SyncStatus.SUCCESS + return new_refs try: @@ -548,38 +567,21 @@ def update_refs(refs): ), ) except (NotGitRepository, SendPackError) as exc: - raise SCMError("Git failed to push '{src}' to '{url}'") from exc + raise SCMError(f"Git failed to push ref to '{url}'") from exc except HTTPUnauthorized: raise AuthError(url) - - def _push_dest_refs( - self, src: Optional[str], dest: str - ) -> Tuple[Iterable[bytes], Iterable[bytes]]: - from dulwich.objects import ZERO_SHA - - if src is not None and src.endswith("/"): - src_b = os.fsencode(src) - keys = self.repo.refs.subkeys(src_b) - values = [self.repo.refs[b"".join([src_b, key])] for key in keys] - dest_refs = [b"".join([os.fsencode(dest), key]) for key in keys] - else: - if src is None: - values = [ZERO_SHA] - else: - values = [self.repo.refs[os.fsencode(src)]] - dest_refs = [os.fsencode(dest)] - return dest_refs, values + return change_result def fetch_refspecs( self, url: str, - refspecs: Iterable[str], + refspecs: Union[str, Iterable[str]], force: Optional[bool] = False, - on_diverged: Optional[Callable[[str, str], bool]] = None, progress: Callable[["GitProgressEvent"], None] = None, **kwargs, - ): + ) -> Mapping[str, int]: from dulwich.client import get_transport_and_path + from dulwich.errors import NotGitRepository from dulwich.objectspec import parse_reftuples from dulwich.porcelain import ( DivergedBranches, @@ -594,7 +596,7 @@ def determine_wants(remote_refs): parse_reftuples( remote_refs, self.repo.refs, - [os.fsencode(refspec) for refspec in refspecs], + refspecs, force=force, ) ) @@ -612,28 +614,40 @@ def determine_wants(remote_refs): f"'{url}' is not a valid Git remote or URL" ) from exc - fetch_result = client.fetch( - path, - self.repo, - progress=DulwichProgressReporter(progress) if progress else None, - determine_wants=determine_wants, - ) + try: + fetch_result = client.fetch( + path, + self.repo, + progress=DulwichProgressReporter(progress) + if progress + else None, + determine_wants=determine_wants, + ) + except NotGitRepository as exc: + raise SCMError(f"Git failed to fetch ref from '{url}'") from exc + + result = {} + for (lh, rh, _) in fetch_refs: - try: - if rh in self.repo.refs: + refname = os.fsdecode(rh) + if rh in self.repo.refs: + if self.repo.refs[rh] == fetch_result.refs[lh]: + result[refname] = SyncStatus.DUPLICATED + continue + try: check_diverged( self.repo, self.repo.refs[rh], fetch_result.refs[lh] ) - except DivergedBranches: - if not force: - overwrite = False - if on_diverged: - overwrite = on_diverged( - os.fsdecode(rh), os.fsdecode(fetch_result.refs[lh]) - ) - if not overwrite: + except DivergedBranches: + if not force: + result[refname] = SyncStatus.DIVERGED continue + except Exception: + result[refname] = SyncStatus.FAILED + continue self.repo.refs[rh] = fetch_result.refs[lh] + result[refname] = SyncStatus.SUCCESS + return result def _stash_iter(self, ref: str): stash = self._get_stash(ref) diff --git a/tests/test_git.py b/tests/test_git.py index bf51bb03..ecf57ab6 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -334,13 +334,20 @@ def test_push_refspec( remote_git_dir: TmpDir, use_url: str, ): + from scmrepo.git.backend.dulwich import SyncStatus + tmp_dir.gen({"file": "0"}) scm.add_commit("file", message="init") init_rev = scm.get_rev() + scm.add_commit("file", message="bar") + bar_rev = scm.get_rev() + scm.checkout(init_rev) + scm.add_commit("file", message="baz") + baz_rev = scm.get_rev() tmp_dir.gen( { - os.path.join(".git", "refs", "foo", "bar"): init_rev, - os.path.join(".git", "refs", "foo", "baz"): init_rev, + os.path.join(".git", "refs", "foo", "bar"): bar_rev, + os.path.join(".git", "refs", "foo", "baz"): baz_rev, } ) @@ -349,53 +356,89 @@ def test_push_refspec( scm.gitpython.repo.create_remote("origin", url) with pytest.raises(SCMError): - git.push_refspec("bad-remote", "refs/foo/bar", "refs/foo/bar") + git.push_refspec("bad-remote", "refs/foo/bar:refs/foo/bar") remote = url if use_url else "origin" - git.push_refspec(remote, "refs/foo/bar", "refs/foo/bar") - assert init_rev == remote_scm.get_ref("refs/foo/bar") + assert git.push_refspec(remote, "refs/foo/bar:refs/foo/bar") == { + "refs/foo/bar": SyncStatus.SUCCESS + } + assert bar_rev == remote_scm.get_ref("refs/foo/bar") remote_scm.checkout("refs/foo/bar") - assert init_rev == remote_scm.get_rev() + assert bar_rev == remote_scm.get_rev() assert (remote_git_dir / "file").read_text() == "0" - git.push_refspec(remote, "refs/foo/", "refs/foo/") - assert init_rev == remote_scm.get_ref("refs/foo/baz") - - git.push_refspec(remote, None, "refs/foo/baz") + assert git.push_refspec( + remote, ["refs/foo/bar:refs/foo/bar", "refs/foo/baz:refs/foo/baz"] + ) == { + "refs/foo/bar": SyncStatus.DUPLICATED, + "refs/foo/baz": SyncStatus.SUCCESS, + } + assert baz_rev == remote_scm.get_ref("refs/foo/baz") + + assert git.push_refspec(remote, ["refs/foo/bar:refs/foo/baz"]) == { + "refs/foo/baz": SyncStatus.DIVERGED + } + assert baz_rev == remote_scm.get_ref("refs/foo/baz") + + assert git.push_refspec(remote, ":refs/foo/baz") == { + "refs/foo/baz": SyncStatus.SUCCESS + } assert remote_scm.get_ref("refs/foo/baz") is None @pytest.mark.skip_git_backend("pygit2", "gitpython") +@pytest.mark.parametrize("use_url", [True, False]) def test_fetch_refspecs( + tmp_dir: TmpDir, scm: Git, git: Git, remote_git_dir: TmpDir, + use_url: bool, ): - url = f"file://{remote_git_dir.resolve().as_posix()}" + from scmrepo.git.backend.dulwich import SyncStatus + + url = f"file://{remote_git_dir.resolve().as_posix()}" + scm.gitpython.repo.create_remote("origin", url) remote_scm = Git(remote_git_dir) remote_git_dir.gen("file", "0") - remote_scm.add_commit("file", message="init") + remote_scm.add_commit("file", message="init") init_rev = remote_scm.get_rev() - + remote_scm.add_commit("file", message="bar") + bar_rev = remote_scm.get_rev() + remote_scm.checkout(init_rev) + remote_scm.add_commit("file", message="baz") + baz_rev = remote_scm.get_rev() remote_git_dir.gen( { - os.path.join(".git", "refs", "foo", "bar"): init_rev, - os.path.join(".git", "refs", "foo", "baz"): init_rev, + os.path.join(".git", "refs", "foo", "bar"): bar_rev, + os.path.join(".git", "refs", "foo", "baz"): baz_rev, } ) - git.fetch_refspecs( - url, ["refs/foo/bar:refs/foo/bar", "refs/foo/baz:refs/foo/baz"] - ) - assert init_rev == scm.get_ref("refs/foo/bar") - assert init_rev == scm.get_ref("refs/foo/baz") + with pytest.raises(SCMError): + git.fetch_refspecs("bad-remote", "refs/foo/bar:refs/foo/bar") - remote_scm.checkout("refs/foo/bar") - assert init_rev == remote_scm.get_rev() - assert (remote_git_dir / "file").read_text() == "0" + remote = url if use_url else "origin" + assert git.fetch_refspecs(remote, "refs/foo/bar:refs/foo/bar") == { + "refs/foo/bar": SyncStatus.SUCCESS + } + assert bar_rev == scm.get_ref("refs/foo/bar") + + assert git.fetch_refspecs( + remote, ["refs/foo/bar:refs/foo/bar", "refs/foo/baz:refs/foo/baz"] + ) == { + "refs/foo/bar": SyncStatus.DUPLICATED, + "refs/foo/baz": SyncStatus.SUCCESS, + } + assert baz_rev == scm.get_ref("refs/foo/baz") + + assert git.fetch_refspecs(remote, ["refs/foo/bar:refs/foo/baz"]) == { + "refs/foo/baz": SyncStatus.DIVERGED + } + assert baz_rev == scm.get_ref("refs/foo/baz") @pytest.mark.skip_git_backend("dulwich", "pygit2")