From e40470e2145736d40cc5a2caa7874e24a8e55da2 Mon Sep 17 00:00:00 2001 From: Georgy Moiseev Date: Thu, 8 Sep 2022 14:09:38 +0300 Subject: [PATCH] msgpack: support datetime interval extended type Tarantool supports datetime interval type since version 2.10.0 [1]. This patch introduced the support of Tarantool interval type in msgpack decoders and encoders. Tarantool datetime interval objects are decoded to `tarantool.Interval` type. `tarantool.Interval` may be encoded to Tarantool interval objects. You can create `tarantool.Interval` objects either from msgpack data or by using the same API as in Tarantool: ``` di = tarantool.Interval(year=-1, month=2, day=3, hour=4, minute=-5, sec=6, nsec=308543321, adjust=tarantool.IntervalAdjust.NONE) ``` Its attributes (same as in init API) are exposed, so you can use them if needed. datetime, numpy and pandas tools doesn't seem to be sufficient to cover all adjust cases supported by Tarantool. This patch does not yet introduce the support of datetime interval arithmetic. 1. https://github.com/tarantool/tarantool/issues/5941 Part of #229 --- CHANGELOG.md | 19 ++ tarantool/__init__.py | 7 +- tarantool/msgpack_ext/interval.py | 9 + tarantool/msgpack_ext/packer.py | 3 + tarantool/msgpack_ext/types/interval.py | 149 +++++++++++++++ tarantool/msgpack_ext/unpacker.py | 2 + test/suites/__init__.py | 4 +- test/suites/test_interval.py | 233 ++++++++++++++++++++++++ 8 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 tarantool/msgpack_ext/interval.py create mode 100644 tarantool/msgpack_ext/types/interval.py create mode 100644 test/suites/test_interval.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ad7c2cb6..c0658007 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,6 +67,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 You may use `tz` property to get timezone name of a datetime object. +- Datetime interval type support and tarantool.Interval type (#229). + + Tarantool datetime interval objects are decoded to `tarantool.Interval` + type. `tarantool.Interval` may be encoded to Tarantool interval + objects. + + You can create `tarantool.Interval` objects either from msgpack + data or by using the same API as in Tarantool: + + ```python + di = tarantool.Interval(year=-1, month=2, day=3, + hour=4, minute=-5, sec=6, + nsec=308543321, + adjust=tarantool.IntervalAdjust.NONE) + ``` + + Its attributes (same as in init API) are exposed, so you can + use them if needed. + ### Changed - Bump msgpack requirement to 1.0.4 (PR #223). The only reason of this bump is various vulnerability fixes, diff --git a/tarantool/__init__.py b/tarantool/__init__.py index 6625b4eb..c0b50cfc 100644 --- a/tarantool/__init__.py +++ b/tarantool/__init__.py @@ -36,6 +36,11 @@ Datetime, ) +from tarantool.msgpack_ext.types.interval import ( + Adjust as IntervalAdjust, + Interval, +) + __version__ = "0.9.0" @@ -95,7 +100,7 @@ def connectmesh(addrs=({'host': 'localhost', 'port': 3301},), user=None, __all__ = ['connect', 'Connection', 'connectmesh', 'MeshConnection', 'Schema', 'Error', 'DatabaseError', 'NetworkError', 'NetworkWarning', - 'SchemaError', 'dbapi', 'Datetime'] + 'SchemaError', 'dbapi', 'Datetime', 'Interval', 'IntervalAdjust'] # ConnectionPool is supported only for Python 3.7 or newer. if sys.version_info.major >= 3 and sys.version_info.minor >= 7: diff --git a/tarantool/msgpack_ext/interval.py b/tarantool/msgpack_ext/interval.py new file mode 100644 index 00000000..79b5a8de --- /dev/null +++ b/tarantool/msgpack_ext/interval.py @@ -0,0 +1,9 @@ +from tarantool.msgpack_ext.types.interval import Interval + +EXT_ID = 6 + +def encode(obj): + return obj.msgpack_encode() + +def decode(data): + return Interval(data) diff --git a/tarantool/msgpack_ext/packer.py b/tarantool/msgpack_ext/packer.py index bff2b821..d41c411d 100644 --- a/tarantool/msgpack_ext/packer.py +++ b/tarantool/msgpack_ext/packer.py @@ -3,15 +3,18 @@ from msgpack import ExtType from tarantool.msgpack_ext.types.datetime import Datetime +from tarantool.msgpack_ext.types.interval import Interval import tarantool.msgpack_ext.decimal as ext_decimal import tarantool.msgpack_ext.uuid as ext_uuid import tarantool.msgpack_ext.datetime as ext_datetime +import tarantool.msgpack_ext.interval as ext_interval encoders = [ {'type': Decimal, 'ext': ext_decimal }, {'type': UUID, 'ext': ext_uuid }, {'type': Datetime, 'ext': ext_datetime}, + {'type': Interval, 'ext': ext_interval}, ] def default(obj): diff --git a/tarantool/msgpack_ext/types/interval.py b/tarantool/msgpack_ext/types/interval.py new file mode 100644 index 00000000..61cbdc27 --- /dev/null +++ b/tarantool/msgpack_ext/types/interval.py @@ -0,0 +1,149 @@ +import msgpack +from enum import Enum + +from tarantool.error import MsgpackError + +# https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/#the-interval-type +# +# The interval MessagePack representation looks like this: +# +--------+-------------------------+-------------+----------------+ +# | MP_EXT | Size of packed interval | MP_INTERVAL | PackedInterval | +# +--------+-------------------------+-------------+----------------+ +# Packed interval consists of: +# - Packed number of non-zero fields. +# - Packed non-null fields. +# +# Each packed field has the following structure: +# +----------+=====================+ +# | field ID | field value | +# +----------+=====================+ +# +# The number of defined (non-null) fields can be zero. In this case, +# the packed interval will be encoded as integer 0. +# +# List of the field IDs: +# - 0 – year +# - 1 – month +# - 2 – week +# - 3 – day +# - 4 – hour +# - 5 – minute +# - 6 – second +# - 7 – nanosecond +# - 8 – adjust + +id_map = { + 0: 'year', + 1: 'month', + 2: 'week', + 3: 'day', + 4: 'hour', + 5: 'minute', + 6: 'sec', + 7: 'nsec', + 8: 'adjust', +} + +# https://github.com/tarantool/c-dt/blob/cec6acebb54d9e73ea0b99c63898732abd7683a6/dt_arithmetic.h#L34 +class Adjust(Enum): + EXCESS = 0 # DT_EXCESS in c-dt, "excess" in Tarantool + NONE = 1 # DT_LIMIT in c-dt, "none" in Tarantool + LAST = 2 # DT_SNAP in c-dt, "last" in Tarantool + +class Interval(): + def __init__(self, data=None, *, year=0, month=0, week=0, + day=0, hour=0, minute=0, sec=0, + nsec=0, adjust=Adjust.NONE): + # If msgpack data does not contain a field value, it is zero. + # If built not from msgpack data, set argument values later. + self.year = 0 + self.month = 0 + self.week = 0 + self.day = 0 + self.hour = 0 + self.minute = 0 + self.sec = 0 + self.nsec = 0 + self.adjust = Adjust(0) + + if data is not None: + if len(data) == 0: + return + + # To create an unpacker is the only way to parse + # a sequence of values in Python msgpack module. + unpacker = msgpack.Unpacker() + unpacker.feed(data) + field_count = unpacker.unpack() + for _ in range(field_count): + field_id = unpacker.unpack() + value = unpacker.unpack() + + if field_id not in id_map: + raise MsgpackError(f'Unknown interval field id {field_id}') + + field_name = id_map[field_id] + + if field_name == 'adjust': + try: + value = Adjust(value) + except ValueError as e: + raise MsgpackError(e) + + setattr(self, id_map[field_id], value) + else: + self.year = year + self.month = month + self.week = week + self.day = day + self.hour = hour + self.minute = minute + self.sec = sec + self.nsec = nsec + self.adjust = adjust + + def __eq__(self, other): + if not isinstance(other, Interval): + return False + + # Tarantool interval compare is naive too + # + # Tarantool 2.10.1-0-g482d91c66 + # + # tarantool> datetime.interval.new{hour=1} == datetime.interval.new{min=60} + # --- + # - false + # ... + + for field_id in id_map.keys(): + field_name = id_map[field_id] + if getattr(self, field_name) != getattr(other, field_name): + return False + + return True + + def __repr__(self): + return f'tarantool.Interval(year={self.year}, month={self.month}, day={self.day}, ' + \ + f'hour={self.hour}, minute={self.minute}, sec={self.sec}, ' + \ + f'nsec={self.nsec}, adjust={self.adjust})' + + __str__ = __repr__ + + def msgpack_encode(self): + buf = bytes() + + count = 0 + for field_id in id_map.keys(): + field_name = id_map[field_id] + value = getattr(self, field_name) + + if field_name == 'adjust': + value = value.value + + if value != 0: + buf = buf + msgpack.packb(field_id) + msgpack.packb(value) + count = count + 1 + + buf = msgpack.packb(count) + buf + + return buf diff --git a/tarantool/msgpack_ext/unpacker.py b/tarantool/msgpack_ext/unpacker.py index b303e18d..ff3bdcb8 100644 --- a/tarantool/msgpack_ext/unpacker.py +++ b/tarantool/msgpack_ext/unpacker.py @@ -1,11 +1,13 @@ import tarantool.msgpack_ext.decimal as ext_decimal import tarantool.msgpack_ext.uuid as ext_uuid import tarantool.msgpack_ext.datetime as ext_datetime +import tarantool.msgpack_ext.interval as ext_interval decoders = { ext_decimal.EXT_ID : ext_decimal.decode , ext_uuid.EXT_ID : ext_uuid.decode , ext_datetime.EXT_ID: ext_datetime.decode, + ext_interval.EXT_ID: ext_interval.decode, } def ext_hook(code, data): diff --git a/test/suites/__init__.py b/test/suites/__init__.py index c5792bdd..7096cad9 100644 --- a/test/suites/__init__.py +++ b/test/suites/__init__.py @@ -18,13 +18,15 @@ from .test_decimal import TestSuite_Decimal from .test_uuid import TestSuite_UUID from .test_datetime import TestSuite_Datetime +from .test_interval import TestSuite_Interval test_cases = (TestSuite_Schema_UnicodeConnection, TestSuite_Schema_BinaryConnection, TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect, TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI, TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl, - TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime) + TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime, + TestSuite_Interval) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/test/suites/test_interval.py b/test/suites/test_interval.py new file mode 100644 index 00000000..3f941e9b --- /dev/null +++ b/test/suites/test_interval.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import sys +import unittest +import msgpack +import warnings +import tarantool +import pandas +import pytz + +from tarantool.msgpack_ext.packer import default as packer_default +from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook + +from .lib.tarantool_server import TarantoolServer +from .lib.skip import skip_or_run_datetime_test +from tarantool.error import MsgpackError + +class TestSuite_Interval(unittest.TestCase): + @classmethod + def setUpClass(self): + print(' INTERVAL EXT TYPE '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + self.srv = TarantoolServer() + self.srv.script = 'test/suites/box.lua' + self.srv.start() + + self.adm = self.srv.admin + self.adm(r""" + _, datetime = pcall(require, 'datetime') + + box.schema.space.create('test') + box.space['test']:create_index('primary', { + type = 'tree', + parts = {1, 'string'}, + unique = true}) + + box.schema.user.create('test', {password = 'test', if_not_exists = true}) + box.schema.user.grant('test', 'read,write,execute', 'universe') + """) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test') + + def setUp(self): + # prevent a remote tarantool from clean our session + if self.srv.is_started(): + self.srv.touch_lock() + + self.adm("box.space['test']:truncate()") + + def test_Interval_bytes_init(self): + dt = tarantool.Interval(b'\x02\x00\x01\x08\x01') + + self.assertEqual(dt.year, 1) + self.assertEqual(dt.month, 0) + self.assertEqual(dt.day, 0) + self.assertEqual(dt.hour, 0) + self.assertEqual(dt.minute, 0) + self.assertEqual(dt.sec, 0) + self.assertEqual(dt.nsec, 0) + self.assertEqual(dt.adjust, tarantool.IntervalAdjust.NONE) + + def test_Interval_bytes_init_ignore_other_fields(self): + dt = tarantool.Interval(b'\x02\x00\x01\x08\x01', + year=2, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.LAST) + + self.assertEqual(dt.year, 1) + self.assertEqual(dt.month, 0) + self.assertEqual(dt.day, 0) + self.assertEqual(dt.hour, 0) + self.assertEqual(dt.minute, 0) + self.assertEqual(dt.sec, 0) + self.assertEqual(dt.nsec, 0) + self.assertEqual(dt.adjust, tarantool.IntervalAdjust.NONE) + + + cases = { + 'year': { + 'python': tarantool.Interval(year=1), + 'msgpack': (b'\x02\x00\x01\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1})", + }, + 'big_year': { + 'python': tarantool.Interval(year=1000), + 'msgpack': (b'\x02\x00\xcd\x03\xe8\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1000})", + }, + 'date': { + 'python': tarantool.Interval(year=1, month=2, day=3), + 'msgpack': (b'\x04\x00\x01\x01\x02\x03\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3})", + }, + 'big_month_date': { + 'python': tarantool.Interval(year=1, month=100000, day=3), + 'msgpack': (b'\x04\x00\x01\x01\xce\x00\x01\x86\xa0\x03\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=100000, day=3})", + }, + 'time': { + 'python': tarantool.Interval(hour=1, minute=2, sec=3), + 'msgpack': (b'\x04\x04\x01\x05\x02\x06\x03\x08\x01'), + 'tarantool': r"datetime.interval.new({hour=1, min=2, sec=3})", + }, + 'big_seconds_time': { + 'python': tarantool.Interval(hour=1, minute=2, sec=3000), + 'msgpack': (b'\x04\x04\x01\x05\x02\x06\xcd\x0b\xb8\x08\x01'), + 'tarantool': r"datetime.interval.new({hour=1, min=2, sec=3000})", + }, + 'datetime': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, sec=3000), + 'msgpack': (b'\x07\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, min=2, sec=3000})", + }, + 'nanoseconds': { + 'python': tarantool.Interval(nsec=10000000), + 'msgpack': (b'\x02\x07\xce\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({nsec=10000000})", + }, + 'datetime_with_nanoseconds': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000})", + }, + 'datetime_none_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.NONE), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x01'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='none'})", + }, + 'datetime_excess_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.EXCESS), + 'msgpack': (b'\x07\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='excess'})", + }, + 'datetime_last_adjust': { + 'python': tarantool.Interval(year=1, month=2, day=3, hour=1, minute=2, + sec=3000, nsec=10000000, + adjust=tarantool.IntervalAdjust.LAST), + 'msgpack': (b'\x08\x00\x01\x01\x02\x03\x03\x04\x01\x05\x02\x06\xcd\x0b\xb8\x07\xce' + + b'\x00\x98\x96\x80\x08\x02'), + 'tarantool': r"datetime.interval.new({year=1, month=2, day=3, hour=1, " + + r"min=2, sec=3000, nsec=10000000, adjust='last'})", + }, + 'all_zeroes': { + 'python': tarantool.Interval(adjust=tarantool.IntervalAdjust.EXCESS), + 'msgpack': (b'\x00'), + 'tarantool': r"datetime.interval.new({adjust='excess'})", + }, + } + + def test_msgpack_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(unpacker_ext_hook(6, case['msgpack']), + case['python']) + + @skip_or_run_datetime_test + def test_tarantool_decode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.adm(f"box.space['test']:replace{{'{name}', {case['tarantool']}, 'field'}}") + + self.assertSequenceEqual(self.con.select('test', name), + [[name, case['python'], 'field']]) + + def test_msgpack_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.assertEqual(packer_default(case['python']), + msgpack.ExtType(code=6, data=case['msgpack'])) + + @skip_or_run_datetime_test + def test_tarantool_encode(self): + for name in self.cases.keys(): + with self.subTest(msg=name): + case = self.cases[name] + + self.con.insert('test', [name, case['python'], 'field']) + + lua_eval = f""" + local interval = {case['tarantool']} + + local tuple = box.space['test']:get('{name}') + assert(tuple ~= nil) + + if tuple[2] == interval then + return true + else + return nil, ('%s is not equal to expected %s'):format( + tostring(tuple[2]), tostring(interval)) + end + """ + + self.assertSequenceEqual(self.adm(lua_eval), [True]) + + + def test_unknown_field_decode(self): + case = b'\x01\x09\xce\x00\x98\x96\x80' + self.assertRaisesRegex( + MsgpackError, 'Unknown interval field id 9', + lambda: unpacker_ext_hook(6, case)) + + def test_unknown_adjust_decode(self): + case = b'\x02\x07\xce\x00\x98\x96\x80\x08\x03' + self.assertRaisesRegex( + MsgpackError, '3 is not a valid Adjust', + lambda: unpacker_ext_hook(6, case)) + + + @classmethod + def tearDownClass(self): + self.con.close() + self.srv.stop() + self.srv.clean()