Skip to content

Commit

Permalink
support for complex numbers
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
chillenb committed Nov 17, 2024
1 parent d9cee77 commit 3ff08b3
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 0 deletions.
16 changes: 16 additions & 0 deletions cbor2/_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,21 @@ def decode_sharedref(self) -> Any:
else:
return shared

def decode_complex(self) -> complex:
# Semantic tag 43000
inputval = self._decode(immutable=True, unshared=True)
try:
value = complex(*inputval)
except TypeError as exc:
if not isinstance(inputval, tuple):
raise CBORDecodeValueError(
"error decoding complex: input value was not a tuple"
) from None

raise CBORDecodeValueError("error decoding complex") from exc

return self.set_shareable(value)

def decode_rational(self) -> Fraction:
# Semantic tag 30
from fractions import Fraction
Expand Down Expand Up @@ -780,6 +795,7 @@ def decode_float64(self) -> float:
260: CBORDecoder.decode_ipaddress,
261: CBORDecoder.decode_ipnetwork,
1004: CBORDecoder.decode_date_string,
43000: CBORDecoder.decode_complex,
55799: CBORDecoder.decode_self_describe_cbor,
}

Expand Down
6 changes: 6 additions & 0 deletions cbor2/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ def encode_float(self, value: float) -> None:
else:
self._fp_write(struct.pack(">Bd", 0xFB, value))

def encode_complex(self, value: complex) -> None:
# Semantic tag 43000
with self.disable_value_sharing():
self.encode_semantic(CBORTag(43000, [value.real, value.imag]))

def encode_minimal_float(self, value: float) -> None:
# Handle special values efficiently
if math.isnan(value):
Expand Down Expand Up @@ -652,6 +657,7 @@ def encode_undefined(self, value: UndefinedType) -> None:
str: CBOREncoder.encode_string,
int: CBOREncoder.encode_int,
float: CBOREncoder.encode_float,
complex: CBOREncoder.encode_complex,
("decimal", "Decimal"): CBOREncoder.encode_decimal,
bool: CBOREncoder.encode_boolean,
type(None): CBOREncoder.encode_none,
Expand Down
1 change: 1 addition & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Tag Semantics Python type(s)
258 Set of unique items set
260 Network address :class:`ipaddress.IPv4Address` (or IPv6)
261 Network prefix :class:`ipaddress.IPv4Network` (or IPv6)
43000 Single complex number complex
55799 Self-Described CBOR object
===== ======================================== ====================================================

Expand Down
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning <https://semver.org/>`_.

- Dropped support for Python 3.8
(#247 <https://github.com/agronholm/cbor2/pull/247>_; PR by @hugovk)
- Added complex number support (tag 43000)
(#249 <https://github.com/agronholm/cbor2/pull/249>_; PR by @chillenb)

**5.6.5** (2024-10-09)

Expand Down
32 changes: 32 additions & 0 deletions source/decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static PyObject * CBORDecoder_decode_epoch_date(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_date_string(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_fraction(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_complex(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
Expand Down Expand Up @@ -1172,6 +1173,7 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 1004: ret = CBORDecoder_decode_date_string(self); break;
case 43000: ret = CBORDecoder_decode_complex(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
break;

Expand Down Expand Up @@ -1636,6 +1638,34 @@ CBORDecoder_decode_sharedref(CBORDecoderObject *self)
return ret;
}

// CBORDecoder.decode_complex(self)
static PyObject *
CBORDecoder_decode_complex(CBORDecoderObject *self)
{
// semantic type 43000
PyObject *payload_t, *real, *imag, *ret = NULL;
payload_t = decode(self, DECODE_IMMUTABLE | DECODE_UNSHARED);
if (payload_t) {
if (PyTuple_CheckExact(payload_t) && PyTuple_GET_SIZE(payload_t) == 2) {
real = PyTuple_GET_ITEM(payload_t, 0);
imag = PyTuple_GET_ITEM(payload_t, 1);
f(PyFloat_CheckExact(real) && PyFloat_CheckExact(imag)) {
ret = PyComplex_FromDoubles(PyFloat_AS_DOUBLE(real), PyFloat_AS_DOUBLE(imag));
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"Incorrect tag 43000 payload: does not contain two floats");
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"Incorrect tag 43000 payload: not an array of length 2");
}
Py_DECREF(payload_t);
}
set_shareable(self, ret);
return ret;
}

// CBORDecoder.decode_rational(self)
static PyObject *
Expand Down Expand Up @@ -2159,6 +2189,8 @@ static PyMethodDef CBORDecoder_methods[] = {
"decode a fractional number from the input"},
{"decode_rational", (PyCFunction) CBORDecoder_decode_rational, METH_NOARGS,
"decode a rational value from the input"},
{"decode_complex", (PyCFunction) CBORDecoder_decode_complex, METH_NOARGS,
"decode a complex value from the input"},
{"decode_bigfloat", (PyCFunction) CBORDecoder_decode_bigfloat, METH_NOARGS,
"decode a large floating-point value from the input"},
{"decode_regexp", (PyCFunction) CBORDecoder_decode_regexp, METH_NOARGS,
Expand Down
32 changes: 32 additions & 0 deletions source/encoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,36 @@ CBOREncoder_encode_rational(CBOREncoderObject *self, PyObject *value)
return ret;
}

// CBOREncoder.encode_complex(self, value)
static PyObject *
CBOREncoder_encode_complex(CBOREncoderObject *self, PyObject *value)
{
// semantic type 43000
PyObject *tuple, *real, *imag, *ret = NULL;
bool sharing;

real = PyObject_GetAttr(value, _CBOR2_str_real);
if (real) {
imag = PyObject_GetAttr(value, _CBOR2_str_imag);
if (imag) {
tuple = PyTuple_Pack(2, real, imag);
if (tuple) {
sharing = self->value_sharing;
self->value_sharing = false;
if (encode_semantic(self, 43000, tuple) == 0) {
Py_INCREF(Py_None);
ret = Py_None;
}
self->value_sharing = sharing;
Py_DECREF(tuple);
}
Py_DECREF(imag);
}
Py_DECREF(real);
}
return ret;
}


// CBOREncoder.encode_regexp(self, value)
static PyObject *
Expand Down Expand Up @@ -2118,6 +2148,8 @@ static PyMethodDef CBOREncoder_methods[] = {
"encode the specified integer *value* to the output"},
{"encode_float", (PyCFunction) CBOREncoder_encode_float, METH_O,
"encode the specified floating-point *value* to the output"},
{"encode_complex", (PyCFunction) CBOREncoder_encode_complex, METH_O,
"encode the specified complex *value* to the output"},
{"encode_boolean", (PyCFunction) CBOREncoder_encode_boolean, METH_O,
"encode the specified boolean *value* to the output"},
{"encode_none", (PyCFunction) CBOREncoder_encode_none, METH_O,
Expand Down
4 changes: 4 additions & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ PyObject *_CBOR2_str_FrozenDict = NULL;
PyObject *_CBOR2_str_fromordinal = NULL;
PyObject *_CBOR2_str_getvalue = NULL;
PyObject *_CBOR2_str_groups = NULL;
PyObject *_CBOR2_str_imag = NULL;
PyObject *_CBOR2_str_ip_address = NULL;
PyObject *_CBOR2_str_ip_network = NULL;
PyObject *_CBOR2_str_is_infinite = NULL;
Expand All @@ -637,6 +638,7 @@ PyObject *_CBOR2_str_parsestr = NULL;
PyObject *_CBOR2_str_pattern = NULL;
PyObject *_CBOR2_str_prefixlen = NULL;
PyObject *_CBOR2_str_read = NULL;
PyObject *_CBOR2_str_real = NULL;
PyObject *_CBOR2_str_s = NULL;
PyObject *_CBOR2_str_timestamp = NULL;
PyObject *_CBOR2_str_toordinal = NULL;
Expand Down Expand Up @@ -955,6 +957,7 @@ PyInit__cbor2(void)
INTERN_STRING(fromordinal);
INTERN_STRING(getvalue);
INTERN_STRING(groups);
INTERN_STRING(imag);
INTERN_STRING(ip_address);
INTERN_STRING(ip_network);
INTERN_STRING(is_infinite);
Expand All @@ -971,6 +974,7 @@ PyInit__cbor2(void)
INTERN_STRING(pattern);
INTERN_STRING(prefixlen);
INTERN_STRING(read);
INTERN_STRING(real);
INTERN_STRING(s);
INTERN_STRING(timestamp);
INTERN_STRING(toordinal);
Expand Down
2 changes: 2 additions & 0 deletions source/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern PyObject *_CBOR2_str_FrozenDict;
extern PyObject *_CBOR2_str_fromordinal;
extern PyObject *_CBOR2_str_getvalue;
extern PyObject *_CBOR2_str_groups;
extern PyObject *_CBOR2_str_imag;
extern PyObject *_CBOR2_str_ip_address;
extern PyObject *_CBOR2_str_ip_network;
extern PyObject *_CBOR2_str_is_infinite;
Expand All @@ -69,6 +70,7 @@ extern PyObject *_CBOR2_str_parsestr;
extern PyObject *_CBOR2_str_pattern;
extern PyObject *_CBOR2_str_prefixlen;
extern PyObject *_CBOR2_str_read;
extern PyObject *_CBOR2_str_real;
extern PyObject *_CBOR2_str_s;
extern PyObject *_CBOR2_str_timestamp;
extern PyObject *_CBOR2_str_toordinal;
Expand Down
29 changes: 29 additions & 0 deletions tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,35 @@ def test_bigfloat(impl):
assert decoded == Decimal("1.5")


@pytest.mark.parametrize(
"payload, expected",
[
("d9a7f882f90000f90000", 0.0j),
("d9a7f882fb0000000000000000fb0000000000000000", 0.0j),
("d9a7f882f98000f98000", -0.0j),
("d9a7f882f90000f93c00", 1.0j),
("d9a7f882fb0000000000000000fb3ff199999999999a", 1.1j),
("d9a7f882f93e00f93e00", 1.5 + 1.5j),
("d9a7f882f97bfff97bff", 65504.0 + 65504.0j),
("d9a7f882fa47c35000fa47c35000", 100000.0 + 100000.0j),
("fa7f7fffff", 3.4028234663852886e38),
("d9a7f882f90000fb7e37e43c8800759c", 1.0e300j),
("d9a7f882f90000f90001", 5.960464477539063e-8j),
("d9a7f882f90000f90400", 0.00006103515625j),
("d9a7f882f90000f9c400", -4.0j),
("d9a7f882f90000fbc010666666666666", -4.1j),
("d9a7f882f90000f97c00", complex(0.0, float("inf"))),
("d9a7f882f97c00f90000", complex(float("inf"), 0.0)),
("d9a7f882f90000f9fc00", complex(0.0, float("-inf"))),
("d9a7f882f90000fa7f800000", complex(0.0, float("inf"))),
("d9a7f882f90000faff800000", complex(0.0, float("-inf"))),
],
)
def test_complex(impl, payload, expected):
decoded = impl.loads(unhexlify(payload))
assert decoded == expected


def test_rational(impl):
decoded = impl.loads(unhexlify("d81e820205"))
assert decoded == Fraction(2, 5)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def test_decimal(impl, value, expected):
assert impl.dumps(value) == expected


@pytest.mark.parametrize(
"value, expected",
[
(3.1 + 2.1j, "d9a7f882fb4008cccccccccccdfb4000cccccccccccd"),
(1.0e300j, "d9a7f882fb0000000000000000fb7e37e43c8800759c"),
(0.0j, "d9a7f882fb0000000000000000fb0000000000000000"),
(complex(float("inf"), float("inf")), "d9a7f882f97c00f97c00"),
(complex(float("inf"), 0.0), "d9a7f882f97c00fb0000000000000000"),
(complex(float("nan"), float("inf")), "d9a7f882f97e00f97c00"),
],
)
def test_complex(impl, value, expected):
expected = unhexlify(expected)
assert impl.dumps(value) == expected


def test_rational(impl):
expected = unhexlify("d81e820205")
assert impl.dumps(Fraction(2, 5)) == expected
Expand Down

0 comments on commit 3ff08b3

Please sign in to comment.