Skip to content

Commit

Permalink
Switched from black to ruff-format
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Nov 1, 2023
1 parent ae25fd0 commit 9486245
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 161 deletions.
6 changes: 1 addition & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ repos:
hooks:
- id: ruff
args: [--fix, --show-fixes]

- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
- id: black
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
Expand Down
34 changes: 9 additions & 25 deletions cbor2/_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
T = TypeVar("T")

timestamp_re = re.compile(
r"^(\d{4})-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)"
r"(?:\.(\d{1,6})\d*)?(?:Z|([+-])(\d\d):(\d\d))$"
r"^(\d{4})-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)" r"(?:\.(\d{1,6})\d*)?(?:Z|([+-])(\d\d):(\d\d))$"
)


Expand Down Expand Up @@ -140,9 +139,7 @@ def object_hook(self) -> Callable[[CBORDecoder, dict[Any, Any]], Any] | None:
return self._object_hook

@object_hook.setter
def object_hook(
self, value: Callable[[CBORDecoder, Mapping[Any, Any]], Any] | None
) -> None:
def object_hook(self, value: Callable[[CBORDecoder, Mapping[Any, Any]], Any] | None) -> None:
if value is None or callable(value):
self._object_hook = value
else:
Expand Down Expand Up @@ -255,14 +252,10 @@ def _decode_length(self, subtype: int) -> int:
...

@overload
def _decode_length(
self, subtype: int, allow_indefinite: Literal[True]
) -> int | None:
def _decode_length(self, subtype: int, allow_indefinite: Literal[True]) -> int | None:
...

def _decode_length(
self, subtype: int, allow_indefinite: bool = False
) -> int | None:
def _decode_length(self, subtype: int, allow_indefinite: bool = False) -> int | None:
if subtype < 24:
return subtype
elif subtype == 24:
Expand All @@ -276,9 +269,7 @@ def _decode_length(
elif subtype == 31 and allow_indefinite:
return None
else:
raise CBORDecodeValueError(
"unknown unsigned integer subtype 0x%x" % subtype
)
raise CBORDecodeValueError("unknown unsigned integer subtype 0x%x" % subtype)

def decode_uint(self, subtype: int) -> int:
# Major tag 0
Expand All @@ -303,8 +294,7 @@ def decode_bytestring(self, subtype: int) -> bytes:
length = self._decode_length(initial_byte & 0x1F)
if length is None or length > sys.maxsize:
raise CBORDecodeValueError(
"invalid length for indefinite bytestring chunk 0x%x"
% length
"invalid length for indefinite bytestring chunk 0x%x" % length
)
value = self.read(length)
buf.append(value)
Expand All @@ -314,9 +304,7 @@ def decode_bytestring(self, subtype: int) -> bytes:
)
else:
if length > sys.maxsize:
raise CBORDecodeValueError(
"invalid length for bytestring 0x%x" % length
)
raise CBORDecodeValueError("invalid length for bytestring 0x%x" % length)

result = self.read(length)
self._stringref_namespace_add(result, length)
Expand Down Expand Up @@ -357,9 +345,7 @@ def decode_string(self, subtype: int) -> str:
value = self.read(length).decode("utf-8", self._str_errors)
buf.append(value)
else:
raise CBORDecodeValueError(
"non-string found in indefinite length string"
)
raise CBORDecodeValueError("non-string found in indefinite length string")
else:
if length > sys.maxsize:
raise CBORDecodeValueError("invalid length for string 0x%x" % length)
Expand Down Expand Up @@ -584,9 +570,7 @@ def decode_sharedref(self) -> Any:
raise CBORDecodeValueError("shared reference %d not found" % value)

if shared is None:
raise CBORDecodeValueError(
"shared value %d has not been initialized" % value
)
raise CBORDecodeValueError("shared value %d has not been initialized" % value)
else:
return shared

Expand Down
33 changes: 8 additions & 25 deletions cbor2/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,14 @@ def __init__(
self._shared_containers: dict[
int, tuple[object, int | None]
] = {} # indexes used for value sharing
self._string_references: dict[
str | bytes, int
] = {} # indexes used for string references
self._string_references: dict[str | bytes, int] = {} # indexes used for string references
self._encoders = default_encoders.copy()
if canonical:
self._encoders.update(canonical_encoders)
if date_as_datetime:
self._encoders[date] = CBOREncoder.encode_date

def _find_encoder(
self, obj_type: type
) -> Callable[[CBOREncoder, Any], None] | None:
def _find_encoder(self, obj_type: type) -> Callable[[CBOREncoder, Any], None] | None:
for type_or_tuple, enc in list(self._encoders.items()):
if type(type_or_tuple) is tuple:
try:
Expand Down Expand Up @@ -310,11 +306,7 @@ def encode(self, obj: Any) -> None:
the object to encode
"""
obj_type = obj.__class__
encoder = (
self._encoders.get(obj_type)
or self._find_encoder(obj_type)
or self._default
)
encoder = self._encoders.get(obj_type) or self._find_encoder(obj_type) or self._default
if not encoder:
raise CBOREncodeTypeError("cannot serialize type %s" % obj_type.__name__)

Expand All @@ -335,19 +327,15 @@ def encode_to_bytes(self, obj: Any) -> bytes:
self.fp = old_fp
return fp.getvalue()

def encode_container(
self, encoder: Callable[[CBOREncoder, Any], Any], value: Any
) -> None:
def encode_container(self, encoder: Callable[[CBOREncoder, Any], Any], value: Any) -> None:
if self.string_namespacing:
# Create a new string reference domain
self.encode_length(6, 256)

with self.disable_string_namespacing():
self.encode_shared(encoder, value)

def encode_shared(
self, encoder: Callable[[CBOREncoder, Any], Any], value: Any
) -> None:
def encode_shared(self, encoder: Callable[[CBOREncoder, Any], Any], value: Any) -> None:
value_id = id(value)
try:
index = self._shared_containers[id(value)][1]
Expand Down Expand Up @@ -482,9 +470,7 @@ def encode_sortable_key(self, value: Any) -> tuple[int, bytes]:
@container_encoder
def encode_canonical_map(self, value: Mapping[Any, Any]) -> None:
"""Reorder keys according to Canonical CBOR specification"""
keyed_keys = (
(self.encode_sortable_key(key), key, value) for key, value in value.items()
)
keyed_keys = ((self.encode_sortable_key(key), key, value) for key, value in value.items())
self.encode_length(5, len(value))
for sortkey, realkey, value in sorted(keyed_keys):
if self.string_referencing:
Expand Down Expand Up @@ -521,8 +507,7 @@ def encode_datetime(self, value: datetime) -> None:
value = value.replace(tzinfo=self._timezone)
else:
raise CBOREncodeValueError(
f"naive datetime {value!r} encountered and no default timezone "
"has been set"
f"naive datetime {value!r} encountered and no default timezone " "has been set"
)

if self.datetime_as_timestamp:
Expand Down Expand Up @@ -600,9 +585,7 @@ def encode_ipaddress(self, value: IPv4Address | IPv6Address) -> None:

def encode_ipnetwork(self, value: IPv4Network | IPv6Network) -> None:
# Semantic tag 261
self.encode_semantic(
CBORTag(261, {value.network_address.packed: value.prefixlen})
)
self.encode_semantic(CBORTag(261, {value.network_address.packed: value.prefixlen}))

#
# Special encoders (major tag 7)
Expand Down
4 changes: 1 addition & 3 deletions cbor2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@
from ._decoder import load as load
from ._decoder import loads as loads

warn(
"The cbor.decoder module has been deprecated. Instead import everything directly from cbor2."
)
warn("The cbor.decoder module has been deprecated. Instead import everything directly from cbor2.")
16 changes: 4 additions & 12 deletions cbor2/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
from typing import Literal, TypeAlias

T = TypeVar("T")
JSONValue: TypeAlias = (
"str | float | bool | None | list[JSONValue] | dict[str, JSONValue]"
)
JSONValue: TypeAlias = "str | float | bool | None | list[JSONValue] | dict[str, JSONValue]"

default_encoders: dict[type, Callable[[Any], Any]] = {
bytes: lambda x: x.decode(encoding="utf-8", errors="backslashreplace"),
Expand All @@ -45,9 +43,7 @@
}


def tag_hook(
decoder: CBORDecoder, tag: CBORTag, ignore_tags: Collection[int] = ()
) -> object:
def tag_hook(decoder: CBORDecoder, tag: CBORTag, ignore_tags: Collection[int] = ()) -> object:
if tag.tag in ignore_tags:
return tag.value

Expand Down Expand Up @@ -75,19 +71,15 @@ def iterdecode(
object_hook: Callable[[CBORDecoder, dict[Any, Any]], Any] | None = None,
str_errors: Literal["strict", "error", "replace"] = "strict",
) -> Iterator[Any]:
decoder = CBORDecoder(
f, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors
)
decoder = CBORDecoder(f, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors)
while True:
try:
yield decoder.decode()
except EOFError:
return


def key_to_str(
d: T, dict_ids: set[int] | None = None
) -> str | list[Any] | dict[str, Any] | T:
def key_to_str(d: T, dict_ids: set[int] | None = None) -> str | list[Any] | dict[str, Any] | T:
dict_ids = set(dict_ids or [])
rval: dict[str, Any] = {}
if not isinstance(d, dict):
Expand Down
4 changes: 1 addition & 3 deletions cbor2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@
from ._types import FrozenDict as FrozenDict
from ._types import undefined as undefined

warn(
"The cbor2.types module has been deprecated. Instead import everything directly from cbor2."
)
warn("The cbor2.types module has been deprecated. Instead import everything directly from cbor2.")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ select = [
"RUF100", # unused noqa (yesqa)
"UP", # pyupgrade
]
ignore = ["ISC001"]

[tool.mypy]
strict = true
Expand Down
6 changes: 1 addition & 5 deletions scripts/half_float_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ def grouper(iterable, n, fillvalue=None):
def sigtable():
print("static const uint32_t sigtable[] = {")
values = (
0
if i == 0
else convertsig(i)
if 1 <= i < 1024
else 0x38000000 + ((i - 1024) << 13)
0 if i == 0 else convertsig(i) if 1 <= i < 1024 else 0x38000000 + ((i - 1024) << 13)
for i in range(2048)
)
values = (f"{i:#010x}" for i in values)
Expand Down
3 changes: 1 addition & 2 deletions scripts/ref_leak_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class Module:
(
"objectdictlist",
{"timezone": UTC},
[{"name": "Foo", "species": "cat", "dob": datetime(2013, 5, 20), "weight": 4.1}]
* 100,
[{"name": "Foo", "species": "cat", "dob": datetime(2013, 5, 20), "weight": 4.1}] * 100,
),
("tag", {}, c_cbor2.CBORTag(1, 1)),
("nestedtag", {}, {c_cbor2.CBORTag(1, 1): 1}),
Expand Down
27 changes: 7 additions & 20 deletions scripts/speed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ class Module:
(
"objectdictlist",
{"timezone": UTC},
[{"name": "Foo", "species": "cat", "dob": datetime(2013, 5, 20), "weight": 4.1}]
* 100,
[{"name": "Foo", "species": "cat", "dob": datetime(2013, 5, 20), "weight": 4.1}] * 100,
),
]

Expand Down Expand Up @@ -122,9 +121,7 @@ def time(op, repeat=3):
return Timing(min(t.repeat(repeat, number)) / number, repeat, number)


def format_time(
t, suffixes=("s", "ms", "µs", "ns"), zero="0s", template="{time:.1f}{suffix}"
):
def format_time(t, suffixes=("s", "ms", "µs", "ns"), zero="0s", template="{time:.1f}{suffix}"):
if isinstance(t, Exception):
return "-"
else:
Expand All @@ -133,9 +130,7 @@ def format_time(
except ValueError:
return zero
else:
return template.format(
time=t.time * 2 ** (index * 10), suffix=suffixes[index]
)
return template.format(time=t.time * 2 ** (index * 10), suffix=suffixes[index])


def print_len(s):
Expand Down Expand Up @@ -204,13 +199,9 @@ def output_table(results):
" ",
" " * col_widths[0],
" | ",
"{value:^{width}}".format(
value="Encoding", width=sum(col_widths[1:4]) + 6
),
"{value:^{width}}".format(value="Encoding", width=sum(col_widths[1:4]) + 6),
" | ",
"{value:^{width}}".format(
value="Decoding", width=sum(col_widths[4:7]) + 6
),
"{value:^{width}}".format(value="Decoding", width=sum(col_widths[4:7]) + 6),
" |",
)
)
Expand Down Expand Up @@ -264,14 +255,10 @@ def output_csv(results):
writer.writerow(
(
title,
result.cbor.encoding.time
if isinstance(result.cbor.encoding, Timing)
else None,
result.cbor.encoding.time if isinstance(result.cbor.encoding, Timing) else None,
result.c_cbor2.encoding.time,
result.py_cbor2.encoding.time,
result.cbor.decoding.time
if isinstance(result.cbor.encoding, Timing)
else None,
result.cbor.decoding.time if isinstance(result.cbor.encoding, Timing) else None,
result.c_cbor2.decoding.time,
result.py_cbor2.decoding.time,
)
Expand Down
Loading

0 comments on commit 9486245

Please sign in to comment.