Skip to content

Commit

Permalink
Add mypy checks to CI. (#60)
Browse files Browse the repository at this point in the history
Fixes #10
  • Loading branch information
plietar authored Jan 11, 2023
1 parent e34ec6f commit 566de22
Show file tree
Hide file tree
Showing 20 changed files with 175 additions and 163 deletions.
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[mypy]
[mypy-cbor2]
ignore_missing_imports = True
[mypy-pycose.*]
ignore_missing_imports = True
[mypy-ccf.*]
ignore_missing_imports = True
[mypy-aiotools.*]
ignore_missing_imports = True
[mypy-setuptools.*]
ignore_missing_imports = True
20 changes: 1 addition & 19 deletions pyscitt/pyscitt/cli/create_did_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,6 @@
DID_FILENAME = "did.json"


def get_did_web_doc_url_from_did(did: str) -> str:
rest = did.replace(DID_WEB_PREFIX, "")
try:
rest.index(":")
except:
return (
DID_WEB_DOC_URL_PREFIX
+ rest.replace(ENCODED_COLON, ":")
+ DID_WEB_DOC_WELLKNOWN_PATH
+ DID_WEB_DOC_URL_SUFFIX
)
else:
return (
DID_WEB_DOC_URL_PREFIX
+ rest.replace(":", "/").replace(ENCODED_COLON, ":")
+ DID_WEB_DOC_URL_SUFFIX
)


def write_file(path: Path, contents: str):
print(f"Writing {path}")
path.write_text(contents)
Expand All @@ -46,6 +27,7 @@ def create_did_web(
):

parsed = urlsplit(base_url)
assert parsed.hostname
did = format_did_web(
host=parsed.hostname, port=parsed.port, path=parsed.path.lstrip("/")
)
Expand Down
2 changes: 1 addition & 1 deletion pyscitt/pyscitt/cli/prefix_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def prefix_tree_debug(client: Client):

def prefix_tree_get_receipt(
client: Client,
claim_path: Optional[str],
claim_path: Optional[Path],
issuer: Optional[str],
feed: Optional[str],
output: Optional[Path],
Expand Down
6 changes: 3 additions & 3 deletions pyscitt/pyscitt/cli/sign_claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def value(self) -> crypto.RegistrationInfoValue:
else:
data = self.content.encode("ascii")

if self.type == RegistrationInfoType.INT:
if self.type is RegistrationInfoType.INT:
return int(data.decode("utf-8"))
elif self.type == RegistrationInfoType.TEXT:
elif self.type is RegistrationInfoType.TEXT:
return data.decode("utf-8")
elif self.type == RegistrationInfoType.BYTES:
elif self.type is RegistrationInfoType.BYTES:
return data


Expand Down
13 changes: 7 additions & 6 deletions pyscitt/pyscitt/cli/submit_signed_claims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from pathlib import Path
from typing import Optional

Expand All @@ -23,15 +24,15 @@ def submit_signed_claimset(
with open(path, "rb") as f:
signed_claimset = f.read()

submission = client.submit_claim(
signed_claimset, skip_confirmation=skip_confirmation, decode=False
)

print(f"Submitted {path} as transaction {submission.tx}")
if skip_confirmation:
print("Confirmation fo submission was skipped! Claim may not be registered.")
tx = client.submit_claim(signed_claimset, skip_confirmation=True).tx
print(f"Submitted {path} as transaction {tx}")
print("Confirmation of submission was skipped! Claim may not be registered.")
return

submission = client.submit_claim(signed_claimset)
print(f"Submitted {path} as transaction {submission.tx}")

if receipt_path:
with open(receipt_path, "wb") as f:
f.write(submission.receipt)
Expand Down
42 changes: 34 additions & 8 deletions pyscitt/pyscitt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Optional, Tuple, Union
from typing import Generic, Iterable, Literal, Optional, Tuple, TypeVar, Union, overload
from urllib.parse import urlencode

import httpx
Expand Down Expand Up @@ -66,6 +66,9 @@ def __str__(self):
return f"{self.code}: {self.message}"


SelfClient = TypeVar("SelfClient", bound="BaseClient")


class BaseClient:
"""
Wrapper around an HTTP client, with facilities to interact with a CCF-based
Expand Down Expand Up @@ -128,14 +131,14 @@ def __init__(
base_url=url, headers=headers, verify=not development
)

def replace(self, **kwargs):
def replace(self: SelfClient, **kwargs) -> SelfClient:
"""
Create a new instance with certain parameters modified. Any parameters
that weren't specified will be inherited from the current instance.
The accepted keyword arguments are the same as those of the constructor.
"""
values = {
values: dict = {
"url": self.url,
"auth_token": self.auth_token,
"member_auth": self.member_auth,
Expand Down Expand Up @@ -269,16 +272,19 @@ def get_historical(self, *args, retry_on=[], **kwargs):
)


T = TypeVar("T", bound=Optional[bytes], covariant=True)


@dataclass
class Submission:
class Submission(Generic[T]):
"""
The result of submitting a claim to the service.
The presence and format of the receipt is depends on arguments passed to the `submit_claim`
method.
"""

tx: str
receipt: Optional[Union[bytes, Receipt]]
receipt: T

@property
def seqno(self) -> int:
Expand Down Expand Up @@ -307,9 +313,21 @@ def get_constitution(self) -> str:
def get_version(self) -> dict:
return self.get("/version").json()

@overload
def submit_claim(
self, claim: bytes, *, skip_confirmation: Literal[False] = False
) -> Submission[bytes]:
...

@overload
def submit_claim(
self, claim: bytes, *, skip_confirmation: Literal[True]
) -> Submission[None]:
...

def submit_claim(
self, claim: bytes, *, skip_confirmation=False, decode=True
) -> Submission:
self, claim: bytes, *, skip_confirmation=False
) -> Union[Submission[bytes], Submission[None]]:
headers = {"Content-Type": "application/cose"}
response = self.post(
"/entries",
Expand All @@ -324,7 +342,7 @@ def submit_claim(
if skip_confirmation:
return Submission(tx, None)
else:
receipt = self.get_receipt(tx, decode=decode)
receipt = self.get_receipt(tx, decode=False)
return Submission(tx, receipt)

def get_claim(self, tx: str, *, embed_receipt=False) -> bytes:
Expand All @@ -333,6 +351,14 @@ def get_claim(self, tx: str, *, embed_receipt=False) -> bytes:
)
return response.content

@overload
def get_receipt(self, tx: str, *, decode: Literal[True] = True) -> Receipt:
...

@overload
def get_receipt(self, tx: str, *, decode: Literal[False]) -> bytes:
...

def get_receipt(self, tx: str, *, decode=True) -> Union[bytes, Receipt]:
response = self.get_historical(f"/entries/{tx}/receipt")
if decode:
Expand Down
45 changes: 25 additions & 20 deletions pyscitt/pyscitt/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4

warnings.filterwarnings("ignore", category=Warning)
Expand Down Expand Up @@ -174,7 +174,12 @@ def generate_cert(
if not cn:
cn = str(uuid4())
subject_priv = load_pem_private_key(private_key_pem.encode("ascii"), None)
assert isinstance(
subject_priv, (RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey)
)

subject_pub_key = subject_priv.public_key()

subject = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, cn),
Expand All @@ -196,6 +201,9 @@ def generate_cert(
issuer_private_key_pem.encode("ascii"),
None,
)
assert isinstance(
issuer_priv_key, (RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey)
)
else:
issuer_priv_key = subject_priv
cert = (
Expand Down Expand Up @@ -234,15 +242,6 @@ def get_pub_key_type(pub_pem: str) -> str:
raise NotImplementedError("unsupported key type")


def get_cert_key_type(cert_pem: str) -> str:
cert = load_pem_x509_certificate(cert_pem.encode("ascii"))
if isinstance(cert.public_key(), RSAPublicKey):
return "rsa"
elif isinstance(cert.public_key(), EllipticCurvePublicKey):
return "ec"
raise NotImplementedError("unsupported key type")


def get_cert_info(pem: str) -> dict:
cert = load_pem_x509_certificate(pem.encode("ascii"))
cn = cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
Expand Down Expand Up @@ -323,13 +322,15 @@ def default_algorithm_for_private_key(key_pem: Pem) -> str:


def verify_cose_sign1(buf: bytes, cert_pem: str):
key_type = get_cert_key_type(cert_pem)
cert = load_pem_x509_certificate(cert_pem.encode("ascii"))
key = cert.public_key()
if key_type == "rsa":
if isinstance(key, RSAPublicKey):
cose_key = from_cryptography_rsakey_obj(key)
else:
elif isinstance(key, EllipticCurvePublicKey):
cose_key = from_cryptography_eckey_obj(key)
else:
raise NotImplementedError("unsupported key type")

msg = Sign1Message.decode(buf)
msg.key = cose_key
if not msg.verify_signature():
Expand Down Expand Up @@ -386,7 +387,7 @@ def pretty_cose_sign1(buf: bytes) -> str:

# temporary, from https://github.com/BrianSipos/pycose/blob/rsa_keys_algs/cose/keys/rsa.py
# until https://github.com/TimothyClaeys/pycose/issues/44 is implemented
def from_cryptography_rsakey_obj(ext_key) -> RSAKey:
def from_cryptography_rsakey_obj(ext_key: Union[RSAPrivateKey, RSAPublicKey]) -> RSAKey:
"""
Returns an initialized COSE Key object of type RSAKey.
:param ext_key: Python cryptography key.
Expand Down Expand Up @@ -426,7 +427,9 @@ def to_bstr(dec):
return RSAKey.from_dict(cose_key)


def from_cryptography_eckey_obj(ext_key) -> EC2Key:
def from_cryptography_eckey_obj(
ext_key: Union[EllipticCurvePrivateKey, EllipticCurvePublicKey]
) -> EC2Key:
"""
Returns an initialized COSE Key object of type EC2Key.
:param ext_key: Python cryptography key.
Expand Down Expand Up @@ -463,7 +466,9 @@ def from_cryptography_eckey_obj(ext_key) -> EC2Key:
return EC2Key.from_dict(cose_key)


def from_cryptography_ed25519key_obj(ext_key) -> OKPKey:
def from_cryptography_ed25519key_obj(
ext_key: Union[Ed25519PrivateKey, Ed25519PublicKey]
) -> OKPKey:
"""
Returns an initialized COSE Key object of type OKPKey.
:param ext_key: Python cryptography key.
Expand Down Expand Up @@ -654,15 +659,15 @@ class Signer:
issuer: Optional[str]
kid: Optional[str]
algorithm: str
x5c: Optional[list]
x5c: Optional[List[Pem]]

def __init__(
self,
private_key: Pem,
issuer: Optional[str] = None,
kid: Optional[str] = None,
algorithm: Optional[str] = None,
x5c: Optional[str] = None,
x5c: Optional[List[Pem]] = None,
):
"""
If no algorithm is specified, a sensible default is inferred from the private key.
Expand All @@ -682,7 +687,7 @@ def sign_claimset(
feed: Optional[str] = None,
registration_info: RegistrationInfo = {},
) -> bytes:
headers = {}
headers: dict = {}
headers[pycose.headers.Algorithm] = signer.algorithm
headers[pycose.headers.ContentType] = content_type

Expand All @@ -704,7 +709,7 @@ def sign_claimset(

def sign_json_claimset(
signer: Signer,
claims: json,
claims: dict,
content_type: str = "application/vnd.dummy+json",
feed: Optional[str] = None,
) -> bytes:
Expand Down
6 changes: 3 additions & 3 deletions pyscitt/pyscitt/prefix_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def __post_init__(self):
def hash(self, index: bytes, leaf: bytes) -> bytes:
positions = bitvector(self.positions)
hashes = reversed(self.hashes)
index = bitvector(index)
index_bits = bitvector(index)

current = leaf
for i in reversed(range(256)):
if positions[i]:
node = hashlib.sha256(index.prefix(i))
if index[i]:
node = hashlib.sha256(index_bits.prefix(i))
if index_bits[i]:
node.update(next(hashes))
node.update(current)
else:
Expand Down
Empty file added pyscitt/pyscitt/py.typed
Empty file.
15 changes: 7 additions & 8 deletions pyscitt/pyscitt/receipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TREE_ALGORITHM_CCF = "CCF"


def hdr_as_dict(phdr: list) -> dict:
def hdr_as_dict(phdr: dict) -> dict:
"""
Return a representation of a list of COSE header parameters that
is amenable to pretty-printing.
Expand Down Expand Up @@ -58,7 +58,12 @@ def as_dict(self) -> dict:
@classmethod
def from_cose_obj(self, headers: dict, cose_obj: Any) -> "ReceiptContents":
if headers.get(HEADER_PARAM_TREE_ALGORITHM) == TREE_ALGORITHM_CCF:
return CCFReceiptContents.from_cose_obj(cose_obj)
return CCFReceiptContents(
cose_obj[0],
cose_obj[1],
cose_obj[2],
LeafInfo.from_cose_obj(cose_obj[3]),
)
else:
raise ValueError("unsupported tree algorithm, cannot decode receipt")

Expand All @@ -70,12 +75,6 @@ class CCFReceiptContents(ReceiptContents):
inclusion_proof: list
leaf_info: LeafInfo

@classmethod
def from_cose_obj(cls, cose_obj: list) -> "ReceiptContents":
return cls(
cose_obj[0], cose_obj[1], cose_obj[2], LeafInfo.from_cose_obj(cose_obj[3])
)

def root(self, claims_digest: bytes) -> bytes:
leaf = self.leaf_info.digest(claims_digest).hex()

Expand Down
Loading

0 comments on commit 566de22

Please sign in to comment.