Skip to content

Commit

Permalink
msgpack: support UUID extended type
Browse files Browse the repository at this point in the history
Tarantool supports UUID type since version 2.4.1 [1]. This patch
introduced the support of Tarantool UUID type in msgpack decoders and
encoders. The Tarantool UUID type is mapped to the native Python
uuid.UUID type.

1. tarantool/tarantool#4268

Closed #202
  • Loading branch information
DifferentialOrange committed Sep 7, 2022
1 parent b0ed8f0 commit c70dfa6
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Decimal type support (#203).
- UUID type support (#202).

### Changed
- Bump msgpack requirement to 1.0.4 (PR #223).
Expand Down
12 changes: 10 additions & 2 deletions tarantool/msgpack_ext/packer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from decimal import Decimal
from uuid import UUID
from msgpack import ExtType

import tarantool.msgpack_ext.decimal as ext_decimal
import tarantool.msgpack_ext.uuid as ext_uuid

encoders = [
{'type': Decimal, 'ext': ext_decimal},
{'type': UUID, 'ext': ext_uuid },
]

def default(obj):
if isinstance(obj, Decimal):
return ExtType(ext_decimal.EXT_ID, ext_decimal.encode(obj))
for encoder in encoders:
if isinstance(obj, encoder['type']):
return ExtType(encoder['ext'].EXT_ID, encoder['ext'].encode(obj))
raise TypeError("Unknown type: %r" % (obj,))
10 changes: 8 additions & 2 deletions tarantool/msgpack_ext/unpacker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import tarantool.msgpack_ext.decimal as ext_decimal
import tarantool.msgpack_ext.uuid as ext_uuid

decoders = {
ext_decimal.EXT_ID: ext_decimal.decode,
ext_uuid.EXT_ID : ext_uuid.decode ,
}

def ext_hook(code, data):
if code == ext_decimal.EXT_ID:
return ext_decimal.decode(data)
if code in decoders:
return decoders[code](data)
raise NotImplementedError("Unknown msgpack type: %d" % (code,))
17 changes: 17 additions & 0 deletions tarantool/msgpack_ext/uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from uuid import UUID

# https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/#the-uuid-type
#
# The UUID MessagePack representation looks like this:
# +--------+------------+-----------------+
# | MP_EXT | MP_UUID | UuidValue |
# | = d8 | = 2 | = 16-byte value |
# +--------+------------+-----------------+

EXT_ID = 2

def encode(obj):
return obj.bytes

def decode(data):
return UUID(bytes=data)
5 changes: 3 additions & 2 deletions test/suites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from .test_dbapi import TestSuite_DBAPI
from .test_encoding import TestSuite_Encoding
from .test_ssl import TestSuite_Ssl
from .test_msgpack_ext import TestSuite_MsgpackExt
from .test_decimal import TestSuite_Decimal
from .test_uuid import TestSuite_UUID

test_cases = (TestSuite_Schema_UnicodeConnection,
TestSuite_Schema_BinaryConnection,
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI,
TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl,
TestSuite_MsgpackExt)
TestSuite_Decimal, TestSuite_UUID)

def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
Expand Down
11 changes: 11 additions & 0 deletions test/suites/lib/skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,14 @@ def skip_or_run_decimal_test(func):

return skip_or_run_test_pcall_require(func, 'decimal',
'does not support decimal type')

def skip_or_run_UUID_test(func):
"""Decorator to skip or run UUID-related tests depending on
the tarantool version.
Tarantool supports UUID type only since 2.4.1 version.
See https://github.com/tarantool/tarantool/issues/4268
"""

return skip_or_run_test_tarantool(func, '2.4.1',
'does not support UUID type')
87 changes: 44 additions & 43 deletions test/suites/test_msgpack_ext.py → test/suites/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from .lib.skip import skip_or_run_decimal_test
from tarantool.error import MsgpackError, MsgpackWarning

class TestSuite_MsgpackExt(unittest.TestCase):
class TestSuite_Decimal(unittest.TestCase):
@classmethod
def setUpClass(self):
print(' MSGPACK EXT TYPES '.center(70, '='), file=sys.stderr)
print(' DECIMAL EXT TYPE '.center(70, '='), file=sys.stderr)
print('-' * 70, file=sys.stderr)
self.srv = TarantoolServer()
self.srv.script = 'test/suites/box.lua'
Expand Down Expand Up @@ -50,7 +50,7 @@ def setUp(self):
self.adm("box.space['test']:truncate()")


valid_decimal_cases = {
valid_cases = {
'simple_decimal_1': {
'python': decimal.Decimal('0.7'),
'msgpack': (b'\x01\x7c'),
Expand Down Expand Up @@ -219,47 +219,47 @@ def setUp(self):
},
}

def test_decimal_msgpack_decode(self):
for name in self.valid_decimal_cases.keys():
def test_msgpack_decode(self):
for name in self.valid_cases.keys():
with self.subTest(msg=name):
decimal_case = self.valid_decimal_cases[name]
case = self.valid_cases[name]

self.assertEqual(unpacker_ext_hook(1, decimal_case['msgpack']),
decimal_case['python'])
self.assertEqual(unpacker_ext_hook(1, case['msgpack']),
case['python'])

@skip_or_run_decimal_test
def test_decimal_tarantool_decode(self):
for name in self.valid_decimal_cases.keys():
def test_tarantool_decode(self):
for name in self.valid_cases.keys():
with self.subTest(msg=name):
decimal_case = self.valid_decimal_cases[name]
case = self.valid_cases[name]

self.adm(f"box.space['test']:replace{{'{name}', {decimal_case['tarantool']}}}")
self.adm(f"box.space['test']:replace{{'{name}', {case['tarantool']}}}")

self.assertSequenceEqual(
self.con.select('test', name),
[[name, decimal_case['python']]])
[[name, case['python']]])

def test_decimal_msgpack_encode(self):
for name in self.valid_decimal_cases.keys():
def test_msgpack_encode(self):
for name in self.valid_cases.keys():
with self.subTest(msg=name):
decimal_case = self.valid_decimal_cases[name]
case = self.valid_cases[name]

self.assertEqual(packer_default(decimal_case['python']),
msgpack.ExtType(code=1, data=decimal_case['msgpack']))
self.assertEqual(packer_default(case['python']),
msgpack.ExtType(code=1, data=case['msgpack']))

@skip_or_run_decimal_test
def test_decimal_tarantool_encode(self):
for name in self.valid_decimal_cases.keys():
def test_tarantool_encode(self):
for name in self.valid_cases.keys():
with self.subTest(msg=name):
decimal_case = self.valid_decimal_cases[name]
case = self.valid_cases[name]

self.con.insert('test', [name, decimal_case['python']])
self.con.insert('test', [name, case['python']])

lua_eval = f"""
local tuple = box.space['test']:get('{name}')
assert(tuple ~= nil)
local dec = {decimal_case['tarantool']}
local dec = {case['tarantool']}
if tuple[2] == dec then
return true
else
Expand All @@ -271,7 +271,7 @@ def test_decimal_tarantool_encode(self):
self.assertSequenceEqual(self.con.eval(lua_eval), [True])


error_decimal_cases = {
error_cases = {
'decimal_limit_break_head_1': {
'python': decimal.Decimal('999999999999999999999999999999999999999'),
},
Expand All @@ -298,31 +298,31 @@ def test_decimal_tarantool_encode(self):
},
}

def test_decimal_msgpack_encode_error(self):
for name in self.error_decimal_cases.keys():
def test_msgpack_encode_error(self):
for name in self.error_cases.keys():
with self.subTest(msg=name):
decimal_case = self.error_decimal_cases[name]
case = self.error_cases[name]

msg = 'Decimal cannot be encoded: Tarantool decimal ' + \
'supports a maximum of 38 digits.'
self.assertRaisesRegex(
MsgpackError, msg,
lambda: packer_default(decimal_case['python']))
lambda: packer_default(case['python']))

@skip_or_run_decimal_test
def test_decimal_tarantool_encode_error(self):
for name in self.error_decimal_cases.keys():
def test_tarantool_encode_error(self):
for name in self.error_cases.keys():
with self.subTest(msg=name):
decimal_case = self.error_decimal_cases[name]
case = self.error_cases[name]

msg = 'Decimal cannot be encoded: Tarantool decimal ' + \
'supports a maximum of 38 digits.'
self.assertRaisesRegex(
MsgpackError, msg,
lambda: self.con.insert('test', [name, decimal_case['python']]))
lambda: self.con.insert('test', [name, case['python']]))


precision_loss_decimal_cases = {
precision_loss_cases = {
'decimal_limit_break_tail_1': {
'python': decimal.Decimal('1.00000000000000000000000000000000000001'),
'msgpack': (b'\x00\x1c'),
Expand Down Expand Up @@ -379,41 +379,41 @@ def test_decimal_tarantool_encode_error(self):
},
}

def test_decimal_msgpack_encode_with_precision_loss(self):
for name in self.precision_loss_decimal_cases.keys():
def test_msgpack_encode_with_precision_loss(self):
for name in self.precision_loss_cases.keys():
with self.subTest(msg=name):
decimal_case = self.precision_loss_decimal_cases[name]
case = self.precision_loss_cases[name]

msg = 'Decimal encoded with loss of precision: ' + \
'Tarantool decimal supports a maximum of 38 digits.'

self.assertWarnsRegex(
MsgpackWarning, msg,
lambda: self.assertEqual(
packer_default(decimal_case['python']),
msgpack.ExtType(code=1, data=decimal_case['msgpack'])
packer_default(case['python']),
msgpack.ExtType(code=1, data=case['msgpack'])
)
)


@skip_or_run_decimal_test
def test_decimal_tarantool_encode_with_precision_loss(self):
for name in self.precision_loss_decimal_cases.keys():
def test_tarantool_encode_with_precision_loss(self):
for name in self.precision_loss_cases.keys():
with self.subTest(msg=name):
decimal_case = self.precision_loss_decimal_cases[name]
case = self.precision_loss_cases[name]

msg = 'Decimal encoded with loss of precision: ' + \
'Tarantool decimal supports a maximum of 38 digits.'

self.assertWarnsRegex(
MsgpackWarning, msg,
lambda: self.con.insert('test', [name, decimal_case['python']]))
lambda: self.con.insert('test', [name, case['python']]))

lua_eval = f"""
local tuple = box.space['test']:get('{name}')
assert(tuple ~= nil)
local dec = {decimal_case['tarantool']}
local dec = {case['tarantool']}
if tuple[2] == dec then
return true
else
Expand All @@ -424,6 +424,7 @@ def test_decimal_tarantool_encode_with_precision_loss(self):

self.assertSequenceEqual(self.con.eval(lua_eval), [True])


@classmethod
def tearDownClass(self):
self.con.close()
Expand Down
Loading

0 comments on commit c70dfa6

Please sign in to comment.