Skip to content

Commit

Permalink
Some review changes on push and fetch refspec
Browse files Browse the repository at this point in the history
1. rename `push_refspec` to `push_refspecs` as it can receive a list of
   refspecs
2. Move SyncEnum to base.py
3. Modify abstract methods args of `push_refspecs` and `fetch_refspecs`
   to match the new api.
  • Loading branch information
karajan1001 committed Apr 15, 2022
1 parent 435f67d commit e2a39e4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 70 deletions.
2 changes: 1 addition & 1 deletion scmrepo/git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def add_commit(
iter_refs = partialmethod(_backend_func, "iter_refs")
iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs")
get_refs_containing = partialmethod(_backend_func, "get_refs_containing")
push_refspec = partialmethod(_backend_func, "push_refspec")
push_refspecs = partialmethod(_backend_func, "push_refspecs")
fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs")
_stash_iter = partialmethod(_backend_func, "_stash_iter")
_stash_push = partialmethod(_backend_func, "_stash_push")
Expand Down
40 changes: 16 additions & 24 deletions scmrepo/git/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -25,6 +26,13 @@ def __init__(self, func):
super().__init__(f"No valid Git backend for '{func}'")


class SyncStatus(Enum):
SUCCESS = 0
DUPLICATED = 1
DIVERGED = 2
FAILED = 3


class BaseGitBackend(ABC):
"""Base Git backend class."""

Expand Down Expand Up @@ -206,43 +214,32 @@ def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
"""Iterate over all git refs containing the specified revision."""

@abstractmethod
def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
"""Push refspec to a remote Git repo.
Args:
url: Git remote name or absolute Git URL.
src: Local refspec. If src ends with "/" it will be treated as a
prefix, and all refs inside src will be pushed using dest
as destination refspec prefix. If src is None, dest will be
deleted from the remote.
dest: Remote refspec.
refspecs: Iterable containing refspecs to fetch.
Note that this will not match subkeys.
force: If True, remote refs will be overwritten.
on_diverged: Callback function which will be called if local ref
and remote have diverged and force is False. If the callback
returns True the remote ref will be overwritten.
Callback will be of the form:
on_diverged(local_refname, remote_sha)
"""

@abstractmethod
def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
refspecs: Union[str, Iterable[str]],
force: bool = False,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
"""Fetch refspecs from a remote Git repo.
Args:
Expand All @@ -251,11 +248,6 @@ def fetch_refspecs(
refspecs: Iterable containing refspecs to fetch.
Note that this will not match subkeys.
force: If True, local refs will be overwritten.
on_diverged: Callback function which will be called if local ref
and remote have diverged and force is False. If the callback
returns True the local ref will be overwritten.
Callback will be of the form:
on_diverged(local_refname, remote_sha)
"""

@abstractmethod
Expand Down
21 changes: 5 additions & 16 deletions scmrepo/git/backend/dulwich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import stat
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
Expand All @@ -25,7 +24,7 @@
from scmrepo.utils import relpath

from ...objects import GitObject
from ..base import BaseGitBackend
from ..base import BaseGitBackend, SyncStatus

if TYPE_CHECKING:
from dulwich.repo import Repo
Expand All @@ -38,13 +37,6 @@
logger = logging.getLogger(__name__)


class SyncStatus(Enum):
SUCCESS = 0
DUPLICATED = 1
DIVERGED = 2
FAILED = 3


class DulwichObject(GitObject):
def __init__(self, repo, name, mode, sha):
self.repo = repo
Expand Down Expand Up @@ -496,14 +488,14 @@ def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
raise NotImplementedError

def push_refspec(
def push_refspecs(
self,
url: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
) -> Mapping[str, int]:
) -> Mapping[str, SyncStatus]:
from dulwich.client import HTTPUnauthorized, get_transport_and_path
from dulwich.errors import NotGitRepository, SendPackError
from dulwich.objectspec import parse_reftuples
Expand Down Expand Up @@ -543,7 +535,7 @@ def update_refs(refs):
if not force:
change_result[refname] = SyncStatus.DIVERGED
continue
except Exception:
except SendPackError:
change_result[refname] = SyncStatus.FAILED
continue

Expand Down Expand Up @@ -579,7 +571,7 @@ def fetch_refspecs(
force: Optional[bool] = False,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
) -> Mapping[str, int]:
) -> Mapping[str, SyncStatus]:
from dulwich.client import get_transport_and_path
from dulwich.errors import NotGitRepository
from dulwich.objectspec import parse_reftuples
Expand Down Expand Up @@ -642,9 +634,6 @@ def determine_wants(remote_refs):
if not force:
result[refname] = SyncStatus.DIVERGED
continue
except Exception:
result[refname] = SyncStatus.FAILED
continue
self.repo.refs[rh] = fetch_result.refs[lh]
result[refname] = SyncStatus.SUCCESS
return result
Expand Down
17 changes: 7 additions & 10 deletions scmrepo/git/backend/gitpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scmrepo.utils import relpath

from ..objects import GitCommit, GitObject
from .base import BaseGitBackend
from .base import BaseGitBackend, SyncStatus

if TYPE_CHECKING:
from scmrepo.progress import GitProgressEvent
Expand Down Expand Up @@ -474,27 +474,24 @@ def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
except GitCommandError:
pass

def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
refspecs: Union[str, Iterable[str]],
force: bool = False,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _stash_iter(self, ref: str):
Expand Down
17 changes: 7 additions & 10 deletions scmrepo/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from scmrepo.utils import relpath

from ..objects import GitCommit, GitObject
from .base import BaseGitBackend
from .base import BaseGitBackend, SyncStatus

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -414,27 +414,24 @@ def _contains(repo, ref, search_commit):
) and _contains(self.repo, ref, search_commit):
yield ref

def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
refspecs: Union[str, Iterable[str]],
force: bool = False,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _stash_iter(self, ref: str):
Expand Down
17 changes: 8 additions & 9 deletions tests/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_refs_containing(tmp_dir: TmpDir, scm: Git, git: Git):

@pytest.mark.skip_git_backend("pygit2", "gitpython")
@pytest.mark.parametrize("use_url", [True, False])
def test_push_refspec(
def test_push_refspecs(
tmp_dir: TmpDir,
scm: Git,
git: Git,
Expand Down Expand Up @@ -356,10 +356,10 @@ def test_push_refspec(
scm.gitpython.repo.create_remote("origin", url)

with pytest.raises(SCMError):
git.push_refspec("bad-remote", "refs/foo/bar:refs/foo/bar")
git.push_refspecs("bad-remote", "refs/foo/bar:refs/foo/bar")

remote = url if use_url else "origin"
assert git.push_refspec(remote, "refs/foo/bar:refs/foo/bar") == {
assert git.push_refspecs(remote, "refs/foo/bar:refs/foo/bar") == {
"refs/foo/bar": SyncStatus.SUCCESS
}
assert bar_rev == remote_scm.get_ref("refs/foo/bar")
Expand All @@ -368,20 +368,20 @@ def test_push_refspec(
assert bar_rev == remote_scm.get_rev()
assert (remote_git_dir / "file").read_text() == "0"

assert git.push_refspec(
assert git.push_refspecs(
remote, ["refs/foo/bar:refs/foo/bar", "refs/foo/baz:refs/foo/baz"]
) == {
"refs/foo/bar": SyncStatus.DUPLICATED,
"refs/foo/baz": SyncStatus.SUCCESS,
}
assert baz_rev == remote_scm.get_ref("refs/foo/baz")

assert git.push_refspec(remote, ["refs/foo/bar:refs/foo/baz"]) == {
assert git.push_refspecs(remote, ["refs/foo/bar:refs/foo/baz"]) == {
"refs/foo/baz": SyncStatus.DIVERGED
}
assert baz_rev == remote_scm.get_ref("refs/foo/baz")

assert git.push_refspec(remote, ":refs/foo/baz") == {
assert git.push_refspecs(remote, ":refs/foo/baz") == {
"refs/foo/baz": SyncStatus.SUCCESS
}
assert remote_scm.get_ref("refs/foo/baz") is None
Expand Down Expand Up @@ -905,10 +905,9 @@ async def test_git_ssh(
scm.add_commit("foo", message="init")
rev = scm.get_rev()

git.push_refspec(
git.push_refspecs(
url,
"refs/heads/master",
"refs/heads/master",
"refs/heads/master:refs/heads/master",
force=True,
key_filename=key_filename,
)
Expand Down

0 comments on commit e2a39e4

Please sign in to comment.