diff --git a/NEWS b/NEWS index 3d71e24be..0bf0e1f3a 100644 --- a/NEWS +++ b/NEWS @@ -4,6 +4,10 @@ when creating symlinks fails due to a permission error. (Jelmer Vernooij, #1005) + * Add new ``ObjectID`` type in ``dulwich.objects``, + currently just an alias for ``bytes``. + (Jelmer Vernooij) + * Support repository format version 1. (Jelmer Vernooij, #1056) diff --git a/dulwich/objects.py b/dulwich/objects.py index fc26dd2fc..492a10ac8 100644 --- a/dulwich/objects.py +++ b/dulwich/objects.py @@ -33,6 +33,8 @@ Iterable, Union, Type, + Iterator, + List, ) import zlib from hashlib import sha1 @@ -75,6 +77,9 @@ BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----" +ObjectID = bytes + + class EmptyFileException(FileFormatException): """An unexpectedly empty file was encountered.""" @@ -153,7 +158,10 @@ def filename_to_hex(filename): def object_header(num_type: int, length: int) -> bytes: """Return an object header for the given numeric type and text length.""" - return object_class(num_type).type_name + b" " + str(length).encode("ascii") + b"\0" + cls = object_class(num_type) + if cls is None: + raise AssertionError("unsupported class type num: %d" % num_type) + return cls.type_name + b" " + str(length).encode("ascii") + b"\0" def serializable_property(name: str, docstring: Optional[str] = None): @@ -169,7 +177,7 @@ def get(obj): return property(get, set, doc=docstring) -def object_class(type): +def object_class(type: Union[bytes, int]) -> Optional[Type["ShaFile"]]: """Get the object class corresponding to the given type. Args: @@ -193,7 +201,7 @@ def check_hexsha(hex, error_msg): raise ObjectFormatException("%s %s" % (error_msg, hex)) -def check_identity(identity, error_msg): +def check_identity(identity: bytes, error_msg: str) -> None: """Check if the specified identity is valid. This will raise an exception if the identity is not valid. @@ -261,11 +269,13 @@ class ShaFile(object): __slots__ = ("_chunked_text", "_sha", "_needs_serialization") - type_name = None # type: bytes - type_num = None # type: int + _needs_serialization: bool + type_name: bytes + type_num: int + _chunked_text: Optional[List[bytes]] @staticmethod - def _parse_legacy_object_header(magic, f): + def _parse_legacy_object_header(magic, f) -> "ShaFile": """Parse a legacy object, creating it but not reading the file.""" bufsize = 1024 decomp = zlib.decompressobj() @@ -287,10 +297,10 @@ def _parse_legacy_object_header(magic, f): "Object size not an integer: %s" % exc) from exc obj_class = object_class(type_name) if not obj_class: - raise ObjectFormatException("Not a known type: %s" % type_name) + raise ObjectFormatException("Not a known type: %s" % type_name.decode('ascii')) return obj_class() - def _parse_legacy_object(self, map): + def _parse_legacy_object(self, map) -> None: """Parse a legacy object, setting the raw string.""" text = _decompress(map) header_end = text.find(b"\0") @@ -298,7 +308,8 @@ def _parse_legacy_object(self, map): raise ObjectFormatException("Invalid object header, no \\0") self.set_raw_string(text[header_end + 1 :]) - def as_legacy_object_chunks(self, compression_level=-1): + def as_legacy_object_chunks( + self, compression_level: int = -1) -> Iterator[bytes]: """Return chunks representing the object in the experimental format. Returns: List of strings @@ -309,13 +320,13 @@ def as_legacy_object_chunks(self, compression_level=-1): yield compobj.compress(chunk) yield compobj.flush() - def as_legacy_object(self, compression_level=-1): + def as_legacy_object(self, compression_level: int = -1) -> bytes: """Return string representing the object in the experimental format.""" return b"".join( self.as_legacy_object_chunks(compression_level=compression_level) ) - def as_raw_chunks(self): + def as_raw_chunks(self) -> List[bytes]: """Return chunks with serialization of the object. Returns: List of strings, not necessarily one per line @@ -324,16 +335,16 @@ def as_raw_chunks(self): self._sha = None self._chunked_text = self._serialize() self._needs_serialization = False - return self._chunked_text + return self._chunked_text # type: ignore - def as_raw_string(self): + def as_raw_string(self) -> bytes: """Return raw string with serialization of the object. Returns: String object """ return b"".join(self.as_raw_chunks()) - def __bytes__(self): + def __bytes__(self) -> bytes: """Return raw string serialization of this object.""" return self.as_raw_string() @@ -341,24 +352,27 @@ def __hash__(self): """Return unique hash for this object.""" return hash(self.id) - def as_pretty_string(self): + def as_pretty_string(self) -> bytes: """Return a string representing this object, fit for display.""" return self.as_raw_string() - def set_raw_string(self, text, sha=None): + def set_raw_string( + self, text: bytes, sha: Optional[ObjectID] = None) -> None: """Set the contents of this object from a serialized string.""" if not isinstance(text, bytes): raise TypeError("Expected bytes for text, got %r" % text) self.set_raw_chunks([text], sha) - def set_raw_chunks(self, chunks, sha=None): + def set_raw_chunks( + self, chunks: List[bytes], + sha: Optional[ObjectID] = None) -> None: """Set the contents of this object from a list of chunks.""" self._chunked_text = chunks self._deserialize(chunks) if sha is None: self._sha = None else: - self._sha = FixedSha(sha) + self._sha = FixedSha(sha) # type: ignore self._needs_serialization = False @staticmethod @@ -370,7 +384,7 @@ def _parse_object_header(magic, f): raise ObjectFormatException("Not a known type %d" % num_type) return obj_class() - def _parse_object(self, map): + def _parse_object(self, map) -> None: """Parse a new style object, setting self._text.""" # skip type and size; type must have already been determined, and # we trust zlib to fail if it's otherwise corrupted @@ -383,7 +397,7 @@ def _parse_object(self, map): self.set_raw_string(_decompress(raw)) @classmethod - def _is_legacy_object(cls, magic): + def _is_legacy_object(cls, magic: bytes) -> bool: b0 = ord(magic[0:1]) b1 = ord(magic[1:2]) word = (b0 << 8) + b1 @@ -445,7 +459,9 @@ def from_raw_string(type_num, string, sha=None): return obj @staticmethod - def from_raw_chunks(type_num, chunks, sha=None): + def from_raw_chunks( + type_num: int, chunks: List[bytes], + sha: Optional[ObjectID] = None): """Creates an object of the indicated type from the raw chunks given. Args: @@ -453,7 +469,10 @@ def from_raw_chunks(type_num, chunks, sha=None): chunks: An iterable of the raw uncompressed contents. sha: Optional known sha for the object """ - obj = object_class(type_num)() + cls = object_class(type_num) + if cls is None: + raise AssertionError("unsupported class type num: %d" % type_num) + obj = cls() obj.set_raw_chunks(chunks, sha) return obj @@ -477,7 +496,7 @@ def _check_has_member(self, member, error_msg): if getattr(self, member, None) is None: raise ObjectFormatException(error_msg) - def check(self): + def check(self) -> None: """Check this object for internal consistency. Raises: @@ -500,9 +519,9 @@ def check(self): raise ChecksumMismatch(new_sha, old_sha) def _header(self): - return object_header(self.type, self.raw_length()) + return object_header(self.type_num, self.raw_length()) - def raw_length(self): + def raw_length(self) -> int: """Returns the length of the raw string of this object.""" ret = 0 for chunk in self.as_raw_chunks(): @@ -522,25 +541,14 @@ def sha(self): def copy(self): """Create a new copy of this SHA1 object from its raw string""" - obj_class = object_class(self.get_type()) - return obj_class.from_raw_string(self.get_type(), self.as_raw_string(), self.id) + obj_class = object_class(self.type_num) + return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id) @property def id(self): """The hex SHA of this object.""" return self.sha().hexdigest().encode("ascii") - def get_type(self): - """Return the type number for this object class.""" - return self.type_num - - def set_type(self, type): - """Set the type number for this object class.""" - self.type_num = type - - # DEPRECATED: use type_num or type_name as needed. - type = property(get_type, set_type) - def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.id) @@ -621,7 +629,7 @@ def check(self): """ super(Blob, self).check() - def splitlines(self): + def splitlines(self) -> List[bytes]: """Return list of lines in this blob. This preserves the original line endings. @@ -649,7 +657,7 @@ def splitlines(self): return ret -def _parse_message(chunks): +def _parse_message(chunks: Iterable[bytes]): """Parse a message with a list of fields and a body. Args: @@ -660,7 +668,7 @@ def _parse_message(chunks): """ f = BytesIO(b"".join(chunks)) k = None - v = "" + v = b"" eof = False def _strip_last_newline(value): @@ -1596,7 +1604,7 @@ def _get_extra(self): Tag, ) -_TYPE_MAP = {} # type: Dict[Union[bytes, int], Type[ShaFile]] +_TYPE_MAP: Dict[Union[bytes, int], Type[ShaFile]] = {} for cls in OBJECT_CLASSES: _TYPE_MAP[cls.type_name] = cls diff --git a/dulwich/refs.py b/dulwich/refs.py index 026fdca3d..91159b95a 100644 --- a/dulwich/refs.py +++ b/dulwich/refs.py @@ -34,12 +34,14 @@ valid_hexsha, ZERO_SHA, Tag, + ObjectID, ) from dulwich.file import ( GitFile, ensure_dir_exists, ) +Ref = bytes HEADREF = b"HEAD" SYMREF = b"ref: " @@ -69,7 +71,7 @@ def parse_symref_value(contents): raise ValueError(contents) -def check_ref_format(refname): +def check_ref_format(refname: Ref): """Check if a refname is correctly formatted. Implements all the same rules as git-check-ref-format[1]. @@ -166,8 +168,8 @@ def get_peeled(self, name): def import_refs( self, - base: bytes, - other: Dict[bytes, bytes], + base: Ref, + other: Dict[Ref, ObjectID], committer: Optional[bytes] = None, timestamp: Optional[bytes] = None, timezone: Optional[bytes] = None, @@ -455,8 +457,8 @@ def _notify(self, ref, newsha): def set_symbolic_ref( self, - name, - other, + name: Ref, + other: Ref, committer=None, timestamp=None, timezone=None, @@ -507,8 +509,8 @@ def set_if_equals( def add_if_new( self, - name: bytes, - ref: bytes, + name: Ref, + ref: ObjectID, committer=None, timestamp=None, timezone=None, diff --git a/dulwich/repo.py b/dulwich/repo.py index 0fa4b1fd4..89439b4b2 100644 --- a/dulwich/repo.py +++ b/dulwich/repo.py @@ -741,7 +741,8 @@ def get_peeled(self, ref): return cached return self.object_store.peel_sha(self.refs[ref]).id - def get_walker(self, include=None, *args, **kwargs): + def get_walker(self, include: Optional[List[bytes]] = None, + *args, **kwargs): """Obtain a walker for this repository. Args: @@ -771,8 +772,6 @@ def get_walker(self, include=None, *args, **kwargs): if include is None: include = [self.head()] - if isinstance(include, str): - include = [include] kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit) diff --git a/dulwich/walk.py b/dulwich/walk.py index 45a0f38f4..5a38f7177 100644 --- a/dulwich/walk.py +++ b/dulwich/walk.py @@ -24,6 +24,7 @@ import collections import heapq from itertools import chain +from typing import List, Tuple, Set from dulwich.diff_tree import ( RENAME_CHANGE_TYPES, @@ -35,7 +36,9 @@ MissingCommitError, ) from dulwich.objects import ( + Commit, Tag, + ObjectID, ) ORDER_DATE = "date" @@ -128,15 +131,15 @@ def __repr__(self): class _CommitTimeQueue(object): """Priority queue of WalkEntry objects by commit time.""" - def __init__(self, walker): + def __init__(self, walker: "Walker"): self._walker = walker self._store = walker.store self._get_parents = walker.get_parents self._excluded = walker.excluded - self._pq = [] - self._pq_set = set() - self._seen = set() - self._done = set() + self._pq: List[Tuple[int, Commit]] = [] + self._pq_set: Set[ObjectID] = set() + self._seen: Set[ObjectID] = set() + self._done: Set[ObjectID] = set() self._min_time = walker.since self._last = None self._extra_commits_left = _MAX_EXTRA_COMMITS @@ -145,7 +148,7 @@ def __init__(self, walker): for commit_id in chain(walker.include, walker.excluded): self._push(commit_id) - def _push(self, object_id): + def _push(self, object_id: bytes): try: obj = self._store[object_id] except KeyError as exc: