diff --git a/cbor2/_decoder.py b/cbor2/_decoder.py
index c8f1a8f..42a9740 100644
--- a/cbor2/_decoder.py
+++ b/cbor2/_decoder.py
@@ -623,6 +623,21 @@ def decode_sharedref(self) -> Any:
else:
return shared
+ def decode_complex(self) -> complex:
+ # Semantic tag 43000
+ inputval = self._decode(immutable=True, unshared=True)
+ try:
+ value = complex(*inputval)
+ except TypeError as exc:
+ if not isinstance(inputval, tuple):
+ raise CBORDecodeValueError(
+ "error decoding complex: input value was not a tuple"
+ ) from None
+
+ raise CBORDecodeValueError("error decoding complex") from exc
+
+ return self.set_shareable(value)
+
def decode_rational(self) -> Fraction:
# Semantic tag 30
from fractions import Fraction
@@ -780,6 +795,7 @@ def decode_float64(self) -> float:
260: CBORDecoder.decode_ipaddress,
261: CBORDecoder.decode_ipnetwork,
1004: CBORDecoder.decode_date_string,
+ 43000: CBORDecoder.decode_complex,
55799: CBORDecoder.decode_self_describe_cbor,
}
diff --git a/cbor2/_encoder.py b/cbor2/_encoder.py
index b92e098..0d3aa43 100644
--- a/cbor2/_encoder.py
+++ b/cbor2/_encoder.py
@@ -614,6 +614,11 @@ def encode_float(self, value: float) -> None:
else:
self._fp_write(struct.pack(">Bd", 0xFB, value))
+ def encode_complex(self, value: complex) -> None:
+ # Semantic tag 43000
+ with self.disable_value_sharing():
+ self.encode_semantic(CBORTag(43000, [value.real, value.imag]))
+
def encode_minimal_float(self, value: float) -> None:
# Handle special values efficiently
if math.isnan(value):
@@ -652,6 +657,7 @@ def encode_undefined(self, value: UndefinedType) -> None:
str: CBOREncoder.encode_string,
int: CBOREncoder.encode_int,
float: CBOREncoder.encode_float,
+ complex: CBOREncoder.encode_complex,
("decimal", "Decimal"): CBOREncoder.encode_decimal,
bool: CBOREncoder.encode_boolean,
type(None): CBOREncoder.encode_none,
diff --git a/docs/usage.rst b/docs/usage.rst
index 797db59..127d206 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -99,6 +99,7 @@ Tag Semantics Python type(s)
258 Set of unique items set
260 Network address :class:`ipaddress.IPv4Address` (or IPv6)
261 Network prefix :class:`ipaddress.IPv4Network` (or IPv6)
+43000 Single complex number complex
55799 Self-Described CBOR object
===== ======================================== ====================================================
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index f2cdcbd..644b36b 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning `_.
- Dropped support for Python 3.8
(#247 _; PR by @hugovk)
+- Added complex number support (tag 43000)
+ (#249 _; PR by @chillenb)
**5.6.5** (2024-10-09)
diff --git a/source/decoder.c b/source/decoder.c
index 918e22d..782dfbc 100644
--- a/source/decoder.c
+++ b/source/decoder.c
@@ -55,6 +55,7 @@ static PyObject * CBORDecoder_decode_epoch_date(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_date_string(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_fraction(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
+static PyObject * CBORDecoder_decode_complex(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
@@ -1172,6 +1173,7 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 1004: ret = CBORDecoder_decode_date_string(self); break;
+ case 43000: ret = CBORDecoder_decode_complex(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
break;
@@ -1636,6 +1638,34 @@ CBORDecoder_decode_sharedref(CBORDecoderObject *self)
return ret;
}
+// CBORDecoder.decode_complex(self)
+static PyObject *
+CBORDecoder_decode_complex(CBORDecoderObject *self)
+{
+ // semantic type 43000
+ PyObject *payload_t, *real, *imag, *ret = NULL;
+ payload_t = decode(self, DECODE_IMMUTABLE | DECODE_UNSHARED);
+ if (payload_t) {
+ if (PyTuple_CheckExact(payload_t) && PyTuple_GET_SIZE(payload_t) == 2) {
+ real = PyTuple_GET_ITEM(payload_t, 0);
+ imag = PyTuple_GET_ITEM(payload_t, 1);
+ if(PyFloat_CheckExact(real) && PyFloat_CheckExact(imag)) {
+ ret = PyComplex_FromDoubles(PyFloat_AS_DOUBLE(real), PyFloat_AS_DOUBLE(imag));
+ } else {
+ PyErr_Format(
+ _CBOR2_CBORDecodeValueError,
+ "Incorrect tag 43000 payload: does not contain two floats");
+ }
+ } else {
+ PyErr_Format(
+ _CBOR2_CBORDecodeValueError,
+ "Incorrect tag 43000 payload: not an array of length 2");
+ }
+ Py_DECREF(payload_t);
+ }
+ set_shareable(self, ret);
+ return ret;
+}
// CBORDecoder.decode_rational(self)
static PyObject *
@@ -2159,6 +2189,8 @@ static PyMethodDef CBORDecoder_methods[] = {
"decode a fractional number from the input"},
{"decode_rational", (PyCFunction) CBORDecoder_decode_rational, METH_NOARGS,
"decode a rational value from the input"},
+ {"decode_complex", (PyCFunction) CBORDecoder_decode_complex, METH_NOARGS,
+ "decode a complex value from the input"},
{"decode_bigfloat", (PyCFunction) CBORDecoder_decode_bigfloat, METH_NOARGS,
"decode a large floating-point value from the input"},
{"decode_regexp", (PyCFunction) CBORDecoder_decode_regexp, METH_NOARGS,
diff --git a/source/encoder.c b/source/encoder.c
index a0670aa..2f88326 100644
--- a/source/encoder.c
+++ b/source/encoder.c
@@ -1362,6 +1362,36 @@ CBOREncoder_encode_rational(CBOREncoderObject *self, PyObject *value)
return ret;
}
+// CBOREncoder.encode_complex(self, value)
+static PyObject *
+CBOREncoder_encode_complex(CBOREncoderObject *self, PyObject *value)
+{
+ // semantic type 43000
+ PyObject *tuple, *real, *imag, *ret = NULL;
+ bool sharing;
+
+ real = PyObject_GetAttr(value, _CBOR2_str_real);
+ if (real) {
+ imag = PyObject_GetAttr(value, _CBOR2_str_imag);
+ if (imag) {
+ tuple = PyTuple_Pack(2, real, imag);
+ if (tuple) {
+ sharing = self->value_sharing;
+ self->value_sharing = false;
+ if (encode_semantic(self, 43000, tuple) == 0) {
+ Py_INCREF(Py_None);
+ ret = Py_None;
+ }
+ self->value_sharing = sharing;
+ Py_DECREF(tuple);
+ }
+ Py_DECREF(imag);
+ }
+ Py_DECREF(real);
+ }
+ return ret;
+}
+
// CBOREncoder.encode_regexp(self, value)
static PyObject *
@@ -2118,6 +2148,8 @@ static PyMethodDef CBOREncoder_methods[] = {
"encode the specified integer *value* to the output"},
{"encode_float", (PyCFunction) CBOREncoder_encode_float, METH_O,
"encode the specified floating-point *value* to the output"},
+ {"encode_complex", (PyCFunction) CBOREncoder_encode_complex, METH_O,
+ "encode the specified complex *value* to the output"},
{"encode_boolean", (PyCFunction) CBOREncoder_encode_boolean, METH_O,
"encode the specified boolean *value* to the output"},
{"encode_none", (PyCFunction) CBOREncoder_encode_none, METH_O,
diff --git a/source/module.c b/source/module.c
index 47fc10c..e59bbc5 100644
--- a/source/module.c
+++ b/source/module.c
@@ -621,6 +621,7 @@ PyObject *_CBOR2_str_FrozenDict = NULL;
PyObject *_CBOR2_str_fromordinal = NULL;
PyObject *_CBOR2_str_getvalue = NULL;
PyObject *_CBOR2_str_groups = NULL;
+PyObject *_CBOR2_str_imag = NULL;
PyObject *_CBOR2_str_ip_address = NULL;
PyObject *_CBOR2_str_ip_network = NULL;
PyObject *_CBOR2_str_is_infinite = NULL;
@@ -637,6 +638,7 @@ PyObject *_CBOR2_str_parsestr = NULL;
PyObject *_CBOR2_str_pattern = NULL;
PyObject *_CBOR2_str_prefixlen = NULL;
PyObject *_CBOR2_str_read = NULL;
+PyObject *_CBOR2_str_real = NULL;
PyObject *_CBOR2_str_s = NULL;
PyObject *_CBOR2_str_timestamp = NULL;
PyObject *_CBOR2_str_toordinal = NULL;
@@ -955,6 +957,7 @@ PyInit__cbor2(void)
INTERN_STRING(fromordinal);
INTERN_STRING(getvalue);
INTERN_STRING(groups);
+ INTERN_STRING(imag);
INTERN_STRING(ip_address);
INTERN_STRING(ip_network);
INTERN_STRING(is_infinite);
@@ -971,6 +974,7 @@ PyInit__cbor2(void)
INTERN_STRING(pattern);
INTERN_STRING(prefixlen);
INTERN_STRING(read);
+ INTERN_STRING(real);
INTERN_STRING(s);
INTERN_STRING(timestamp);
INTERN_STRING(toordinal);
diff --git a/source/module.h b/source/module.h
index 72bc6b5..f478386 100644
--- a/source/module.h
+++ b/source/module.h
@@ -53,6 +53,7 @@ extern PyObject *_CBOR2_str_FrozenDict;
extern PyObject *_CBOR2_str_fromordinal;
extern PyObject *_CBOR2_str_getvalue;
extern PyObject *_CBOR2_str_groups;
+extern PyObject *_CBOR2_str_imag;
extern PyObject *_CBOR2_str_ip_address;
extern PyObject *_CBOR2_str_ip_network;
extern PyObject *_CBOR2_str_is_infinite;
@@ -69,6 +70,7 @@ extern PyObject *_CBOR2_str_parsestr;
extern PyObject *_CBOR2_str_pattern;
extern PyObject *_CBOR2_str_prefixlen;
extern PyObject *_CBOR2_str_read;
+extern PyObject *_CBOR2_str_real;
extern PyObject *_CBOR2_str_s;
extern PyObject *_CBOR2_str_timestamp;
extern PyObject *_CBOR2_str_toordinal;
diff --git a/tests/test_decoder.py b/tests/test_decoder.py
index 84ef1d7..77ded86 100644
--- a/tests/test_decoder.py
+++ b/tests/test_decoder.py
@@ -539,6 +539,35 @@ def test_bigfloat(impl):
assert decoded == Decimal("1.5")
+@pytest.mark.parametrize(
+ "payload, expected",
+ [
+ ("d9a7f882f90000f90000", 0.0j),
+ ("d9a7f882fb0000000000000000fb0000000000000000", 0.0j),
+ ("d9a7f882f98000f98000", -0.0j),
+ ("d9a7f882f90000f93c00", 1.0j),
+ ("d9a7f882fb0000000000000000fb3ff199999999999a", 1.1j),
+ ("d9a7f882f93e00f93e00", 1.5 + 1.5j),
+ ("d9a7f882f97bfff97bff", 65504.0 + 65504.0j),
+ ("d9a7f882fa47c35000fa47c35000", 100000.0 + 100000.0j),
+ ("fa7f7fffff", 3.4028234663852886e38),
+ ("d9a7f882f90000fb7e37e43c8800759c", 1.0e300j),
+ ("d9a7f882f90000f90001", 5.960464477539063e-8j),
+ ("d9a7f882f90000f90400", 0.00006103515625j),
+ ("d9a7f882f90000f9c400", -4.0j),
+ ("d9a7f882f90000fbc010666666666666", -4.1j),
+ ("d9a7f882f90000f97c00", complex(0.0, float("inf"))),
+ ("d9a7f882f97c00f90000", complex(float("inf"), 0.0)),
+ ("d9a7f882f90000f9fc00", complex(0.0, float("-inf"))),
+ ("d9a7f882f90000fa7f800000", complex(0.0, float("inf"))),
+ ("d9a7f882f90000faff800000", complex(0.0, float("-inf"))),
+ ],
+)
+def test_complex(impl, payload, expected):
+ decoded = impl.loads(unhexlify(payload))
+ assert decoded == expected
+
+
def test_rational(impl):
decoded = impl.loads(unhexlify("d81e820205"))
assert decoded == Fraction(2, 5)
diff --git a/tests/test_encoder.py b/tests/test_encoder.py
index f2ef248..94e2f24 100644
--- a/tests/test_encoder.py
+++ b/tests/test_encoder.py
@@ -324,6 +324,22 @@ def test_decimal(impl, value, expected):
assert impl.dumps(value) == expected
+@pytest.mark.parametrize(
+ "value, expected",
+ [
+ (3.1 + 2.1j, "d9a7f882fb4008cccccccccccdfb4000cccccccccccd"),
+ (1.0e300j, "d9a7f882fb0000000000000000fb7e37e43c8800759c"),
+ (0.0j, "d9a7f882fb0000000000000000fb0000000000000000"),
+ (complex(float("inf"), float("inf")), "d9a7f882f97c00f97c00"),
+ (complex(float("inf"), 0.0), "d9a7f882f97c00fb0000000000000000"),
+ (complex(float("nan"), float("inf")), "d9a7f882f97e00f97c00"),
+ ],
+)
+def test_complex(impl, value, expected):
+ expected = unhexlify(expected)
+ assert impl.dumps(value) == expected
+
+
def test_rational(impl):
expected = unhexlify("d81e820205")
assert impl.dumps(Fraction(2, 5)) == expected