Skip to content

Commit

Permalink
Add more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmer committed May 10, 2023
1 parent 574e529 commit e6ccc98
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 49 deletions.
12 changes: 7 additions & 5 deletions dulwich/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class ReportStatusParser:
def __init__(self) -> None:
self._done = False
self._pack_status = None
self._ref_statuses = []
self._ref_statuses: List[bytes] = []

def check(self):
"""Check if there were any errors and, if so, raise exceptions.
Expand Down Expand Up @@ -427,8 +427,8 @@ def _read_shallow_updates(pkt_seq):
class _v1ReceivePackHeader:

def __init__(self, capabilities, old_refs, new_refs) -> None:
self.want = []
self.have = []
self.want: List[bytes] = []
self.have: List[bytes] = []
self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
self.sent_capabilities = False

Expand Down Expand Up @@ -646,7 +646,7 @@ def __init__(
to
"""
self._report_activity = report_activity
self._report_status_parser = None
self._report_status_parser: Optional[ReportStatusParser] = None
self._fetch_capabilities = set(UPLOAD_CAPABILITIES)
self._fetch_capabilities.add(capability_agent())
self._send_capabilities = set(RECEIVE_CAPABILITIES)
Expand Down Expand Up @@ -915,6 +915,7 @@ def progress(x):
pass

if CAPABILITY_REPORT_STATUS in capabilities:
assert self._report_status_parser is not None
pktline_parser = PktLineParser(self._report_status_parser.handle_packet)
for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
if chan == SIDE_BAND_CHANNEL_DATA:
Expand All @@ -927,6 +928,7 @@ def progress(x):
"Invalid sideband channel %d" % chan)
else:
if CAPABILITY_REPORT_STATUS in capabilities:
assert self._report_status_parser
for pkt in proto.read_pkt_seq():
self._report_status_parser.handle_packet(pkt)
if self._report_status_parser is not None:
Expand Down Expand Up @@ -1729,7 +1731,7 @@ def __init__(
"GIT_SSH_COMMAND", os.environ.get("GIT_SSH")
)
super().__init__(**kwargs)
self.alternative_paths = {}
self.alternative_paths: Dict[bytes, bytes] = {}
if vendor is not None:
self.ssh_vendor = vendor
else:
Expand Down
6 changes: 4 additions & 2 deletions dulwich/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
Tuple,
Union,
overload,
Any,
Dict,
)

from .file import GitFile
Expand All @@ -60,8 +62,8 @@ def lower_key(key):
class CaseInsensitiveOrderedMultiDict(MutableMapping):

def __init__(self) -> None:
self._real = []
self._keyed = {}
self._real: List[Any] = []
self._keyed: Dict[Any, Any] = {}

@classmethod
def make(cls, dict_in=None):
Expand Down
7 changes: 4 additions & 3 deletions dulwich/fastexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Fast export/import functionality."""

import stat
from typing import Dict, Tuple

from fastimport import commands, parser, processor
from fastimport import errors as fastimport_errors
Expand All @@ -42,7 +43,7 @@ class GitFastExporter:
def __init__(self, outf, store) -> None:
self.outf = outf
self.store = store
self.markers = {}
self.markers: Dict[bytes, bytes] = {}
self._marker_idx = 0

def print_cmd(self, cmd):
Expand Down Expand Up @@ -125,8 +126,8 @@ def __init__(self, repo, params=None, verbose=False, outf=None) -> None:
processor.ImportProcessor.__init__(self, params, verbose)
self.repo = repo
self.last_commit = ZERO_SHA
self.markers = {}
self._contents = {}
self.markers: Dict[bytes, bytes] = {}
self._contents: Dict[bytes, Tuple[int, bytes]] = {}

def lookup_object(self, objectish):
if objectish.startswith(b":"):
Expand Down
8 changes: 5 additions & 3 deletions dulwich/greenthreads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import gevent
from gevent import pool

from typing import Set, Tuple, Optional, FrozenSet

from .object_store import (
MissingObjectFinder,
_collect_ancestors,
_collect_filetree_revs,
)
from .objects import Commit, Tag
from .objects import Commit, Tag, ObjectID


def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
Expand Down Expand Up @@ -89,7 +91,7 @@ def collect_tree_sha(sha):

have_commits, have_tags = _split_commits_and_tags(object_store, haves, ignore_unknown=True, pool=p)
want_commits, want_tags = _split_commits_and_tags(object_store, wants, ignore_unknown=False, pool=p)
all_ancestors = _collect_ancestors(object_store, have_commits)[0]
all_ancestors: FrozenSet[ObjectID] = frozenset(_collect_ancestors(object_store, have_commits)[0])
missing_commits, common_commits = _collect_ancestors(
object_store, want_commits, all_ancestors
)
Expand All @@ -101,7 +103,7 @@ def collect_tree_sha(sha):
self.sha_done.add(t)
missing_tags = want_tags.difference(have_tags)
wants = missing_commits.union(missing_tags)
self.objects_to_send = {(w, None, False) for w in wants}
self.objects_to_send: Set[Tuple[ObjectID, Optional[bytes], Optional[int], bool]] = {(w, None, 0, False) for w in wants}
if progress is None:
self.progress = lambda x: None
else:
Expand Down
4 changes: 3 additions & 1 deletion dulwich/mailmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

"""Mailmap file reader."""

from typing import Dict, Tuple, Optional


def parse_identity(text):
# TODO(jelmer): Integrate this with dulwich.fastexport.split_email and
Expand Down Expand Up @@ -62,7 +64,7 @@ class Mailmap:
"""Class for accessing a mailmap file."""

def __init__(self, map=None) -> None:
self._table = {}
self._table: Dict[Tuple[Optional[str], str], Tuple[str, str]] = {}
if map:
for (canonical_identity, from_identity) in map:
self.add_entry(canonical_identity, from_identity)
Expand Down
21 changes: 13 additions & 8 deletions dulwich/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import (
Callable,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -360,7 +361,7 @@ def close(self):

class PackBasedObjectStore(BaseObjectStore):
def __init__(self, pack_compression_level=-1) -> None:
self._pack_cache = {}
self._pack_cache: Dict[str, Pack] = {}
self.pack_compression_level = pack_compression_level

def add_pack(
Expand Down Expand Up @@ -995,7 +996,7 @@ class MemoryObjectStore(BaseObjectStore):

def __init__(self) -> None:
super().__init__()
self._data = {}
self._data: Dict[str, ShaFile] = {}
self.pack_compression_level = -1

def _to_hexsha(self, sha):
Expand Down Expand Up @@ -1269,7 +1270,7 @@ def __init__(

# in fact, what we 'want' is commits, tags, and others
# we've found missing
self.objects_to_send = {
self.objects_to_send: Set[Tuple[ObjectID, Optional[bytes], Optional[int], bool]] = {
(w, None, Commit.type_num, False)
for w in missing_commits}
missing_tags = want_tags.difference(have_tags)
Expand All @@ -1293,7 +1294,7 @@ def get_remote_has(self):
def add_todo(self, entries: Iterable[Tuple[ObjectID, Optional[bytes], Optional[int], bool]]):
self.objects_to_send.update([e for e in entries if e[0] not in self.sha_done])

def __next__(self) -> Tuple[bytes, PackHint]:
def __next__(self) -> Tuple[bytes, Optional[PackHint]]:
while True:
if not self.objects_to_send:
self.progress(("counting objects: %d, done.\n" % len(self.sha_done)).encode("ascii"))
Expand Down Expand Up @@ -1321,7 +1322,11 @@ def __next__(self) -> Tuple[bytes, PackHint]:
self.sha_done.add(sha)
if len(self.sha_done) % 1000 == 0:
self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
return (sha, (type_num, name))
if type_num is None:
pack_hint = None
else:
pack_hint = (type_num, name)
return (sha, pack_hint)

def __iter__(self):
return self
Expand All @@ -1344,7 +1349,7 @@ def __init__(self, local_heads, get_parents, shallow=None) -> None:
"""
self.heads = set(local_heads)
self.get_parents = get_parents
self.parents = {}
self.parents: Dict[ObjectID, Optional[List[ObjectID]]] = {}
if shallow is None:
shallow = set()
self.shallow = shallow
Expand Down Expand Up @@ -1610,8 +1615,8 @@ def commit():
def _collect_ancestors(
store: ObjectContainer,
heads,
common=frozenset(),
shallow=frozenset(),
common: FrozenSet[ObjectID] = frozenset(),
shallow: FrozenSet[ObjectID] = frozenset(),
get_parents=lambda commit: commit.parents,
):
"""Collect all ancestors of heads up to (excluding) those in common.
Expand Down
10 changes: 6 additions & 4 deletions dulwich/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ class Tree(ShaFile):

def __init__(self) -> None:
super().__init__()
self._entries = {}
self._entries: Dict[bytes, Tuple[int, bytes]] = {}

@classmethod
def from_path(cls, filename):
Expand Down Expand Up @@ -1381,11 +1381,11 @@ class Commit(ShaFile):

def __init__(self) -> None:
super().__init__()
self._parents = []
self._parents: List[bytes] = []
self._encoding = None
self._mergetag = []
self._mergetag: List[Tag] = []
self._gpgsig = None
self._extra = []
self._extra: List[Tuple[bytes, bytes]] = []
self._author_timezone_neg_utc = False
self._commit_timezone_neg_utc = False

Expand All @@ -1412,6 +1412,7 @@ def _deserialize(self, chunks):
if field == _TREE_HEADER:
self._tree = value
elif field == _PARENT_HEADER:
assert value is not None
self._parents.append(value)
elif field == _AUTHOR_HEADER:
author_info = parse_time_entry(value)
Expand All @@ -1420,6 +1421,7 @@ def _deserialize(self, chunks):
elif field == _ENCODING_HEADER:
self._encoding = value
elif field == _MERGETAG_HEADER:
assert value is not None
self._mergetag.append(Tag.from_string(value + b"\n"))
elif field == _GPGSIG_HEADER:
self._gpgsig = value
Expand Down
9 changes: 5 additions & 4 deletions dulwich/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class UnpackedObject:
obj_chunks: Optional[List[bytes]]
delta_base: Union[None, bytes, int]
decomp_chunks: List[bytes]
comp_chunks: Optional[List[bytes]]

# TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
# methods of this object.
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def __init__(self, filename, file=None, size=None) -> None:
else:
self._file = file
(version, self._num_objects) = read_pack_header(self._file.read)
self._offset_cache = LRUSizeCache(
self._offset_cache = LRUSizeCache[int, Tuple[int, OldUnpackedObject]](
1024 * 1024 * 20, compute_size=_compute_object_size
)

Expand Down Expand Up @@ -1239,7 +1240,7 @@ def iter_unpacked(self, *, include_comp: bool = False):
# Back up over unused data.
self._file.seek(-len(unused), SEEK_CUR)

def iterentries(self, progress: Optional[ProgressFn] = None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
def iterentries(self, progress=None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
"""Yield entries summarizing the contents of this pack.
Args:
Expand Down Expand Up @@ -1957,7 +1958,7 @@ class PackChunkGenerator:

def __init__(self, num_records=None, records=None, progress=None, compression_level=-1, reuse_compressed=True) -> None:
self.cs = sha1(b"")
self.entries = {}
self.entries: Dict[Union[int, bytes], Tuple[int, int]] = {}
self._it = self._pack_data_chunks(
num_records=num_records, records=records, progress=progress, compression_level=compression_level, reuse_compressed=reuse_compressed)

Expand Down Expand Up @@ -2607,7 +2608,7 @@ def extend_pack(f: BinaryIO, object_ids: Set[ObjectID], get_raw, *, compression_


try:
from dulwich._pack import (
from dulwich._pack import ( # type: ignore # noqa: F811
apply_delta, # type: ignore # noqa: F811
bisect_find_sha, # type: ignore # noqa: F811
)
Expand Down
6 changes: 3 additions & 3 deletions dulwich/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import warnings
from contextlib import suppress
from typing import Dict, Optional
from typing import Dict, Optional, Set, Any

from .errors import PackedRefsException, RefFormatError
from .file import GitFile, ensure_dir_exists
Expand Down Expand Up @@ -442,8 +442,8 @@ class DictRefsContainer(RefsContainer):
def __init__(self, refs, logger=None) -> None:
super().__init__(logger=logger)
self._refs = refs
self._peeled = {}
self._watchers = set()
self._peeled: Dict[bytes, ObjectID] = {}
self._watchers: Set[Any] = set()

def allkeys(self):
return self._refs.keys()
Expand Down
7 changes: 4 additions & 3 deletions dulwich/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Set,
Tuple,
Union,
Any
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1797,10 +1798,10 @@ class MemoryRepo(BaseRepo):
def __init__(self) -> None:
from .config import ConfigFile

self._reflog = []
self._reflog: List[Any] = []
refs_container = DictRefsContainer({}, logger=self._append_reflog)
BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
self._named_files = {}
BaseRepo.__init__(self, MemoryObjectStore(), refs_container) # type: ignore
self._named_files: Dict[str, bytes] = {}
self.bare = True
self._config = ConfigFile()
self._description = None
Expand Down
8 changes: 4 additions & 4 deletions dulwich/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class PackHandler(Handler):

def __init__(self, backend, proto, stateless_rpc=False) -> None:
super().__init__(backend, proto, stateless_rpc)
self._client_capabilities = None
self._client_capabilities: Optional[Set[bytes]] = None
# Flags needed for the no-done capability
self._done_received = False

Expand Down Expand Up @@ -763,7 +763,7 @@ class SingleAckGraphWalkerImpl:

def __init__(self, walker) -> None:
self.walker = walker
self._common = []
self._common: List[bytes] = []

def ack(self, have_ref):
if not self._common:
Expand Down Expand Up @@ -808,7 +808,7 @@ class MultiAckGraphWalkerImpl:
def __init__(self, walker) -> None:
self.walker = walker
self._found_base = False
self._common = []
self._common: List[bytes] = []

def ack(self, have_ref):
self._common.append(have_ref)
Expand Down Expand Up @@ -866,7 +866,7 @@ class MultiAckDetailedGraphWalkerImpl:

def __init__(self, walker) -> None:
self.walker = walker
self._common = []
self._common: List[bytes] = []

def ack(self, have_ref):
# Should only be called iff have_ref is common
Expand Down
Loading

0 comments on commit e6ccc98

Please sign in to comment.