Skip to content

Commit

Permalink
Merge pull request #1076 from jelmer/more-typing
Browse files Browse the repository at this point in the history
Add more typing
  • Loading branch information
jelmer authored Oct 22, 2022
2 parents eb409c0 + 260a6d9 commit 1997f83
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 58 deletions.
4 changes: 4 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
when creating symlinks fails due to a permission
error. (Jelmer Vernooij, #1005)

* Add new ``ObjectID`` type in ``dulwich.objects``,
currently just an alias for ``bytes``.
(Jelmer Vernooij)

* Support repository format version 1.
(Jelmer Vernooij, #1056)

Expand Down
92 changes: 50 additions & 42 deletions dulwich/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
Iterable,
Union,
Type,
Iterator,
List,
)
import zlib
from hashlib import sha1
Expand Down Expand Up @@ -75,6 +77,9 @@
BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"


ObjectID = bytes


class EmptyFileException(FileFormatException):
"""An unexpectedly empty file was encountered."""

Expand Down Expand Up @@ -153,7 +158,10 @@ def filename_to_hex(filename):

def object_header(num_type: int, length: int) -> bytes:
"""Return an object header for the given numeric type and text length."""
return object_class(num_type).type_name + b" " + str(length).encode("ascii") + b"\0"
cls = object_class(num_type)
if cls is None:
raise AssertionError("unsupported class type num: %d" % num_type)
return cls.type_name + b" " + str(length).encode("ascii") + b"\0"


def serializable_property(name: str, docstring: Optional[str] = None):
Expand All @@ -169,7 +177,7 @@ def get(obj):
return property(get, set, doc=docstring)


def object_class(type):
def object_class(type: Union[bytes, int]) -> Optional[Type["ShaFile"]]:
"""Get the object class corresponding to the given type.
Args:
Expand All @@ -193,7 +201,7 @@ def check_hexsha(hex, error_msg):
raise ObjectFormatException("%s %s" % (error_msg, hex))


def check_identity(identity, error_msg):
def check_identity(identity: bytes, error_msg: str) -> None:
"""Check if the specified identity is valid.
This will raise an exception if the identity is not valid.
Expand Down Expand Up @@ -261,11 +269,13 @@ class ShaFile(object):

__slots__ = ("_chunked_text", "_sha", "_needs_serialization")

type_name = None # type: bytes
type_num = None # type: int
_needs_serialization: bool
type_name: bytes
type_num: int
_chunked_text: Optional[List[bytes]]

@staticmethod
def _parse_legacy_object_header(magic, f):
def _parse_legacy_object_header(magic, f) -> "ShaFile":
"""Parse a legacy object, creating it but not reading the file."""
bufsize = 1024
decomp = zlib.decompressobj()
Expand All @@ -287,18 +297,19 @@ def _parse_legacy_object_header(magic, f):
"Object size not an integer: %s" % exc) from exc
obj_class = object_class(type_name)
if not obj_class:
raise ObjectFormatException("Not a known type: %s" % type_name)
raise ObjectFormatException("Not a known type: %s" % type_name.decode('ascii'))
return obj_class()

def _parse_legacy_object(self, map):
def _parse_legacy_object(self, map) -> None:
"""Parse a legacy object, setting the raw string."""
text = _decompress(map)
header_end = text.find(b"\0")
if header_end < 0:
raise ObjectFormatException("Invalid object header, no \\0")
self.set_raw_string(text[header_end + 1 :])

def as_legacy_object_chunks(self, compression_level=-1):
def as_legacy_object_chunks(
self, compression_level: int = -1) -> Iterator[bytes]:
"""Return chunks representing the object in the experimental format.
Returns: List of strings
Expand All @@ -309,13 +320,13 @@ def as_legacy_object_chunks(self, compression_level=-1):
yield compobj.compress(chunk)
yield compobj.flush()

def as_legacy_object(self, compression_level=-1):
def as_legacy_object(self, compression_level: int = -1) -> bytes:
"""Return string representing the object in the experimental format."""
return b"".join(
self.as_legacy_object_chunks(compression_level=compression_level)
)

def as_raw_chunks(self):
def as_raw_chunks(self) -> List[bytes]:
"""Return chunks with serialization of the object.
Returns: List of strings, not necessarily one per line
Expand All @@ -324,41 +335,44 @@ def as_raw_chunks(self):
self._sha = None
self._chunked_text = self._serialize()
self._needs_serialization = False
return self._chunked_text
return self._chunked_text # type: ignore

def as_raw_string(self):
def as_raw_string(self) -> bytes:
"""Return raw string with serialization of the object.
Returns: String object
"""
return b"".join(self.as_raw_chunks())

def __bytes__(self):
def __bytes__(self) -> bytes:
"""Return raw string serialization of this object."""
return self.as_raw_string()

def __hash__(self):
"""Return unique hash for this object."""
return hash(self.id)

def as_pretty_string(self):
def as_pretty_string(self) -> bytes:
"""Return a string representing this object, fit for display."""
return self.as_raw_string()

def set_raw_string(self, text, sha=None):
def set_raw_string(
self, text: bytes, sha: Optional[ObjectID] = None) -> None:
"""Set the contents of this object from a serialized string."""
if not isinstance(text, bytes):
raise TypeError("Expected bytes for text, got %r" % text)
self.set_raw_chunks([text], sha)

def set_raw_chunks(self, chunks, sha=None):
def set_raw_chunks(
self, chunks: List[bytes],
sha: Optional[ObjectID] = None) -> None:
"""Set the contents of this object from a list of chunks."""
self._chunked_text = chunks
self._deserialize(chunks)
if sha is None:
self._sha = None
else:
self._sha = FixedSha(sha)
self._sha = FixedSha(sha) # type: ignore
self._needs_serialization = False

@staticmethod
Expand All @@ -370,7 +384,7 @@ def _parse_object_header(magic, f):
raise ObjectFormatException("Not a known type %d" % num_type)
return obj_class()

def _parse_object(self, map):
def _parse_object(self, map) -> None:
"""Parse a new style object, setting self._text."""
# skip type and size; type must have already been determined, and
# we trust zlib to fail if it's otherwise corrupted
Expand All @@ -383,7 +397,7 @@ def _parse_object(self, map):
self.set_raw_string(_decompress(raw))

@classmethod
def _is_legacy_object(cls, magic):
def _is_legacy_object(cls, magic: bytes) -> bool:
b0 = ord(magic[0:1])
b1 = ord(magic[1:2])
word = (b0 << 8) + b1
Expand Down Expand Up @@ -445,15 +459,20 @@ def from_raw_string(type_num, string, sha=None):
return obj

@staticmethod
def from_raw_chunks(type_num, chunks, sha=None):
def from_raw_chunks(
type_num: int, chunks: List[bytes],
sha: Optional[ObjectID] = None):
"""Creates an object of the indicated type from the raw chunks given.
Args:
type_num: The numeric type of the object.
chunks: An iterable of the raw uncompressed contents.
sha: Optional known sha for the object
"""
obj = object_class(type_num)()
cls = object_class(type_num)
if cls is None:
raise AssertionError("unsupported class type num: %d" % type_num)
obj = cls()
obj.set_raw_chunks(chunks, sha)
return obj

Expand All @@ -477,7 +496,7 @@ def _check_has_member(self, member, error_msg):
if getattr(self, member, None) is None:
raise ObjectFormatException(error_msg)

def check(self):
def check(self) -> None:
"""Check this object for internal consistency.
Raises:
Expand All @@ -500,9 +519,9 @@ def check(self):
raise ChecksumMismatch(new_sha, old_sha)

def _header(self):
return object_header(self.type, self.raw_length())
return object_header(self.type_num, self.raw_length())

def raw_length(self):
def raw_length(self) -> int:
"""Returns the length of the raw string of this object."""
ret = 0
for chunk in self.as_raw_chunks():
Expand All @@ -522,25 +541,14 @@ def sha(self):

def copy(self):
"""Create a new copy of this SHA1 object from its raw string"""
obj_class = object_class(self.get_type())
return obj_class.from_raw_string(self.get_type(), self.as_raw_string(), self.id)
obj_class = object_class(self.type_num)
return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)

@property
def id(self):
"""The hex SHA of this object."""
return self.sha().hexdigest().encode("ascii")

def get_type(self):
"""Return the type number for this object class."""
return self.type_num

def set_type(self, type):
"""Set the type number for this object class."""
self.type_num = type

# DEPRECATED: use type_num or type_name as needed.
type = property(get_type, set_type)

def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.id)

Expand Down Expand Up @@ -621,7 +629,7 @@ def check(self):
"""
super(Blob, self).check()

def splitlines(self):
def splitlines(self) -> List[bytes]:
"""Return list of lines in this blob.
This preserves the original line endings.
Expand Down Expand Up @@ -649,7 +657,7 @@ def splitlines(self):
return ret


def _parse_message(chunks):
def _parse_message(chunks: Iterable[bytes]):
"""Parse a message with a list of fields and a body.
Args:
Expand All @@ -660,7 +668,7 @@ def _parse_message(chunks):
"""
f = BytesIO(b"".join(chunks))
k = None
v = ""
v = b""
eof = False

def _strip_last_newline(value):
Expand Down Expand Up @@ -1596,7 +1604,7 @@ def _get_extra(self):
Tag,
)

_TYPE_MAP = {} # type: Dict[Union[bytes, int], Type[ShaFile]]
_TYPE_MAP: Dict[Union[bytes, int], Type[ShaFile]] = {}

for cls in OBJECT_CLASSES:
_TYPE_MAP[cls.type_name] = cls
Expand Down
16 changes: 9 additions & 7 deletions dulwich/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
valid_hexsha,
ZERO_SHA,
Tag,
ObjectID,
)
from dulwich.file import (
GitFile,
ensure_dir_exists,
)

Ref = bytes

HEADREF = b"HEAD"
SYMREF = b"ref: "
Expand Down Expand Up @@ -69,7 +71,7 @@ def parse_symref_value(contents):
raise ValueError(contents)


def check_ref_format(refname):
def check_ref_format(refname: Ref):
"""Check if a refname is correctly formatted.
Implements all the same rules as git-check-ref-format[1].
Expand Down Expand Up @@ -166,8 +168,8 @@ def get_peeled(self, name):

def import_refs(
self,
base: bytes,
other: Dict[bytes, bytes],
base: Ref,
other: Dict[Ref, ObjectID],
committer: Optional[bytes] = None,
timestamp: Optional[bytes] = None,
timezone: Optional[bytes] = None,
Expand Down Expand Up @@ -455,8 +457,8 @@ def _notify(self, ref, newsha):

def set_symbolic_ref(
self,
name,
other,
name: Ref,
other: Ref,
committer=None,
timestamp=None,
timezone=None,
Expand Down Expand Up @@ -507,8 +509,8 @@ def set_if_equals(

def add_if_new(
self,
name: bytes,
ref: bytes,
name: Ref,
ref: ObjectID,
committer=None,
timestamp=None,
timezone=None,
Expand Down
5 changes: 2 additions & 3 deletions dulwich/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,8 @@ def get_peeled(self, ref):
return cached
return self.object_store.peel_sha(self.refs[ref]).id

def get_walker(self, include=None, *args, **kwargs):
def get_walker(self, include: Optional[List[bytes]] = None,
*args, **kwargs):
"""Obtain a walker for this repository.
Args:
Expand Down Expand Up @@ -771,8 +772,6 @@ def get_walker(self, include=None, *args, **kwargs):

if include is None:
include = [self.head()]
if isinstance(include, str):
include = [include]

kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit)

Expand Down
Loading

0 comments on commit 1997f83

Please sign in to comment.