Skip to content

Commit

Permalink
added binary support in dictionaries via base64 encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominik Andreas committed Jan 22, 2024
1 parent 1fb1687 commit f46d6b6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 16 deletions.
45 changes: 29 additions & 16 deletions capnp/lib/capnp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import array
import asyncio
import collections as _collections
import contextlib
import base64
import enum as _enum
import inspect as _inspect
import os as _os
Expand Down Expand Up @@ -957,17 +958,17 @@ cdef _DynamicStructBuilder temp_msg_b
cdef _DynamicStructReader temp_msg_r


cdef _to_dict(msg, bint verbose, bint ordered):
cdef _to_dict(msg, bint verbose, bint ordered, bint encode_bytes_as_base64=False):
msg_type = type(msg)
if msg_type is _DynamicListBuilder:
temp_list_b = msg
return [_to_dict(temp_list_b._get(i), verbose, ordered) for i in range(len(msg))]
return [_to_dict(temp_list_b._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))]
elif msg_type is _DynamicListReader:
temp_list_r = msg
return [_to_dict(temp_list_r._get(i), verbose, ordered) for i in range(len(msg))]
return [_to_dict(temp_list_r._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))]
elif msg_type is _DynamicResizableListBuilder:
temp_list_rb = msg
return [_to_dict(temp_list_rb._get(i), verbose, ordered) for i in range(len(msg))]
return [_to_dict(temp_list_rb._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))]

if msg_type is _DynamicStructBuilder or isinstance(msg, _Request):
temp_msg_b = msg
Expand All @@ -977,13 +978,13 @@ cdef _to_dict(msg, bint verbose, bint ordered):
ret = {}
try:
which = temp_msg_b.which()
ret[which] = _to_dict(temp_msg_b._get(which), verbose, ordered)
ret[which] = _to_dict(temp_msg_b._get(which), verbose, ordered, encode_bytes_as_base64)
except KjException:
pass

for field in temp_msg_b.schema.non_union_fields:
if verbose or temp_msg_b._has(field):
ret[field] = _to_dict(temp_msg_b._get(field), verbose, ordered)
ret[field] = _to_dict(temp_msg_b._get(field), verbose, ordered, encode_bytes_as_base64)

return ret
elif msg_type is _DynamicStructReader or isinstance(msg, _Response):
Expand All @@ -994,13 +995,13 @@ cdef _to_dict(msg, bint verbose, bint ordered):
ret = {}
try:
which = temp_msg_r.which()
ret[which] = _to_dict(temp_msg_r._get(which), verbose, ordered)
ret[which] = _to_dict(temp_msg_r._get(which), verbose, ordered, encode_bytes_as_base64)
except KjException:
pass

for field in temp_msg_r.schema.non_union_fields:
if verbose or temp_msg_r._has(field):
ret[field] = _to_dict(temp_msg_r._get(field), verbose, ordered)
ret[field] = _to_dict(temp_msg_r._get(field), verbose, ordered, encode_bytes_as_base64)

return ret

Expand All @@ -1010,6 +1011,10 @@ cdef _to_dict(msg, bint verbose, bint ordered):
if msg_type is _DynamicEnum:
return str(msg)

if encode_bytes_as_base64 and msg_type is bytes:
# encode the message as base64 and return utf-8 string
return base64.b64encode(msg).decode('utf-8')

return msg


Expand Down Expand Up @@ -1220,8 +1225,8 @@ cdef class _DynamicStructReader:
def __repr__(self):
return '<%s reader %s>' % (self.schema.node.displayName, <char*>strStructReader(self.thisptr).cStr())

def to_dict(self, verbose=False, ordered=False):
return _to_dict(self, verbose, ordered)
def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False):
return _to_dict(self, verbose, ordered, encode_bytes_as_base64)

cpdef as_builder(self, num_first_segment_words=None):
"""A method for casting this Reader to a Builder
Expand Down Expand Up @@ -1598,12 +1603,20 @@ cdef class _DynamicStructBuilder:
def __repr__(self):
return '<%s builder %s>' % (self.schema.node.displayName, <char*>strStructBuilder(self.thisptr).cStr())

def to_dict(self, verbose=False, ordered=False):
return _to_dict(self, verbose, ordered)
def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False):
return _to_dict(self, verbose, ordered, encode_bytes_as_base64)

def from_dict(self, dict d):
for key, val in d.iteritems():
if key != 'which':
field = self.schema.fields.get(key)
print("attribute: ", field)
if isinstance(val, str):
dtype = field.proto.slot.type.which()
print("dtype: ", dtype)
if dtype == "data":
# decode bytes from utf-8 base64 encoding
val = base64.b64decode(val)
try:
self._set(key, val)
except Exception as e:
Expand Down Expand Up @@ -1683,8 +1696,8 @@ cdef class _DynamicStructPipeline:
# def __repr__(self):
# return '<%s reader %s>' % (self.schema.node.displayName, strStructReader(self.thisptr).cStr())

def to_dict(self, verbose=False, ordered=False):
return _to_dict(self, verbose, ordered)
def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False):
return _to_dict(self, verbose, ordered, encode_bytes_as_base64)


cdef class _DynamicOrphan:
Expand Down Expand Up @@ -2065,8 +2078,8 @@ cdef class _RemotePromise:
def __dir__(self):
return list(set(self.schema.fieldnames + tuple(dir(self.__class__))))

def to_dict(self, verbose=False, ordered=False):
return _to_dict(self, verbose, ordered)
def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False):
return _to_dict(self, verbose, ordered, encode_bytes_as_base64)

cpdef cancel(self) except +reraise_kj_exception:
self.thisptr = Own[RemotePromise]()
Expand Down
21 changes: 21 additions & 0 deletions test/test_blob_to_dict_base64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import capnp
import base64
import pytest

this_dir = os.path.dirname(__file__)


@pytest.fixture
def blob_schema():
return capnp.load(os.path.join(this_dir, "blob_test.capnp"))


def test_blob_to_dict(blob_schema):
blob_value = b"hello world"
blob = blob_schema.BlobTest(blob=blob_value)
blob_dict = blob.to_dict(encode_bytes_as_base64=True)
assert base64.b64decode(blob_dict["blob"]) == blob_value
msg = blob_schema.BlobTest.new_message()
msg.from_dict(blob_dict)
assert blob.blob == blob_value

0 comments on commit f46d6b6

Please sign in to comment.