diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f518c55..f1035363 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,6 +143,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update documentation index, quick start and guide pages (#67). ### Fixed +- Allow any MessagePack supported type as a request key (#240). ## 0.9.0 - 2022-06-20 diff --git a/tarantool/connection.py b/tarantool/connection.py index f971367e..c8ba9c3d 100644 --- a/tarantool/connection.py +++ b/tarantool/connection.py @@ -78,9 +78,9 @@ ) from tarantool.schema import Schema from tarantool.utils import ( - check_key, greeting_decode, version_id, + wrap_key, ENCODING_DEFAULT, ) @@ -1202,7 +1202,7 @@ def delete(self, space_name, key, *, index=0): .. _delete: https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_space/delete/ """ - key = check_key(key) + key = wrap_key(key) if isinstance(space_name, str): space_name = self.schema.get_space(space_name).sid if isinstance(index, str): @@ -1331,7 +1331,7 @@ def update(self, space_name, key, op_list, *, index=0): .. _update: https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_space/update/ """ - key = check_key(key) + key = wrap_key(key) if isinstance(space_name, str): space_name = self.schema.get_space(space_name).sid if isinstance(index, str): @@ -1516,7 +1516,7 @@ def select(self, space_name, key=None, *, offset=0, limit=0xffffffff, index=0, i # Perform smart type checking (scalar / list of scalars / list of # tuples) - key = check_key(key, select=True) + key = wrap_key(key, select=True) if isinstance(space_name, str): space_name = self.schema.get_space(space_name).sid diff --git a/tarantool/utils.py b/tarantool/utils.py index 3f8275ba..e5238707 100644 --- a/tarantool/utils.py +++ b/tarantool/utils.py @@ -1,8 +1,6 @@ import sys import uuid -supported_types = (int, str, bytes, float,) - ENCODING_DEFAULT = "utf-8" from base64 import decodebytes as base64_decode @@ -22,33 +20,30 @@ def strxor(rhs, lhs): return bytes([x ^ y for x, y in zip(rhs, lhs)]) -def check_key(*args, **kwargs): +def wrap_key(*args, first=True, select=False): """ - Validate request key types and map. + Wrap request key in list, if needed. :param args: Method args. :type args: :obj:`tuple` - :param kwargs: Method kwargs. - :type kwargs: :obj:`dict` + :param first: ``True`` if this is the first recursion iteration. + :type first: :obj:`bool` + + :param select: ``True`` if wrapping SELECT request key. + :type select: :obj:`bool` :rtype: :obj:`list` """ - if 'first' not in kwargs: - kwargs['first'] = True - if 'select' not in kwargs: - kwargs['select'] = False - if len(args) == 0 and kwargs['select']: + if len(args) == 0 and select: return [] if len(args) == 1: - if isinstance(args[0], (list, tuple)) and kwargs['first']: - kwargs['first'] = False - return check_key(*args[0], **kwargs) - elif args[0] is None and kwargs['select']: + if isinstance(args[0], (list, tuple)) and first: + return wrap_key(*args[0], first=False, select=select) + elif args[0] is None and select: return [] - for key in args: - assert isinstance(key, supported_types) + return list(args) diff --git a/test/suites/test_datetime.py b/test/suites/test_datetime.py index a6ce0341..8ecbc7f8 100644 --- a/test/suites/test_datetime.py +++ b/test/suites/test_datetime.py @@ -24,7 +24,7 @@ def setUpClass(self): self.adm = self.srv.admin self.adm(r""" - _, datetime = pcall(require, 'datetime') + is_supported, datetime = pcall(require, 'datetime') box.schema.space.create('test') box.space['test']:create_index('primary', { @@ -32,6 +32,14 @@ def setUpClass(self): parts = {1, 'string'}, unique = true}) + if is_supported then + box.schema.space.create('test_pk') + box.space['test_pk']:create_index('primary', { + type = 'tree', + parts = {1, 'datetime'}, + unique = true}) + end + box.schema.user.create('test', {password = 'test', if_not_exists = true}) box.schema.user.grant('test', 'read,write,execute', 'universe') @@ -528,6 +536,13 @@ def test_tarantool_datetime_addition_winter_time_switch(self): [case['res']]) + @skip_or_run_datetime_test + def test_primary_key(self): + data = [tarantool.Datetime(year=1970, month=1, day=1), 'content'] + + self.assertSequenceEqual(self.con.insert('test_pk', data), [data]) + self.assertSequenceEqual(self.con.select('test_pk', data[0]), [data]) + @classmethod def tearDownClass(self): self.con.close() diff --git a/test/suites/test_decimal.py b/test/suites/test_decimal.py index 4745669f..daf102a8 100644 --- a/test/suites/test_decimal.py +++ b/test/suites/test_decimal.py @@ -31,6 +31,12 @@ def setUpClass(self): parts = {1, 'string'}, unique = true}) + box.schema.space.create('test_pk') + box.space['test_pk']:create_index('primary', { + type = 'tree', + parts = {1, 'decimal'}, + unique = true}) + box.schema.user.create('test', {password = 'test', if_not_exists = true}) box.schema.user.grant('test', 'read,write,execute', 'universe') """) @@ -421,6 +427,14 @@ def test_tarantool_encode_with_precision_loss(self): self.assertSequenceEqual(self.con.eval(lua_eval), [True]) + @skip_or_run_decimal_test + def test_primary_key(self): + data = [decimal.Decimal('0'), 'content'] + + self.assertSequenceEqual(self.con.insert('test_pk', data), [data]) + self.assertSequenceEqual(self.con.select('test_pk', data[0]), [data]) + + @classmethod def tearDownClass(self): self.con.close() diff --git a/test/suites/test_uuid.py b/test/suites/test_uuid.py index a0203538..ac21795b 100644 --- a/test/suites/test_uuid.py +++ b/test/suites/test_uuid.py @@ -23,7 +23,7 @@ def setUpClass(self): self.adm = self.srv.admin self.adm(r""" - _, uuid = pcall(require, 'uuid') + is_supported, uuid = pcall(require, 'uuid') box.schema.space.create('test') box.space['test']:create_index('primary', { @@ -31,6 +31,14 @@ def setUpClass(self): parts = {1, 'string'}, unique = true}) + if is_supported then + box.schema.space.create('test_pk') + box.space['test_pk']:create_index('primary', { + type = 'tree', + parts = {1, 'uuid'}, + unique = true}) + end + box.schema.user.create('test', {password = 'test', if_not_exists = true}) box.schema.user.grant('test', 'read,write,execute', 'universe') """) @@ -125,6 +133,14 @@ def test_tarantool_encode(self): self.assertSequenceEqual(self.con.eval(lua_eval), [True]) + @skip_or_run_UUID_test + def test_primary_key(self): + data = [uuid.UUID('ae28d4f6-076c-49dd-8227-7f9fae9592d0'), 'content'] + + self.assertSequenceEqual(self.con.insert('test_pk', data), [data]) + self.assertSequenceEqual(self.con.select('test_pk', data[0]), [data]) + + @classmethod def tearDownClass(self): self.con.close()