diff --git a/pyproject.toml b/pyproject.toml index 2a08b68c..1d1b27f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ + "nacl.encoding", "nacl.exceptions", ] disallow_any_unimported = true diff --git a/src/nacl/encoding.py b/src/nacl/encoding.py index 848be630..31bdb11b 100644 --- a/src/nacl/encoding.py +++ b/src/nacl/encoding.py @@ -15,69 +15,87 @@ import base64 import binascii -from typing import SupportsBytes +import sys +from typing import SupportsBytes, Type + +if sys.version_info >= (3, 8): + from typing import Protocol + + class _Encoder(Protocol): + @staticmethod + def encode(data: bytes) -> bytes: + ... + + @staticmethod + def decode(data: bytes) -> bytes: + ... + + # We pass around the encoder classes themselves (rather than an instance). + Encoder = Type[_Encoder] +else: + Encoder = "Encoder" class RawEncoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return data @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return data class HexEncoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return binascii.hexlify(data) @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return binascii.unhexlify(data) class Base16Encoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return base64.b16encode(data) @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return base64.b16decode(data) class Base32Encoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return base64.b32encode(data) @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return base64.b32decode(data) class Base64Encoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return base64.b64encode(data) @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return base64.b64decode(data) class URLSafeBase64Encoder: @staticmethod - def encode(data): + def encode(data: bytes) -> bytes: return base64.urlsafe_b64encode(data) @staticmethod - def decode(data): + def decode(data: bytes) -> bytes: return base64.urlsafe_b64decode(data) class Encodable: - def encode(self: SupportsBytes, encoder=RawEncoder): + def encode(self: SupportsBytes, encoder: Encoder = RawEncoder) -> bytes: return encoder.encode(bytes(self))