diff --git a/CHANGELOG.md b/CHANGELOG.md index ad7c2cb6..c0658007 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,6 +67,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 You may use `tz` property to get timezone name of a datetime object. +- Datetime interval type support and tarantool.Interval type (#229). + + Tarantool datetime interval objects are decoded to `tarantool.Interval` + type. `tarantool.Interval` may be encoded to Tarantool interval + objects. + + You can create `tarantool.Interval` objects either from msgpack + data or by using the same API as in Tarantool: + + ```python + di = tarantool.Interval(year=-1, month=2, day=3, + hour=4, minute=-5, sec=6, + nsec=308543321, + adjust=tarantool.IntervalAdjust.NONE) + ``` + + Its attributes (same as in init API) are exposed, so you can + use them if needed. + ### Changed - Bump msgpack requirement to 1.0.4 (PR #223). The only reason of this bump is various vulnerability fixes, diff --git a/tarantool/__init__.py b/tarantool/__init__.py index 6625b4eb..c0b50cfc 100644 --- a/tarantool/__init__.py +++ b/tarantool/__init__.py @@ -36,6 +36,11 @@ Datetime, ) +from tarantool.msgpack_ext.types.interval import ( + Adjust as IntervalAdjust, + Interval, +) + __version__ = "0.9.0" @@ -95,7 +100,7 @@ def connectmesh(addrs=({'host': 'localhost', 'port': 3301},), user=None, __all__ = ['connect', 'Connection', 'connectmesh', 'MeshConnection', 'Schema', 'Error', 'DatabaseError', 'NetworkError', 'NetworkWarning', - 'SchemaError', 'dbapi', 'Datetime'] + 'SchemaError', 'dbapi', 'Datetime', 'Interval', 'IntervalAdjust'] # ConnectionPool is supported only for Python 3.7 or newer. if sys.version_info.major >= 3 and sys.version_info.minor >= 7: diff --git a/tarantool/msgpack_ext/interval.py b/tarantool/msgpack_ext/interval.py new file mode 100644 index 00000000..79b5a8de --- /dev/null +++ b/tarantool/msgpack_ext/interval.py @@ -0,0 +1,9 @@ +from tarantool.msgpack_ext.types.interval import Interval + +EXT_ID = 6 + +def encode(obj): + return obj.msgpack_encode() + +def decode(data): + return Interval(data) diff --git a/tarantool/msgpack_ext/packer.py b/tarantool/msgpack_ext/packer.py index bff2b821..d41c411d 100644 --- a/tarantool/msgpack_ext/packer.py +++ b/tarantool/msgpack_ext/packer.py @@ -3,15 +3,18 @@ from msgpack import ExtType from tarantool.msgpack_ext.types.datetime import Datetime +from tarantool.msgpack_ext.types.interval import Interval import tarantool.msgpack_ext.decimal as ext_decimal import tarantool.msgpack_ext.uuid as ext_uuid import tarantool.msgpack_ext.datetime as ext_datetime +import tarantool.msgpack_ext.interval as ext_interval encoders = [ {'type': Decimal, 'ext': ext_decimal }, {'type': UUID, 'ext': ext_uuid }, {'type': Datetime, 'ext': ext_datetime}, + {'type': Interval, 'ext': ext_interval}, ] def default(obj): diff --git a/tarantool/msgpack_ext/types/interval.py b/tarantool/msgpack_ext/types/interval.py new file mode 100644 index 00000000..61cbdc27 --- /dev/null +++ b/tarantool/msgpack_ext/types/interval.py @@ -0,0 +1,149 @@ +import msgpack +from enum import Enum + +from tarantool.error import MsgpackError + +# https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/#the-interval-type +# +# The interval MessagePack representation looks like this: +# +--------+-------------------------+-------------+----------------+ +# | MP_EXT | Size of packed interval | MP_INTERVAL | PackedInterval | +# +--------+-------------------------+-------------+----------------+ +# Packed interval consists of: +# - Packed number of non-zero fields. +# - Packed non-null fields. +# +# Each packed field has the following structure: +# +----------+=====================+ +# | field ID | field value | +# +----------+=====================+ +# +# The number of defined (non-null) fields can be zero. In this case, +# the packed interval will be encoded as integer 0. +# +# List of the field IDs: +# - 0 – year +# - 1 – month +# - 2 – week +# - 3 – day +# - 4 – hour +# - 5 – minute +# - 6 – second +# - 7 – nanosecond +# - 8 – adjust + +id_map = { + 0: 'year', + 1: 'month', + 2: 'week', + 3: 'day', + 4: 'hour', + 5: 'minute', + 6: 'sec', + 7: 'nsec', + 8: 'adjust', +} + +# https://github.com/tarantool/c-dt/blob/cec6acebb54d9e73ea0b99c63898732abd7683a6/dt_arithmetic.h#L34 +class Adjust(Enum): + EXCESS = 0 # DT_EXCESS in c-dt, "excess" in Tarantool + NONE = 1 # DT_LIMIT in c-dt, "none" in Tarantool + LAST = 2 # DT_SNAP in c-dt, "last" in Tarantool + +class Interval(): + def __init__(self, data=None, *, year=0, month=0, week=0, + day=0, hour=0, minute=0, sec=0, + nsec=0, adjust=Adjust.NONE): + # If msgpack data does not contain a field value, it is zero. + # If built not from msgpack data, set argument values later. + self.year = 0 + self.month = 0 + self.week = 0 + self.day = 0 + self.hour = 0 + self.minute = 0 + self.sec = 0 + self.nsec = 0 + self.adjust = Adjust(0) + + if data is not None: + if len(data) == 0: + return + + # To create an unpacker is the only way to parse + # a sequence of values in Python msgpack module. + unpacker = msgpack.Unpacker() + unpacker.feed(data) + field_count = unpacker.unpack() + for _ in range(field_count): + field_id = unpacker.unpack() + value = unpacker.unpack() + + if field_id not in id_map: + raise MsgpackError(f'Unknown interval field id {field_id}') + + field_name = id_map[field_id] + + if field_name == 'adjust': + try: + value = Adjust(value) + except ValueError as e: + raise MsgpackError(e) + + setattr(self, id_map[field_id], value) + else: + self.year = year + self.month = month + self.week = week + self.day = day + self.hour = hour + self.minute = minute + self.sec = sec + self.nsec = nsec + self.adjust = adjust + + def __eq__(self, other): + if not isinstance(other, Interval): + return False + + # Tarantool interval compare is naive too + # + # Tarantool 2.10.1-0-g482d91c66 + # + # tarantool> datetime.interval.new{hour=1} == datetime.interval.new{min=60} + # --- + # - false + # ... + + for field_id in id_map.keys(): + field_name = id_map[field_id] + if getattr(self, field_name) != getattr(other, field_name): + return False + + return True + + def __repr__(self): + return f'tarantool.Interval(year={self.year}, month={self.month}, day={self.day}, ' + \ + f'hour={self.hour}, minute={self.minute}, sec={self.sec}, ' + \ + f'nsec={self.nsec}, adjust={self.adjust})' + + __str__ = __repr__ + + def msgpack_encode(self): + buf = bytes() + + count = 0 + for field_id in id_map.keys(): + field_name = id_map[field_id] + value = getattr(self, field_name) + + if field_name == 'adjust': + value = value.value + + if value != 0: + buf = buf + msgpack.packb(field_id) + msgpack.packb(value) + count = count + 1 + + buf = msgpack.packb(count) + buf + + return buf diff --git a/tarantool/msgpack_ext/unpacker.py b/tarantool/msgpack_ext/unpacker.py index b303e18d..ff3bdcb8 100644 --- a/tarantool/msgpack_ext/unpacker.py +++ b/tarantool/msgpack_ext/unpacker.py @@ -1,11 +1,13 @@ import tarantool.msgpack_ext.decimal as ext_decimal import tarantool.msgpack_ext.uuid as ext_uuid import tarantool.msgpack_ext.datetime as ext_datetime +import tarantool.msgpack_ext.interval as ext_interval decoders = { ext_decimal.EXT_ID : ext_decimal.decode , ext_uuid.EXT_ID : ext_uuid.decode , ext_datetime.EXT_ID: ext_datetime.decode, + ext_interval.EXT_ID: ext_interval.decode, } def ext_hook(code, data): diff --git a/test/suites/__init__.py b/test/suites/__init__.py index c5792bdd..7096cad9 100644 --- a/test/suites/__init__.py +++ b/test/suites/__init__.py @@ -18,13 +18,15 @@ from .test_decimal import TestSuite_Decimal from .test_uuid import TestSuite_UUID from .test_datetime import TestSuite_Datetime +from .test_interval import TestSuite_Interval test_cases = (TestSuite_Schema_UnicodeConnection, TestSuite_Schema_BinaryConnection, TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect, TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI, TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl, - TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime) + TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime, + TestSuite_Interval) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/test/suites/test_interval.py b/test/suites/test_interval.py new file mode 100644 index 00000000..9ed87652 --- /dev/null +++ b/test/suites/test_interval.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import sys +import unittest +import msgpack +import warnings +import tarantool +import pandas +import pytz + +from tarantool.msgpack_ext.packer import default as packer_default +from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook + +from .lib.tarantool_server import TarantoolServer +from .lib.skip import skip_or_run_datetime_test +from tarantool.error import MsgpackError + +class TestSuite_Interval(unittest.TestCase): + @classmethod + def setUpClass(self): + print(' INTERVAL EXT TYPE '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + self.srv = TarantoolServer() + self.srv.script = 'test/suites/box.lua' + self.srv.start() + + self.adm = self.srv.admin + self.adm(r""" + _, datetime = pcall(require, 'datetime') + + box.schema.space.create('test') + box.space['test']:create_index('primary', { + type = 'tree', + parts = {1, 'string'}, + unique = true}) + + box.schema.user.create('test', {password = 'test', if_not_exists = true}) + box.schema.user.grant('test', 'read,write,execute', 'universe') + """) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test') + + def setUp(self): + # prevent a remote tarantool from clean our session + if self.srv.is_started(): + self.srv.touch_lock() + + self.adm("box.space['test']:truncate()") + + + cases = { + 'year': { + 'python': tarantool.Interval(year=1), + 'msgpack': (b'\x02\x00\x01\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1})", + }, + 'big_year': { + 'python': tarantool.Interval(year=1000), + 'msgpack': (b'\x02\x00\xcd\x03\xe8\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1000})", + }, + 'date': { + 'python': tarantool.Interval(year=1, month=2, day=3), + 'msgpack': (b'\x04\x00\x01\x01\x02\x03\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3})", + }, + 'big_month_date': { + 'python': tarantool.Interval(year=1, month=100000, day=3), + 'msgpack': (b'\x04\x00\x01\x01\xce\x00\x01\x86\xa0\x03\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=100000, day=3})", + }, + 'time': { + 'python': tarantool.Interval(hour=1, minute=2, sec=3), + 'msgpack': (b'\x04\x04\x01\x05\x02\x06\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({hour=1, min=2, sec=3})", + }, + 'big_seconds_time': { + 'python': tarantool.Interval(hour=1, minute=2, sec=3000), + 'msgpack': (b'\x04\x04\x01\x05\x02\x06\xcd\x0b\xb8\x08\x01'), + 'tarantool': r"datetime.interval.new({hour=1, min=2, sec=3000})", + }, + 'datetime': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, sec=3000), + 'msgpack': (b'\x07\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, min=2, sec=3000})", + }, + 'nanoseconds': { + 'python': tarantool.Interval(nsec=10000000), + 'msgpack': (b'\x02\x07\xce\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({nsec=10000000})", + }, + 'datetime_with_nanoseconds': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000})", + }, + 'datetime_none_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.NONE), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='none'})", + }, + 'datetime_excess_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.EXCESS), + 'msgpack': (b'\x07\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='excess'})", + }, + 'datetime_last_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.LAST), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x02'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='last'})", + }, + 'all_zeroes': { + 'python': tarantool.Interval(adjust=tarantool.IntervalAdjust.EXCESS), + 'msgpack': (b'\x00'), + 'tarantool': r"datetime.interval.new({adjust='excess'})", + }, + } + + def test_msgpack_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(unpacker_ext_hook(6, case['msgpack']), + case['python']) + + @skip_or_run_datetime_test + def test_tarantool_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.adm(f"box.space['test']:replace{{'{name}', {case['tarantool']}, 'field'}}") + + self.assertSequenceEqual(self.con.select('test', name), + [[name, case['python'], 'field']]) + + def test_msgpack_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(packer_default(case['python']), + msgpack.ExtType(code=6, data=case['msgpack'])) + + @skip_or_run_datetime_test + def test_tarantool_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.con.insert('test', [name, case['python'], 'field']) + + lua_eval = f""" + local interval = {case['tarantool']} + + local tuple = box.space['test']:get('{name}') + assert(tuple ~= nil) + + if tuple[2] == interval then + return true + else + return nil, ('%s is not equal to expected %s'):format( + tostring(tuple[2]), tostring(interval)) + end + """ + + self.assertSequenceEqual(self.adm(lua_eval), [True]) + + + def test_unknown_field_decode(self): + case = b'\x01\x09\xce\x00\x98\x96\x80' + self.assertRaisesRegex( + MsgpackError, 'Unknown interval field id 9', + lambda: unpacker_ext_hook(6, case)) + + def test_unknown_adjust_decode(self): + case = b'\x02\x07\xce\x00\x98\x96\x80\x08\x03' + self.assertRaisesRegex( + MsgpackError, '3 is not a valid Adjust', + lambda: unpacker_ext_hook(6, case)) + + + @classmethod + def tearDownClass(self): + self.con.close() + self.srv.stop() + self.srv.clean()