diff --git a/dulwich/client.py b/dulwich/client.py index a7b9bdfc1..a79d7dd74 100644 --- a/dulwich/client.py +++ b/dulwich/client.py @@ -121,9 +121,21 @@ parse_capability, pkt_line, ) -from .refs import PEELED_TAG_SUFFIX, _import_remote_refs, read_info_refs +from .refs import ( + PEELED_TAG_SUFFIX, + Ref, + _import_remote_refs, + _set_default_branch, + _set_head, + _set_origin_head, + read_info_refs, + split_peeled_refs, +) from .repo import Repo +ObjectID = bytes + + # url2pathname is lazily imported url2pathname = None @@ -825,8 +837,6 @@ def clone( protocol_version: Optional[int] = None, ) -> Repo: """Clone a repository.""" - from .refs import _set_default_branch, _set_head, _set_origin_head - if mkdir: os.mkdir(target_path) @@ -2343,7 +2353,11 @@ def _http_request(self, url, headers=None, data=None): """ raise NotImplementedError(self._http_request) - def _discover_references(self, service, base_url, protocol_version=None): + def _discover_references( + self, service, base_url, protocol_version=None + ) -> Tuple[ + Dict[Ref, ObjectID], Set[bytes], str, Dict[Ref, Ref], Dict[Ref, ObjectID] + ]: if ( protocol_version is not None and protocol_version not in GIT_PROTOCOL_VERSIONS @@ -2413,11 +2427,10 @@ def begin_protocol_v2(proto): self.protocol_version = server_protocol_version if self.protocol_version == 2: server_capabilities, resp, read, proto = begin_protocol_v2(proto) - (refs, _symrefs, _peeled) = read_pkt_refs_v2(proto.read_pkt_seq()) - return refs, server_capabilities, base_url + (refs, symrefs, peeled) = read_pkt_refs_v2(proto.read_pkt_seq()) + return refs, server_capabilities, base_url, symrefs, peeled else: - server_capabilities = None # read_pkt_refs will find them try: [pkt] = list(proto.read_pkt_seq()) except ValueError as exc: @@ -2446,14 +2459,21 @@ def begin_protocol_v2(proto): server_capabilities, resp, read, proto = begin_protocol_v2( proto ) - ( - refs, - server_capabilities, - ) = read_pkt_refs_v1(proto.read_pkt_seq()) - return refs, server_capabilities, base_url + (refs, symrefs, peeled) = read_pkt_refs_v2(proto.read_pkt_seq()) + else: + ( + refs, + server_capabilities, + ) = read_pkt_refs_v1(proto.read_pkt_seq()) + (refs, peeled) = split_peeled_refs(refs) + (symrefs, agent) = _extract_symrefs_and_agent( + server_capabilities + ) + return refs, server_capabilities, base_url, symrefs, peeled else: self.protocol_version = 0 # dumb servers only support protocol v0 - return read_info_refs(resp), set(), base_url + (refs, peeled) = split_peeled_refs(read_info_refs(resp)) + return refs, set(), base_url, {}, peeled finally: resp.close() @@ -2501,7 +2521,7 @@ def send_pack(self, path, update_refs, generate_pack_data, progress=None): """ url = self._get_url(path) - old_refs, server_capabilities, url = self._discover_references( + old_refs, server_capabilities, url, symrefs, peeled = self._discover_references( b"git-receive-pack", url ) ( @@ -2584,7 +2604,7 @@ def fetch_pack( """ url = self._get_url(path) - refs, server_capabilities, url = self._discover_references( + refs, server_capabilities, url, symrefs, peeled = self._discover_references( b"git-upload-pack", url, protocol_version ) ( @@ -2651,7 +2671,7 @@ def fetch_pack( def get_refs(self, path): """Retrieve the current refs from a git smart server.""" url = self._get_url(path) - refs, _, _ = self._discover_references(b"git-upload-pack", url) + refs, _, _, _, _ = self._discover_references(b"git-upload-pack", url) return refs def get_url(self, path): diff --git a/dulwich/contrib/swift.py b/dulwich/contrib/swift.py index 57887ceb6..1fa48a9c0 100644 --- a/dulwich/contrib/swift.py +++ b/dulwich/contrib/swift.py @@ -59,7 +59,7 @@ write_pack_object, ) from ..protocol import TCP_GIT_PORT -from ..refs import InfoRefsContainer, read_info_refs, write_info_refs +from ..refs import InfoRefsContainer, read_info_refs, split_peeled_refs, write_info_refs from ..repo import OBJECTDIR, BaseRepo from ..server import Backend, TCPGitServer @@ -809,6 +809,7 @@ def _load_check_ref(self, name, old_ref): if not f: return {} refs = read_info_refs(f) + (refs, peeled) = split_peeled_refs(refs) if old_ref is not None: if refs[name] != old_ref: return False diff --git a/dulwich/refs.py b/dulwich/refs.py index 1a7f1ef60..fd7740898 100644 --- a/dulwich/refs.py +++ b/dulwich/refs.py @@ -585,17 +585,8 @@ class InfoRefsContainer(RefsContainer): def __init__(self, f) -> None: self._refs = {} self._peeled = {} - for line in f.readlines(): - sha, name = line.rstrip(b"\n").split(b"\t") - if name.endswith(PEELED_TAG_SUFFIX): - name = name[:-3] - if not check_ref_format(name): - raise ValueError(f"invalid ref name {name!r}") - self._peeled[name] = sha - else: - if not check_ref_format(name): - raise ValueError(f"invalid ref name {name!r}") - self._refs[name] = sha + refs = read_info_refs(f) + (self._refs, self._peeled) = split_peeled_refs(refs) def allkeys(self): return self._refs.keys() @@ -1175,6 +1166,18 @@ def strip_peeled_refs(refs): } +def split_peeled_refs(refs): + """Split peeled refs from regular refs.""" + peeled = {} + regular = {} + for ref, sha in refs.items(): + if ref.endswith(PEELED_TAG_SUFFIX): + peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha + else: + regular[ref] = sha + return regular, peeled + + def _set_origin_head(refs, origin, origin_head): # set refs/remotes/origin/HEAD origin_base = b"refs/remotes/" + origin + b"/" diff --git a/tests/test_client.py b/tests/test_client.py index 83c9c59c8..9d2947c8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1138,7 +1138,9 @@ def request( # instantiate HttpGitClient with mocked pool manager c = HttpGitClient(base_url, pool_manager=pool_manager, config=None) # call method that detects url redirection - _, _, processed_url = c._discover_references(b"git-upload-pack", base_url) + _, _, processed_url, _, _ = c._discover_references( + b"git-upload-pack", base_url + ) # send the same request as the method above without redirection resp = c.pool_manager.request("GET", base_url + tail, redirect=False) diff --git a/tests/test_refs.py b/tests/test_refs.py index 572be9d80..361e98a3c 100644 --- a/tests/test_refs.py +++ b/tests/test_refs.py @@ -38,6 +38,7 @@ parse_symref_value, read_packed_refs, read_packed_refs_with_peeled, + split_peeled_refs, strip_peeled_refs, write_packed_refs, ) @@ -814,3 +815,14 @@ class StripPeeledRefsTests(TestCase): def test_strip_peeled_refs(self): # Simple check of two dicts self.assertEqual(strip_peeled_refs(self.all_refs), self.non_peeled_refs) + + def test_split_peeled_refs(self): + (regular, peeled) = split_peeled_refs(self.all_refs) + self.assertEqual(regular, self.non_peeled_refs) + self.assertEqual( + peeled, + { + b"refs/tags/2.0.0": b"0749936d0956c661ac8f8d3483774509c165f89e", + b"refs/tags/1.0.0": b"a93db4b0360cc635a2b93675010bac8d101f73f0", + }, + )