diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c314cb1..226909d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Decimal type support (#203). +- UUID type support (#202). ### Changed - Bump msgpack requirement to 1.0.4 (PR #223). diff --git a/tarantool/msgpack_ext/packer.py b/tarantool/msgpack_ext/packer.py index db4aa710..e8dd74db 100644 --- a/tarantool/msgpack_ext/packer.py +++ b/tarantool/msgpack_ext/packer.py @@ -1,9 +1,17 @@ from decimal import Decimal +from uuid import UUID from msgpack import ExtType import tarantool.msgpack_ext.decimal as ext_decimal +import tarantool.msgpack_ext.uuid as ext_uuid + +encoders = [ + {'type': Decimal, 'ext': ext_decimal}, + {'type': UUID, 'ext': ext_uuid }, +] def default(obj): - if isinstance(obj, Decimal): - return ExtType(ext_decimal.EXT_ID, ext_decimal.encode(obj)) + for encoder in encoders: + if isinstance(obj, encoder['type']): + return ExtType(encoder['ext'].EXT_ID, encoder['ext'].encode(obj)) raise TypeError("Unknown type: %r" % (obj,)) diff --git a/tarantool/msgpack_ext/unpacker.py b/tarantool/msgpack_ext/unpacker.py index dd6c0112..44bfdb63 100644 --- a/tarantool/msgpack_ext/unpacker.py +++ b/tarantool/msgpack_ext/unpacker.py @@ -1,6 +1,12 @@ import tarantool.msgpack_ext.decimal as ext_decimal +import tarantool.msgpack_ext.uuid as ext_uuid + +decoders = { + ext_decimal.EXT_ID: ext_decimal.decode, + ext_uuid.EXT_ID : ext_uuid.decode , +} def ext_hook(code, data): - if code == ext_decimal.EXT_ID: - return ext_decimal.decode(data) + if code in decoders: + return decoders[code](data) raise NotImplementedError("Unknown msgpack type: %d" % (code,)) diff --git a/tarantool/msgpack_ext/uuid.py b/tarantool/msgpack_ext/uuid.py new file mode 100644 index 00000000..c489a3fc --- /dev/null +++ b/tarantool/msgpack_ext/uuid.py @@ -0,0 +1,17 @@ +from uuid import UUID + +# https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/#the-uuid-type +# +# The UUID MessagePack representation looks like this: +# +--------+------------+-----------------+ +# | MP_EXT | MP_UUID | UuidValue | +# | = d8 | = 2 | = 16-byte value | +# +--------+------------+-----------------+ + +EXT_ID = 2 + +def encode(obj): + return obj.bytes + +def decode(data): + return UUID(bytes=data) diff --git a/test/suites/__init__.py b/test/suites/__init__.py index 984665b6..94357c8e 100644 --- a/test/suites/__init__.py +++ b/test/suites/__init__.py @@ -15,14 +15,15 @@ from .test_dbapi import TestSuite_DBAPI from .test_encoding import TestSuite_Encoding from .test_ssl import TestSuite_Ssl -from .test_msgpack_ext import TestSuite_MsgpackExt +from .test_decimal import TestSuite_Decimal +from .test_uuid import TestSuite_UUID test_cases = (TestSuite_Schema_UnicodeConnection, TestSuite_Schema_BinaryConnection, TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect, TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI, TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl, - TestSuite_MsgpackExt) + TestSuite_Decimal, TestSuite_UUID) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/test/suites/lib/skip.py b/test/suites/lib/skip.py index 9ac5fe9e..9ce76991 100644 --- a/test/suites/lib/skip.py +++ b/test/suites/lib/skip.py @@ -143,3 +143,14 @@ def skip_or_run_decimal_test(func): return skip_or_run_test_pcall_require(func, 'decimal', 'does not support decimal type') + +def skip_or_run_UUID_test(func): + """Decorator to skip or run UUID-related tests depending on + the tarantool version. + + Tarantool supports UUID type only since 2.4.1 version. + See https://github.com/tarantool/tarantool/issues/4268 + """ + + return skip_or_run_test_tarantool(func, '2.4.1', + 'does not support UUID type') diff --git a/test/suites/test_msgpack_ext.py b/test/suites/test_decimal.py similarity index 86% rename from test/suites/test_msgpack_ext.py rename to test/suites/test_decimal.py index 42b937fe..ad633ddb 100644 --- a/test/suites/test_msgpack_ext.py +++ b/test/suites/test_decimal.py @@ -16,10 +16,10 @@ from .lib.skip import skip_or_run_decimal_test from tarantool.error import MsgpackError, MsgpackWarning -class TestSuite_MsgpackExt(unittest.TestCase): +class TestSuite_Decimal(unittest.TestCase): @classmethod def setUpClass(self): - print(' MSGPACK EXT TYPES '.center(70, '='), file=sys.stderr) + print(' DECIMAL EXT TYPE '.center(70, '='), file=sys.stderr) print('-' * 70, file=sys.stderr) self.srv = TarantoolServer() self.srv.script = 'test/suites/box.lua' @@ -50,7 +50,7 @@ def setUp(self): self.adm("box.space['test']:truncate()") - valid_decimal_cases = { + valid_cases = { 'simple_decimal_1': { 'python': decimal.Decimal('0.7'), 'msgpack': (b'\x01\x7c'), @@ -219,47 +219,47 @@ def setUp(self): }, } - def test_decimal_msgpack_decode(self): - for name in self.valid_decimal_cases.keys(): + def test_msgpack_decode(self): + for name in self.valid_cases.keys(): with self.subTest(msg=name): - decimal_case = self.valid_decimal_cases[name] + case = self.valid_cases[name] - self.assertEqual(unpacker_ext_hook(1, decimal_case['msgpack']), - decimal_case['python']) + self.assertEqual(unpacker_ext_hook(1, case['msgpack']), + case['python']) @skip_or_run_decimal_test - def test_decimal_tarantool_decode(self): - for name in self.valid_decimal_cases.keys(): + def test_tarantool_decode(self): + for name in self.valid_cases.keys(): with self.subTest(msg=name): - decimal_case = self.valid_decimal_cases[name] + case = self.valid_cases[name] - self.adm(f"box.space['test']:replace{{'{name}', {decimal_case['tarantool']}}}") + self.adm(f"box.space['test']:replace{{'{name}', {case['tarantool']}}}") self.assertSequenceEqual( self.con.select('test', name), - [[name, decimal_case['python']]]) + [[name, case['python']]]) - def test_decimal_msgpack_encode(self): - for name in self.valid_decimal_cases.keys(): + def test_msgpack_encode(self): + for name in self.valid_cases.keys(): with self.subTest(msg=name): - decimal_case = self.valid_decimal_cases[name] + case = self.valid_cases[name] - self.assertEqual(packer_default(decimal_case['python']), - msgpack.ExtType(code=1, data=decimal_case['msgpack'])) + self.assertEqual(packer_default(case['python']), + msgpack.ExtType(code=1, data=case['msgpack'])) @skip_or_run_decimal_test - def test_decimal_tarantool_encode(self): - for name in self.valid_decimal_cases.keys(): + def test_tarantool_encode(self): + for name in self.valid_cases.keys(): with self.subTest(msg=name): - decimal_case = self.valid_decimal_cases[name] + case = self.valid_cases[name] - self.con.insert('test', [name, decimal_case['python']]) + self.con.insert('test', [name, case['python']]) lua_eval = f""" local tuple = box.space['test']:get('{name}') assert(tuple ~= nil) - local dec = {decimal_case['tarantool']} + local dec = {case['tarantool']} if tuple[2] == dec then return true else @@ -271,7 +271,7 @@ def test_decimal_tarantool_encode(self): self.assertSequenceEqual(self.con.eval(lua_eval), [True]) - error_decimal_cases = { + error_cases = { 'decimal_limit_break_head_1': { 'python': decimal.Decimal('999999999999999999999999999999999999999'), }, @@ -298,31 +298,31 @@ def test_decimal_tarantool_encode(self): }, } - def test_decimal_msgpack_encode_error(self): - for name in self.error_decimal_cases.keys(): + def test_msgpack_encode_error(self): + for name in self.error_cases.keys(): with self.subTest(msg=name): - decimal_case = self.error_decimal_cases[name] + case = self.error_cases[name] msg = 'Decimal cannot be encoded: Tarantool decimal ' + \ 'supports a maximum of 38 digits.' self.assertRaisesRegex( MsgpackError, msg, - lambda: packer_default(decimal_case['python'])) + lambda: packer_default(case['python'])) @skip_or_run_decimal_test - def test_decimal_tarantool_encode_error(self): - for name in self.error_decimal_cases.keys(): + def test_tarantool_encode_error(self): + for name in self.error_cases.keys(): with self.subTest(msg=name): - decimal_case = self.error_decimal_cases[name] + case = self.error_cases[name] msg = 'Decimal cannot be encoded: Tarantool decimal ' + \ 'supports a maximum of 38 digits.' self.assertRaisesRegex( MsgpackError, msg, - lambda: self.con.insert('test', [name, decimal_case['python']])) + lambda: self.con.insert('test', [name, case['python']])) - precision_loss_decimal_cases = { + precision_loss_cases = { 'decimal_limit_break_tail_1': { 'python': decimal.Decimal('1.00000000000000000000000000000000000001'), 'msgpack': (b'\x00\x1c'), @@ -379,10 +379,10 @@ def test_decimal_tarantool_encode_error(self): }, } - def test_decimal_msgpack_encode_with_precision_loss(self): - for name in self.precision_loss_decimal_cases.keys(): + def test_msgpack_encode_with_precision_loss(self): + for name in self.precision_loss_cases.keys(): with self.subTest(msg=name): - decimal_case = self.precision_loss_decimal_cases[name] + case = self.precision_loss_cases[name] msg = 'Decimal encoded with loss of precision: ' + \ 'Tarantool decimal supports a maximum of 38 digits.' @@ -390,30 +390,30 @@ def test_decimal_msgpack_encode_with_precision_loss(self): self.assertWarnsRegex( MsgpackWarning, msg, lambda: self.assertEqual( - packer_default(decimal_case['python']), - msgpack.ExtType(code=1, data=decimal_case['msgpack']) + packer_default(case['python']), + msgpack.ExtType(code=1, data=case['msgpack']) ) ) @skip_or_run_decimal_test - def test_decimal_tarantool_encode_with_precision_loss(self): - for name in self.precision_loss_decimal_cases.keys(): + def test_tarantool_encode_with_precision_loss(self): + for name in self.precision_loss_cases.keys(): with self.subTest(msg=name): - decimal_case = self.precision_loss_decimal_cases[name] + case = self.precision_loss_cases[name] msg = 'Decimal encoded with loss of precision: ' + \ 'Tarantool decimal supports a maximum of 38 digits.' self.assertWarnsRegex( MsgpackWarning, msg, - lambda: self.con.insert('test', [name, decimal_case['python']])) + lambda: self.con.insert('test', [name, case['python']])) lua_eval = f""" local tuple = box.space['test']:get('{name}') assert(tuple ~= nil) - local dec = {decimal_case['tarantool']} + local dec = {case['tarantool']} if tuple[2] == dec then return true else @@ -424,6 +424,7 @@ def test_decimal_tarantool_encode_with_precision_loss(self): self.assertSequenceEqual(self.con.eval(lua_eval), [True]) + @classmethod def tearDownClass(self): self.con.close() diff --git a/test/suites/test_uuid.py b/test/suites/test_uuid.py new file mode 100644 index 00000000..c51c8c82 --- /dev/null +++ b/test/suites/test_uuid.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import sys +import unittest +import uuid +import msgpack +import warnings +import tarantool + +from tarantool.msgpack_ext.packer import default as packer_default +from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook + +from .lib.tarantool_server import TarantoolServer +from .lib.skip import skip_or_run_UUID_test +from tarantool.error import MsgpackError, MsgpackWarning + +class TestSuite_UUID(unittest.TestCase): + @classmethod + def setUpClass(self): + print(' UUID EXT TYPE '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + self.srv = TarantoolServer() + self.srv.script = 'test/suites/box.lua' + self.srv.start() + + self.adm = self.srv.admin + self.adm(r""" + _, uuid = pcall(require, 'uuid') + + box.schema.space.create('test') + box.space['test']:create_index('primary', { + type = 'tree', + parts = {1, 'string'}, + unique = true}) + + box.schema.user.create('test', {password = 'test', if_not_exists = true}) + box.schema.user.grant('test', 'read,write,execute', 'universe') + """) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test') + + def setUp(self): + # prevent a remote tarantool from clean our session + if self.srv.is_started(): + self.srv.touch_lock() + + self.adm("box.space['test']:truncate()") + + + cases = { + 'uuid_1': { + 'python': uuid.UUID('ae28d4f6-076c-49dd-8227-7f9fae9592d0'), + 'msgpack': (b'\xae\x28\xd4\xf6\x07\x6c\x49\xdd\x82\x27\x7f\x9f\xae\x95\x92\xd0'), + 'tarantool': "uuid.fromstr('ae28d4f6-076c-49dd-8227-7f9fae9592d0')", + }, + 'uuid_2': { + 'python': uuid.UUID('b3121301-9300-4038-a652-ead943fb9c39'), + 'msgpack': (b'\xb3\x12\x13\x01\x93\x00\x40\x38\xa6\x52\xea\xd9\x43\xfb\x9c\x39'), + 'tarantool': "uuid.fromstr('b3121301-9300-4038-a652-ead943fb9c39')", + }, + 'uuid_3': { + 'python': uuid.UUID('dfa69f02-92e6-44a5-abb5-84b39292ff93'), + 'msgpack': (b'\xdf\xa6\x9f\x02\x92\xe6\x44\xa5\xab\xb5\x84\xb3\x92\x92\xff\x93'), + 'tarantool': "uuid.fromstr('dfa69f02-92e6-44a5-abb5-84b39292ff93')", + }, + 'uuid_4': { + 'python': uuid.UUID('8b69a1ce-094a-4e21-a5dc-4cdae7cd8960'), + 'msgpack': (b'\x8b\x69\xa1\xce\x09\x4a\x4e\x21\xa5\xdc\x4c\xda\xe7\xcd\x89\x60'), + 'tarantool': "uuid.fromstr('8b69a1ce-094a-4e21-a5dc-4cdae7cd8960')", + }, + 'uuid_5': { + 'python': uuid.UUID('25932334-1d42-4686-9299-ec1a7165227c'), + 'msgpack': (b'\x25\x93\x23\x34\x1d\x42\x46\x86\x92\x99\xec\x1a\x71\x65\x22\x7c'), + 'tarantool': "uuid.fromstr('25932334-1d42-4686-9299-ec1a7165227c')", + }, + } + + def test_msgpack_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(unpacker_ext_hook(2, case['msgpack']), + case['python']) + + @skip_or_run_UUID_test + def test_tarantool_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.adm(f"box.space['test']:replace{{'{name}', {case['tarantool']}}}") + + self.assertSequenceEqual(self.con.select('test', name), + [[name, case['python']]]) + + def test_msgpack_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(packer_default(case['python']), + msgpack.ExtType(code=2, data=case['msgpack'])) + + @skip_or_run_UUID_test + def test_tarantool_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.con.insert('test', [name, case['python']]) + + lua_eval = f""" + local tuple = box.space['test']:get('{name}') + assert(tuple ~= nil) + + local id = {case['tarantool']} + if tuple[2] == id then + return true + else + return nil, ('%s is not equal to expected %s'):format( + tostring(tuple[2]), tostring(id)) + end + """ + + self.assertSequenceEqual(self.con.eval(lua_eval), [True]) + + + @classmethod + def tearDownClass(self): + self.con.close() + self.srv.stop() + self.srv.clean()