diff --git a/cbor2/_decoder.py b/cbor2/_decoder.py index c8f1a8f..42a9740 100644 --- a/cbor2/_decoder.py +++ b/cbor2/_decoder.py @@ -623,6 +623,21 @@ def decode_sharedref(self) -> Any: else: return shared + def decode_complex(self) -> complex: + # Semantic tag 43000 + inputval = self._decode(immutable=True, unshared=True) + try: + value = complex(*inputval) + except TypeError as exc: + if not isinstance(inputval, tuple): + raise CBORDecodeValueError( + "error decoding complex: input value was not a tuple" + ) from None + + raise CBORDecodeValueError("error decoding complex") from exc + + return self.set_shareable(value) + def decode_rational(self) -> Fraction: # Semantic tag 30 from fractions import Fraction @@ -780,6 +795,7 @@ def decode_float64(self) -> float: 260: CBORDecoder.decode_ipaddress, 261: CBORDecoder.decode_ipnetwork, 1004: CBORDecoder.decode_date_string, + 43000: CBORDecoder.decode_complex, 55799: CBORDecoder.decode_self_describe_cbor, } diff --git a/cbor2/_encoder.py b/cbor2/_encoder.py index b92e098..0d3aa43 100644 --- a/cbor2/_encoder.py +++ b/cbor2/_encoder.py @@ -614,6 +614,11 @@ def encode_float(self, value: float) -> None: else: self._fp_write(struct.pack(">Bd", 0xFB, value)) + def encode_complex(self, value: complex) -> None: + # Semantic tag 43000 + with self.disable_value_sharing(): + self.encode_semantic(CBORTag(43000, [value.real, value.imag])) + def encode_minimal_float(self, value: float) -> None: # Handle special values efficiently if math.isnan(value): @@ -652,6 +657,7 @@ def encode_undefined(self, value: UndefinedType) -> None: str: CBOREncoder.encode_string, int: CBOREncoder.encode_int, float: CBOREncoder.encode_float, + complex: CBOREncoder.encode_complex, ("decimal", "Decimal"): CBOREncoder.encode_decimal, bool: CBOREncoder.encode_boolean, type(None): CBOREncoder.encode_none, diff --git a/docs/usage.rst b/docs/usage.rst index 797db59..127d206 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -99,6 +99,7 @@ Tag Semantics Python type(s) 258 Set of unique items set 260 Network address :class:`ipaddress.IPv4Address` (or IPv6) 261 Network prefix :class:`ipaddress.IPv4Network` (or IPv6) +43000 Single complex number complex 55799 Self-Described CBOR object ===== ======================================== ==================================================== diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index f2cdcbd..644b36b 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning `_. - Dropped support for Python 3.8 (#247 _; PR by @hugovk) +- Added complex number support (tag 43000) + (#249 _; PR by @chillenb) **5.6.5** (2024-10-09) diff --git a/source/decoder.c b/source/decoder.c index 918e22d..782dfbc 100644 --- a/source/decoder.c +++ b/source/decoder.c @@ -55,6 +55,7 @@ static PyObject * CBORDecoder_decode_epoch_date(CBORDecoderObject *); static PyObject * CBORDecoder_decode_date_string(CBORDecoderObject *); static PyObject * CBORDecoder_decode_fraction(CBORDecoderObject *); static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *); +static PyObject * CBORDecoder_decode_complex(CBORDecoderObject *); static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *); static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *); static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *); @@ -1172,6 +1173,7 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype) case 260: ret = CBORDecoder_decode_ipaddress(self); break; case 261: ret = CBORDecoder_decode_ipnetwork(self); break; case 1004: ret = CBORDecoder_decode_date_string(self); break; + case 43000: ret = CBORDecoder_decode_complex(self); break; case 55799: ret = CBORDecoder_decode_self_describe_cbor(self); break; @@ -1636,6 +1638,34 @@ CBORDecoder_decode_sharedref(CBORDecoderObject *self) return ret; } +// CBORDecoder.decode_complex(self) +static PyObject * +CBORDecoder_decode_complex(CBORDecoderObject *self) +{ + // semantic type 43000 + PyObject *payload_t, *real, *imag, *ret = NULL; + payload_t = decode(self, DECODE_IMMUTABLE | DECODE_UNSHARED); + if (payload_t) { + if (PyTuple_CheckExact(payload_t) && PyTuple_GET_SIZE(payload_t) == 2) { + real = PyTuple_GET_ITEM(payload_t, 0); + imag = PyTuple_GET_ITEM(payload_t, 1); + if(PyFloat_CheckExact(real) && PyFloat_CheckExact(imag)) { + ret = PyComplex_FromDoubles(PyFloat_AS_DOUBLE(real), PyFloat_AS_DOUBLE(imag)); + } else { + PyErr_Format( + _CBOR2_CBORDecodeValueError, + "Incorrect tag 43000 payload: does not contain two floats"); + } + } else { + PyErr_Format( + _CBOR2_CBORDecodeValueError, + "Incorrect tag 43000 payload: not an array of length 2"); + } + Py_DECREF(payload_t); + } + set_shareable(self, ret); + return ret; +} // CBORDecoder.decode_rational(self) static PyObject * @@ -2159,6 +2189,8 @@ static PyMethodDef CBORDecoder_methods[] = { "decode a fractional number from the input"}, {"decode_rational", (PyCFunction) CBORDecoder_decode_rational, METH_NOARGS, "decode a rational value from the input"}, + {"decode_complex", (PyCFunction) CBORDecoder_decode_complex, METH_NOARGS, + "decode a complex value from the input"}, {"decode_bigfloat", (PyCFunction) CBORDecoder_decode_bigfloat, METH_NOARGS, "decode a large floating-point value from the input"}, {"decode_regexp", (PyCFunction) CBORDecoder_decode_regexp, METH_NOARGS, diff --git a/source/encoder.c b/source/encoder.c index a0670aa..2f88326 100644 --- a/source/encoder.c +++ b/source/encoder.c @@ -1362,6 +1362,36 @@ CBOREncoder_encode_rational(CBOREncoderObject *self, PyObject *value) return ret; } +// CBOREncoder.encode_complex(self, value) +static PyObject * +CBOREncoder_encode_complex(CBOREncoderObject *self, PyObject *value) +{ + // semantic type 43000 + PyObject *tuple, *real, *imag, *ret = NULL; + bool sharing; + + real = PyObject_GetAttr(value, _CBOR2_str_real); + if (real) { + imag = PyObject_GetAttr(value, _CBOR2_str_imag); + if (imag) { + tuple = PyTuple_Pack(2, real, imag); + if (tuple) { + sharing = self->value_sharing; + self->value_sharing = false; + if (encode_semantic(self, 43000, tuple) == 0) { + Py_INCREF(Py_None); + ret = Py_None; + } + self->value_sharing = sharing; + Py_DECREF(tuple); + } + Py_DECREF(imag); + } + Py_DECREF(real); + } + return ret; +} + // CBOREncoder.encode_regexp(self, value) static PyObject * @@ -2118,6 +2148,8 @@ static PyMethodDef CBOREncoder_methods[] = { "encode the specified integer *value* to the output"}, {"encode_float", (PyCFunction) CBOREncoder_encode_float, METH_O, "encode the specified floating-point *value* to the output"}, + {"encode_complex", (PyCFunction) CBOREncoder_encode_complex, METH_O, + "encode the specified complex *value* to the output"}, {"encode_boolean", (PyCFunction) CBOREncoder_encode_boolean, METH_O, "encode the specified boolean *value* to the output"}, {"encode_none", (PyCFunction) CBOREncoder_encode_none, METH_O, diff --git a/source/module.c b/source/module.c index 47fc10c..e59bbc5 100644 --- a/source/module.c +++ b/source/module.c @@ -621,6 +621,7 @@ PyObject *_CBOR2_str_FrozenDict = NULL; PyObject *_CBOR2_str_fromordinal = NULL; PyObject *_CBOR2_str_getvalue = NULL; PyObject *_CBOR2_str_groups = NULL; +PyObject *_CBOR2_str_imag = NULL; PyObject *_CBOR2_str_ip_address = NULL; PyObject *_CBOR2_str_ip_network = NULL; PyObject *_CBOR2_str_is_infinite = NULL; @@ -637,6 +638,7 @@ PyObject *_CBOR2_str_parsestr = NULL; PyObject *_CBOR2_str_pattern = NULL; PyObject *_CBOR2_str_prefixlen = NULL; PyObject *_CBOR2_str_read = NULL; +PyObject *_CBOR2_str_real = NULL; PyObject *_CBOR2_str_s = NULL; PyObject *_CBOR2_str_timestamp = NULL; PyObject *_CBOR2_str_toordinal = NULL; @@ -955,6 +957,7 @@ PyInit__cbor2(void) INTERN_STRING(fromordinal); INTERN_STRING(getvalue); INTERN_STRING(groups); + INTERN_STRING(imag); INTERN_STRING(ip_address); INTERN_STRING(ip_network); INTERN_STRING(is_infinite); @@ -971,6 +974,7 @@ PyInit__cbor2(void) INTERN_STRING(pattern); INTERN_STRING(prefixlen); INTERN_STRING(read); + INTERN_STRING(real); INTERN_STRING(s); INTERN_STRING(timestamp); INTERN_STRING(toordinal); diff --git a/source/module.h b/source/module.h index 72bc6b5..f478386 100644 --- a/source/module.h +++ b/source/module.h @@ -53,6 +53,7 @@ extern PyObject *_CBOR2_str_FrozenDict; extern PyObject *_CBOR2_str_fromordinal; extern PyObject *_CBOR2_str_getvalue; extern PyObject *_CBOR2_str_groups; +extern PyObject *_CBOR2_str_imag; extern PyObject *_CBOR2_str_ip_address; extern PyObject *_CBOR2_str_ip_network; extern PyObject *_CBOR2_str_is_infinite; @@ -69,6 +70,7 @@ extern PyObject *_CBOR2_str_parsestr; extern PyObject *_CBOR2_str_pattern; extern PyObject *_CBOR2_str_prefixlen; extern PyObject *_CBOR2_str_read; +extern PyObject *_CBOR2_str_real; extern PyObject *_CBOR2_str_s; extern PyObject *_CBOR2_str_timestamp; extern PyObject *_CBOR2_str_toordinal; diff --git a/tests/test_decoder.py b/tests/test_decoder.py index 84ef1d7..77ded86 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -539,6 +539,35 @@ def test_bigfloat(impl): assert decoded == Decimal("1.5") +@pytest.mark.parametrize( + "payload, expected", + [ + ("d9a7f882f90000f90000", 0.0j), + ("d9a7f882fb0000000000000000fb0000000000000000", 0.0j), + ("d9a7f882f98000f98000", -0.0j), + ("d9a7f882f90000f93c00", 1.0j), + ("d9a7f882fb0000000000000000fb3ff199999999999a", 1.1j), + ("d9a7f882f93e00f93e00", 1.5 + 1.5j), + ("d9a7f882f97bfff97bff", 65504.0 + 65504.0j), + ("d9a7f882fa47c35000fa47c35000", 100000.0 + 100000.0j), + ("fa7f7fffff", 3.4028234663852886e38), + ("d9a7f882f90000fb7e37e43c8800759c", 1.0e300j), + ("d9a7f882f90000f90001", 5.960464477539063e-8j), + ("d9a7f882f90000f90400", 0.00006103515625j), + ("d9a7f882f90000f9c400", -4.0j), + ("d9a7f882f90000fbc010666666666666", -4.1j), + ("d9a7f882f90000f97c00", complex(0.0, float("inf"))), + ("d9a7f882f97c00f90000", complex(float("inf"), 0.0)), + ("d9a7f882f90000f9fc00", complex(0.0, float("-inf"))), + ("d9a7f882f90000fa7f800000", complex(0.0, float("inf"))), + ("d9a7f882f90000faff800000", complex(0.0, float("-inf"))), + ], +) +def test_complex(impl, payload, expected): + decoded = impl.loads(unhexlify(payload)) + assert decoded == expected + + def test_rational(impl): decoded = impl.loads(unhexlify("d81e820205")) assert decoded == Fraction(2, 5) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index f2ef248..94e2f24 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -324,6 +324,22 @@ def test_decimal(impl, value, expected): assert impl.dumps(value) == expected +@pytest.mark.parametrize( + "value, expected", + [ + (3.1 + 2.1j, "d9a7f882fb4008cccccccccccdfb4000cccccccccccd"), + (1.0e300j, "d9a7f882fb0000000000000000fb7e37e43c8800759c"), + (0.0j, "d9a7f882fb0000000000000000fb0000000000000000"), + (complex(float("inf"), float("inf")), "d9a7f882f97c00f97c00"), + (complex(float("inf"), 0.0), "d9a7f882f97c00fb0000000000000000"), + (complex(float("nan"), float("inf")), "d9a7f882f97e00f97c00"), + ], +) +def test_complex(impl, value, expected): + expected = unhexlify(expected) + assert impl.dumps(value) == expected + + def test_rational(impl): expected = unhexlify("d81e820205") assert impl.dumps(Fraction(2, 5)) == expected