Skip to content

Commit

Permalink
Merge pull request #1489 from sechkova/enable-mypy-ngclient
Browse files Browse the repository at this point in the history
Enable mypy for ngclient
  • Loading branch information
Jussi Kukkonen authored Aug 30, 2021
2 parents 2dd88d9 + 6ada96c commit 3028fb6
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 56 deletions.
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ isort
pylint
mypy
bandit
types-requests
8 changes: 7 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ warn_unreachable = True
strict_equality = True
disallow_untyped_defs = True
disallow_untyped_calls = True
files = tuf/api/, tuf/exceptions.py
files =
tuf/api/,
tuf/ngclient,
tuf/exceptions.py

[mypy-securesystemslib.*]
ignore_missing_imports = True

[mypy-urllib3.*]
ignore_missing_imports = True
12 changes: 6 additions & 6 deletions tuf/api/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import (
IO,
Any,
BinaryIO,
ClassVar,
Dict,
Generic,
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, signed: T, signatures: "OrderedDict[str, Signature]"):
self.signatures = signatures

@classmethod
def from_dict(cls, metadata: Dict[str, Any]) -> "Metadata":
def from_dict(cls, metadata: Dict[str, Any]) -> "Metadata[T]":
"""Creates Metadata object from its dict representation.
Arguments:
Expand Down Expand Up @@ -753,7 +753,7 @@ class BaseFile:

@staticmethod
def _verify_hashes(
data: Union[bytes, BinaryIO], expected_hashes: Dict[str, str]
data: Union[bytes, IO[bytes]], expected_hashes: Dict[str, str]
) -> None:
"""Verifies that the hash of 'data' matches 'expected_hashes'"""
is_bytes = isinstance(data, bytes)
Expand Down Expand Up @@ -782,7 +782,7 @@ def _verify_hashes(

@staticmethod
def _verify_length(
data: Union[bytes, BinaryIO], expected_length: int
data: Union[bytes, IO[bytes]], expected_length: int
) -> None:
"""Verifies that the length of 'data' matches 'expected_length'"""
if isinstance(data, bytes):
Expand Down Expand Up @@ -867,7 +867,7 @@ def to_dict(self) -> Dict[str, Any]:

return res_dict

def verify_length_and_hashes(self, data: Union[bytes, BinaryIO]) -> None:
def verify_length_and_hashes(self, data: Union[bytes, IO[bytes]]) -> None:
"""Verifies that the length and hashes of "data" match expected values.
Args:
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def to_dict(self) -> Dict[str, Any]:
**self.unrecognized_fields,
}

def verify_length_and_hashes(self, data: Union[bytes, BinaryIO]) -> None:
def verify_length_and_hashes(self, data: Union[bytes, IO[bytes]]) -> None:
"""Verifies that length and hashes of "data" match expected values.
Args:
Expand Down
11 changes: 7 additions & 4 deletions tuf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from urllib import parse

from typing import Any, Dict
from typing import Any, Dict, Optional

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -206,16 +206,19 @@ def __repr__(self) -> str:
class SlowRetrievalError(DownloadError):
""""Indicate that downloading a file took an unreasonably long time."""

def __init__(self, average_download_speed: int):
def __init__(self, average_download_speed: Optional[int] = None):
super(SlowRetrievalError, self).__init__()

self.__average_download_speed = average_download_speed #bytes/second

def __str__(self) -> str:
return (
'Download was too slow. Average speed: ' +
msg = 'Download was too slow.'
if self.__average_download_speed is not None:
msg = ('Download was too slow. Average speed: ' +
repr(self.__average_download_speed) + ' bytes per second.')

return msg

def __repr__(self) -> str:
return self.__class__.__name__ + ' : ' + str(self)

Expand Down
10 changes: 5 additions & 5 deletions tuf/ngclient/_internal/requests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
import time
from typing import Iterator, Optional
from typing import Dict, Iterator, Optional
from urllib import parse

# Imports
Expand All @@ -31,7 +31,7 @@ class RequestsFetcher(FetcherInterface):
session per scheme+hostname combination.
"""

def __init__(self):
def __init__(self) -> None:
# http://docs.python-requests.org/en/master/user/advanced/#session-objects:
#
# "The Session object allows you to persist certain parameters across
Expand All @@ -46,7 +46,7 @@ def __init__(self):
# improve efficiency, but avoiding sharing state between different
# hosts-scheme combinations to minimize subtle security issues.
# Some cookies may not be HTTP-safe.
self._sessions = {}
self._sessions: Dict[str, requests.Session] = {}

# Default settings
self.socket_timeout: int = 4 # seconds
Expand Down Expand Up @@ -141,12 +141,12 @@ def _chunks(
)

except urllib3.exceptions.ReadTimeoutError as e:
raise exceptions.SlowRetrievalError(str(e))
raise exceptions.SlowRetrievalError from e

finally:
response.close()

def _get_session(self, url):
def _get_session(self, url: str) -> requests.Session:
"""Returns a different customized requests.Session per schema+hostname
combination.
"""
Expand Down
77 changes: 51 additions & 26 deletions tuf/ngclient/_internal/trusted_metadata_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from typing import Dict, Iterator, Optional

from tuf import exceptions
from tuf.api.metadata import Metadata
from tuf.api.metadata import Metadata, Root, Snapshot, Targets, Timestamp
from tuf.api.serialization import DeserializationError

logger = logging.getLogger(__name__)
Expand All @@ -92,13 +92,13 @@ def __init__(self, root_data: bytes):
RepositoryError: Metadata failed to load or verify. The actual
error type and content will contain more details.
"""
self._trusted_set = {} # type: Dict[str: Metadata]
self._trusted_set: Dict[str, Metadata] = {}
self.reference_time = datetime.utcnow()

# Load and validate the local root metadata. Valid initial trusted root
# metadata is required
logger.debug("Updating initial trusted root")
self.update_root(root_data)
self._load_trusted_root(root_data)

def __getitem__(self, role: str) -> Metadata:
"""Returns current Metadata for 'role'"""
Expand All @@ -114,27 +114,27 @@ def __iter__(self) -> Iterator[Metadata]:

# Helper properties for top level metadata
@property
def root(self) -> Optional[Metadata]:
"""Current root Metadata or None"""
return self._trusted_set.get("root")
def root(self) -> Metadata[Root]:
"""Current root Metadata"""
return self._trusted_set["root"]

@property
def timestamp(self) -> Optional[Metadata]:
def timestamp(self) -> Optional[Metadata[Timestamp]]:
"""Current timestamp Metadata or None"""
return self._trusted_set.get("timestamp")

@property
def snapshot(self) -> Optional[Metadata]:
def snapshot(self) -> Optional[Metadata[Snapshot]]:
"""Current snapshot Metadata or None"""
return self._trusted_set.get("snapshot")

@property
def targets(self) -> Optional[Metadata]:
def targets(self) -> Optional[Metadata[Targets]]:
"""Current targets Metadata or None"""
return self._trusted_set.get("targets")

# Methods for updating metadata
def update_root(self, data: bytes):
def update_root(self, data: bytes) -> None:
"""Verifies and loads 'data' as new root metadata.
Note that an expired intermediate root is considered valid: expiry is
Expand All @@ -152,7 +152,7 @@ def update_root(self, data: bytes):
logger.debug("Updating root")

try:
new_root = Metadata.from_bytes(data)
new_root = Metadata[Root].from_bytes(data)
except DeserializationError as e:
raise exceptions.RepositoryError("Failed to load root") from e

Expand All @@ -161,21 +161,21 @@ def update_root(self, data: bytes):
f"Expected 'root', got '{new_root.signed.type}'"
)

if self.root is not None:
# We are not loading initial trusted root: verify the new one
self.root.verify_delegate("root", new_root)
# Verify that new root is signed by trusted root
self.root.verify_delegate("root", new_root)

if new_root.signed.version != self.root.signed.version + 1:
raise exceptions.ReplayedMetadataError(
"root", new_root.signed.version, self.root.signed.version
)
if new_root.signed.version != self.root.signed.version + 1:
raise exceptions.ReplayedMetadataError(
"root", new_root.signed.version, self.root.signed.version
)

# Verify that new root is signed by itself
new_root.verify_delegate("root", new_root)

self._trusted_set["root"] = new_root
logger.debug("Updated root")

def update_timestamp(self, data: bytes):
def update_timestamp(self, data: bytes) -> None:
"""Verifies and loads 'data' as new timestamp metadata.
Note that an expired intermediate timestamp is considered valid so it
Expand All @@ -199,7 +199,7 @@ def update_timestamp(self, data: bytes):
# timestamp/snapshot can not yet be loaded at this point

try:
new_timestamp = Metadata.from_bytes(data)
new_timestamp = Metadata[Timestamp].from_bytes(data)
except DeserializationError as e:
raise exceptions.RepositoryError("Failed to load timestamp") from e

Expand Down Expand Up @@ -237,7 +237,7 @@ def update_timestamp(self, data: bytes):
self._trusted_set["timestamp"] = new_timestamp
logger.debug("Updated timestamp")

def update_snapshot(self, data: bytes):
def update_snapshot(self, data: bytes) -> None:
"""Verifies and loads 'data' as new snapshot metadata.
Note that intermediate snapshot is considered valid even if it is
Expand Down Expand Up @@ -276,7 +276,7 @@ def update_snapshot(self, data: bytes):
) from e

try:
new_snapshot = Metadata.from_bytes(data)
new_snapshot = Metadata[Snapshot].from_bytes(data)
except DeserializationError as e:
raise exceptions.RepositoryError("Failed to load snapshot") from e

Expand Down Expand Up @@ -314,7 +314,11 @@ def update_snapshot(self, data: bytes):
self._trusted_set["snapshot"] = new_snapshot
logger.debug("Updated snapshot")

def _check_final_snapshot(self):
def _check_final_snapshot(self) -> None:
"""Check snapshot expiry and version before targets is updated"""

assert self.snapshot is not None # nosec
assert self.timestamp is not None # nosec
if self.snapshot.signed.is_expired(self.reference_time):
raise exceptions.ExpiredMetadataError("snapshot.json is expired")

Expand All @@ -328,7 +332,7 @@ def _check_final_snapshot(self):
f"got {self.snapshot.signed.version}"
)

def update_targets(self, data: bytes):
def update_targets(self, data: bytes) -> None:
"""Verifies and loads 'data' as new top-level targets metadata.
Args:
Expand All @@ -342,7 +346,7 @@ def update_targets(self, data: bytes):

def update_delegated_targets(
self, data: bytes, role_name: str, delegator_name: str
):
) -> None:
"""Verifies and loads 'data' as new metadata for target 'role_name'.
Args:
Expand Down Expand Up @@ -383,7 +387,7 @@ def update_delegated_targets(
) from e

try:
new_delegate = Metadata.from_bytes(data)
new_delegate = Metadata[Targets].from_bytes(data)
except DeserializationError as e:
raise exceptions.RepositoryError("Failed to load snapshot") from e

Expand All @@ -405,3 +409,24 @@ def update_delegated_targets(

self._trusted_set[role_name] = new_delegate
logger.debug("Updated %s delegated by %s", role_name, delegator_name)

def _load_trusted_root(self, data: bytes) -> None:
"""Verifies and loads 'data' as trusted root metadata.
Note that an expired initial root is considered valid: expiry is
only checked for the final root in update_timestamp().
"""
try:
new_root = Metadata[Root].from_bytes(data)
except DeserializationError as e:
raise exceptions.RepositoryError("Failed to load root") from e

if new_root.signed.type != "root":
raise exceptions.RepositoryError(
f"Expected 'root', got '{new_root.signed.type}'"
)

new_root.verify_delegate("root", new_root)

self._trusted_set["root"] = new_root
logger.debug("Loaded trusted root")
Loading

0 comments on commit 3028fb6

Please sign in to comment.