Skip to content

Commit

Permalink
Work-in-progress: Adds support for typed arrays from RFC8746
Browse files Browse the repository at this point in the history
This currently only has support for little-endian 64-bit floats, but the
idea is there.
  • Loading branch information
tgockel committed May 25, 2021
1 parent 6f8311d commit 98c252f
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 15 deletions.
22 changes: 22 additions & 0 deletions cbor2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,27 @@ def decode_uuid(self):
from uuid import UUID
return self.set_shareable(UUID(bytes=self._decode()))

def _decode_typed_array_impl(self, tag, element_size, format):
"""Helper function for decoding typed arrays described by RFC 8746"""
buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid major type for tag %r" % tag)
elif len(buf) % element_size != 0:
raise CBORDecodeValueError(
"invalid length for tag %r -- must be multiple of element size %r, but is %r"
% tag % element_size % len(buf))

out = struct.unpack(format % (len(buf) // element_size), buf)

if self._immutable:
return self.set_shareable(out)
else:
return self.set_shareable(list(out))

def decode_array_float64_le(self):
# Semantic tag 85
return self._decode_typed_array_impl(85, 8, '<%id')

def decode_set(self):
# Semantic tag 258
if self._immutable:
Expand Down Expand Up @@ -535,6 +556,7 @@ def decode_float64(self):
35: CBORDecoder.decode_regexp,
36: CBORDecoder.decode_mime,
37: CBORDecoder.decode_uuid,
85: CBORDecoder.decode_array_float64_le,
258: CBORDecoder.decode_set,
260: CBORDecoder.decode_ipaddress,
261: CBORDecoder.decode_ipnetwork,
Expand Down
99 changes: 84 additions & 15 deletions source/decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
#define be16toh(x) OSSwapBigToHostInt16(x)
#define be32toh(x) OSSwapBigToHostInt32(x)
#define be64toh(x) OSSwapBigToHostInt64(x)
#define le16toh(x) OSSwapLittleToHostInt16(x)
#define le32toh(x) OSSwapLittleToHostInt32(x)
#define le64toh(x) OSSwapLittleToHostInt64(x)
#elif _WIN32
// All windows platforms are (currently) little-endian so byteswap is required
#define be16toh(x) _byteswap_ushort(x)
#define be32toh(x) _byteswap_ulong(x)
#define be64toh(x) _byteswap_uint64(x)
#define le16toh(x) (x)
#define le32toh(x) (x)
#define le64toh(x) (x)
#endif

enum DecodeOption {
Expand All @@ -52,6 +58,7 @@ static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float64_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_mime(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_positive_bignum(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_negative_bignum(CBORDecoderObject *);
Expand Down Expand Up @@ -891,21 +898,22 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)

if (decode_length(self, subtype, &tagnum, NULL) == 0) {
switch (tagnum) {
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 85: ret = CBORDecoder_decode_array_float64_le(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
break;

Expand Down Expand Up @@ -1326,6 +1334,65 @@ CBORDecoder_decode_uuid(CBORDecoderObject *self)
}


// CBORDecoder.decode_array_float64_le
static PyObject *
CBORDecoder_decode_array_float64_le(CBORDecoderObject *self)
{
// semantic type 85
PyObject *bytes, *list, *ret = NULL;
Py_ssize_t bytes_size, element_count, element_idx;
char *bytes_direct;
size_t element_size = 8U;

bytes = decode(self, DECODE_UNSHARED);
if (bytes) {
if (PyBytes_CheckExact(bytes)) {
bytes_size = PyBytes_GET_SIZE(bytes);
if (bytes_size % element_size == 0) {
element_count = bytes_size / element_size;
list = PyList_New(element_count);
if (list) {
set_shareable(self, list);
bytes_direct = PyBytes_AS_STRING(bytes);

for (element_idx = 0; element_idx < element_count; bytes_direct += element_size, ++element_idx) {
uint64_t i_repr;
double value;

memcpy(&i_repr, bytes_direct, sizeof i_repr);
i_repr = le64toh(i_repr);
memcpy(&value, &i_repr, sizeof value);

PyList_SET_ITEM(list, element_idx, PyFloat_FromDouble(value));
}

if (self->immutable) {
ret = PyList_AsTuple(list);
if (ret) {
Py_DECREF(list);
set_shareable(self, ret);
}
} else {
ret = list;
}
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid float64 typed array %R", bytes);
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid float64 typed array %R", bytes);
}
Py_DECREF(bytes);
}

return ret;
}


// CBORDecoder.decode_set(self)
static PyObject *
CBORDecoder_decode_set(CBORDecoderObject *self)
Expand Down Expand Up @@ -1738,6 +1805,8 @@ static PyMethodDef CBORDecoder_methods[] = {
"decode a shareable value from the input"},
{"decode_sharedref", (PyCFunction) CBORDecoder_decode_sharedref, METH_NOARGS,
"decode a shared reference from the input"},
{"decode_array_float64_le", (PyCFunction) CBORDecoder_decode_array_float64_le, METH_NOARGS,
"decode a typed array of little-endian double-precision floating-point values"},
{"decode_set", (PyCFunction) CBORDecoder_decode_set, METH_NOARGS,
"decode a set or frozenset from the input"},
{"decode_ipaddress", (PyCFunction) CBORDecoder_decode_ipaddress, METH_NOARGS,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,15 @@ def test_huge_truncated_bytes(impl):
def test_huge_truncated_string(impl):
with pytest.raises((impl.CBORDecodeEOF, MemoryError)):
impl.loads(unhexlify('7B37388519251ae9ca'))


@pytest.mark.parametrize('payload, expected', [
('d8555820000000000000f83f0000000000000440000000000000234000000000000014c0',
[1.5, 2.5, 9.5, -5.0])
])
def test_typed_array_float64_le(impl, payload, expected):
decoded = impl.loads(unhexlify(payload))
print(impl)
print('DECODED: ')
print(repr(decoded))
assert decoded == expected

0 comments on commit 98c252f

Please sign in to comment.