Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ref handling #1393

Merged
merged 4 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions dulwich/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,21 @@
parse_capability,
pkt_line,
)
from .refs import PEELED_TAG_SUFFIX, _import_remote_refs, read_info_refs
from .refs import (
PEELED_TAG_SUFFIX,
Ref,
_import_remote_refs,
_set_default_branch,
_set_head,
_set_origin_head,
read_info_refs,
split_peeled_refs,
)
from .repo import Repo

ObjectID = bytes


# url2pathname is lazily imported
url2pathname = None

Expand Down Expand Up @@ -825,8 +837,6 @@ def clone(
protocol_version: Optional[int] = None,
) -> Repo:
"""Clone a repository."""
from .refs import _set_default_branch, _set_head, _set_origin_head

if mkdir:
os.mkdir(target_path)

Expand Down Expand Up @@ -2343,7 +2353,11 @@ def _http_request(self, url, headers=None, data=None):
"""
raise NotImplementedError(self._http_request)

def _discover_references(self, service, base_url, protocol_version=None):
def _discover_references(
self, service, base_url, protocol_version=None
) -> Tuple[
Dict[Ref, ObjectID], Set[bytes], str, Dict[Ref, Ref], Dict[Ref, ObjectID]
]:
if (
protocol_version is not None
and protocol_version not in GIT_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -2413,11 +2427,10 @@ def begin_protocol_v2(proto):
self.protocol_version = server_protocol_version
if self.protocol_version == 2:
server_capabilities, resp, read, proto = begin_protocol_v2(proto)
(refs, _symrefs, _peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
return refs, server_capabilities, base_url
(refs, symrefs, peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
return refs, server_capabilities, base_url, symrefs, peeled

else:
server_capabilities = None # read_pkt_refs will find them
try:
[pkt] = list(proto.read_pkt_seq())
except ValueError as exc:
Expand Down Expand Up @@ -2446,14 +2459,21 @@ def begin_protocol_v2(proto):
server_capabilities, resp, read, proto = begin_protocol_v2(
proto
)
(
refs,
server_capabilities,
) = read_pkt_refs_v1(proto.read_pkt_seq())
return refs, server_capabilities, base_url
(refs, symrefs, peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
else:
(
refs,
server_capabilities,
) = read_pkt_refs_v1(proto.read_pkt_seq())
(refs, peeled) = split_peeled_refs(refs)
(symrefs, agent) = _extract_symrefs_and_agent(
server_capabilities
)
return refs, server_capabilities, base_url, symrefs, peeled
else:
self.protocol_version = 0 # dumb servers only support protocol v0
return read_info_refs(resp), set(), base_url
(refs, peeled) = split_peeled_refs(read_info_refs(resp))
return refs, set(), base_url, {}, peeled
finally:
resp.close()

Expand Down Expand Up @@ -2501,7 +2521,7 @@ def send_pack(self, path, update_refs, generate_pack_data, progress=None):

"""
url = self._get_url(path)
old_refs, server_capabilities, url = self._discover_references(
old_refs, server_capabilities, url, symrefs, peeled = self._discover_references(
b"git-receive-pack", url
)
(
Expand Down Expand Up @@ -2584,7 +2604,7 @@ def fetch_pack(

"""
url = self._get_url(path)
refs, server_capabilities, url = self._discover_references(
refs, server_capabilities, url, symrefs, peeled = self._discover_references(
b"git-upload-pack", url, protocol_version
)
(
Expand Down Expand Up @@ -2651,7 +2671,7 @@ def fetch_pack(
def get_refs(self, path):
"""Retrieve the current refs from a git smart server."""
url = self._get_url(path)
refs, _, _ = self._discover_references(b"git-upload-pack", url)
refs, _, _, _, _ = self._discover_references(b"git-upload-pack", url)
return refs

def get_url(self, path):
Expand Down
3 changes: 2 additions & 1 deletion dulwich/contrib/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
write_pack_object,
)
from ..protocol import TCP_GIT_PORT
from ..refs import InfoRefsContainer, read_info_refs, write_info_refs
from ..refs import InfoRefsContainer, read_info_refs, split_peeled_refs, write_info_refs
from ..repo import OBJECTDIR, BaseRepo
from ..server import Backend, TCPGitServer

Expand Down Expand Up @@ -809,6 +809,7 @@ def _load_check_ref(self, name, old_ref):
if not f:
return {}
refs = read_info_refs(f)
(refs, peeled) = split_peeled_refs(refs)
if old_ref is not None:
if refs[name] != old_ref:
return False
Expand Down
25 changes: 14 additions & 11 deletions dulwich/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,17 +585,8 @@ class InfoRefsContainer(RefsContainer):
def __init__(self, f) -> None:
self._refs = {}
self._peeled = {}
for line in f.readlines():
sha, name = line.rstrip(b"\n").split(b"\t")
if name.endswith(PEELED_TAG_SUFFIX):
name = name[:-3]
if not check_ref_format(name):
raise ValueError(f"invalid ref name {name!r}")
self._peeled[name] = sha
else:
if not check_ref_format(name):
raise ValueError(f"invalid ref name {name!r}")
self._refs[name] = sha
refs = read_info_refs(f)
(self._refs, self._peeled) = split_peeled_refs(refs)

def allkeys(self):
return self._refs.keys()
Expand Down Expand Up @@ -1175,6 +1166,18 @@ def strip_peeled_refs(refs):
}


def split_peeled_refs(refs):
"""Split peeled refs from regular refs."""
peeled = {}
regular = {}
for ref, sha in refs.items():
if ref.endswith(PEELED_TAG_SUFFIX):
peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
else:
regular[ref] = sha
return regular, peeled


def _set_origin_head(refs, origin, origin_head):
# set refs/remotes/origin/HEAD
origin_base = b"refs/remotes/" + origin + b"/"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,9 @@ def request(
# instantiate HttpGitClient with mocked pool manager
c = HttpGitClient(base_url, pool_manager=pool_manager, config=None)
# call method that detects url redirection
_, _, processed_url = c._discover_references(b"git-upload-pack", base_url)
_, _, processed_url, _, _ = c._discover_references(
b"git-upload-pack", base_url
)

# send the same request as the method above without redirection
resp = c.pool_manager.request("GET", base_url + tail, redirect=False)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
parse_symref_value,
read_packed_refs,
read_packed_refs_with_peeled,
split_peeled_refs,
strip_peeled_refs,
write_packed_refs,
)
Expand Down Expand Up @@ -814,3 +815,14 @@ class StripPeeledRefsTests(TestCase):
def test_strip_peeled_refs(self):
# Simple check of two dicts
self.assertEqual(strip_peeled_refs(self.all_refs), self.non_peeled_refs)

def test_split_peeled_refs(self):
(regular, peeled) = split_peeled_refs(self.all_refs)
self.assertEqual(regular, self.non_peeled_refs)
self.assertEqual(
peeled,
{
b"refs/tags/2.0.0": b"0749936d0956c661ac8f8d3483774509c165f89e",
b"refs/tags/1.0.0": b"a93db4b0360cc635a2b93675010bac8d101f73f0",
},
)
Loading