diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 2f368725..ce1d5e5c 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -567,6 +567,8 @@ def __check_optional_meta_data(self): By Applying this, provide properly mapped column information on UPDATE,DELETE,INSERT. """, ) + else: + self.__optional_meta_data = True def fetchone(self): while True: diff --git a/pymysqlreplication/event.py b/pymysqlreplication/event.py index 760b01e0..2d6afbbd 100644 --- a/pymysqlreplication/event.py +++ b/pymysqlreplication/event.py @@ -6,6 +6,7 @@ from pymysqlreplication.constants.STATUS_VAR_KEY import * from pymysqlreplication.exceptions import StatusVariableMismatch +from pymysqlreplication.util.bytes import parse_decimal_from_bytes from typing import Union, Optional @@ -781,9 +782,7 @@ def _read_decimal(self, buffer: bytes) -> decimal.Decimal: self.precision = self.temp_value_buffer[0] self.decimals = self.temp_value_buffer[1] raw_decimal = self.temp_value_buffer[2:] - return self._parse_decimal_from_bytes( - raw_decimal, self.precision, self.decimals - ) + return parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals) def _read_default(self) -> bytes: """ @@ -792,57 +791,6 @@ def _read_default(self) -> bytes: """ return self.packet.read(self.value_len) - @staticmethod - def _parse_decimal_from_bytes( - raw_decimal: bytes, precision: int, decimals: int - ) -> decimal.Decimal: - """ - Parse decimal from bytes. - """ - digits_per_integer = 9 - compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4] - integral = precision - decimals - - uncomp_integral, comp_integral = divmod(integral, digits_per_integer) - uncomp_fractional, comp_fractional = divmod(decimals, digits_per_integer) - - res = "-" if not raw_decimal[0] & 0x80 else "" - mask = -1 if res == "-" else 0 - raw_decimal = bytearray([raw_decimal[0] ^ 0x80]) + raw_decimal[1:] - - def decode_decimal_decompress_value(comp_indx, data, mask): - size = compressed_bytes[comp_indx] - if size > 0: - databuff = bytearray(data[:size]) - for i in range(size): - databuff[i] = (databuff[i] ^ mask) & 0xFF - return size, int.from_bytes(databuff, byteorder="big") - return 0, 0 - - pointer, value = decode_decimal_decompress_value( - comp_integral, raw_decimal, mask - ) - res += str(value) - - for _ in range(uncomp_integral): - value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask - res += "%09d" % value - pointer += 4 - - res += "." - - for _ in range(uncomp_fractional): - value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask - res += "%09d" % value - pointer += 4 - - size, value = decode_decimal_decompress_value( - comp_fractional, raw_decimal[pointer:], mask - ) - if size > 0: - res += "%0*d" % (comp_fractional, value) - return decimal.Decimal(res) - def _dump(self) -> None: super(UserVarEvent, self)._dump() print("User variable name: %s" % self.name) diff --git a/pymysqlreplication/packet.py b/pymysqlreplication/packet.py index 80423200..3edbd74f 100644 --- a/pymysqlreplication/packet.py +++ b/pymysqlreplication/packet.py @@ -1,6 +1,8 @@ import struct from pymysqlreplication import constants, event, row_event +from pymysqlreplication.constants import FIELD_TYPE +from pymysqlreplication.util.bytes import * # Constants from PyMYSQL source code NULL_COLUMN = 251 @@ -13,7 +15,6 @@ UNSIGNED_INT24_LENGTH = 3 UNSIGNED_INT64_LENGTH = 8 - JSONB_TYPE_SMALL_OBJECT = 0x0 JSONB_TYPE_LARGE_OBJECT = 0x1 JSONB_TYPE_SMALL_ARRAY = 0x2 @@ -33,18 +34,141 @@ JSONB_LITERAL_TRUE = 0x1 JSONB_LITERAL_FALSE = 0x2 +JSONB_SMALL_OFFSET_SIZE = 2 +JSONB_LARGE_OFFSET_SIZE = 4 +JSONB_KEY_ENTRY_SIZE_SMALL = 2 + JSONB_SMALL_OFFSET_SIZE +JSONB_KEY_ENTRY_SIZE_LARGE = 2 + JSONB_LARGE_OFFSET_SIZE +JSONB_VALUE_ENTRY_SIZE_SMALL = 1 + JSONB_SMALL_OFFSET_SIZE +JSONB_VALUE_ENTRY_SIZE_LARGE = 1 + JSONB_LARGE_OFFSET_SIZE + + +def is_json_inline_value(type: bytes, is_small: bool) -> bool: + if type in [JSONB_TYPE_UINT16, JSONB_TYPE_INT16, JSONB_TYPE_LITERAL]: + return True + elif type in [JSONB_TYPE_INT32, JSONB_TYPE_UINT32]: + return not is_small + return False + + +def parse_json(type: bytes, data: bytes): + if type == JSONB_TYPE_SMALL_OBJECT: + v = parse_json_object_or_array(data, True, True) + elif type == JSONB_TYPE_LARGE_OBJECT: + v = parse_json_object_or_array(data, False, True) + elif type == JSONB_TYPE_SMALL_ARRAY: + v = parse_json_object_or_array(data, True, False) + elif type == JSONB_TYPE_LARGE_ARRAY: + v = parse_json_object_or_array(data, False, False) + elif type == JSONB_TYPE_LITERAL: + v = parse_literal(data) + elif type == JSONB_TYPE_INT16: + v = parse_int16(data) + elif type == JSONB_TYPE_UINT16: + v = parse_uint16(data) + elif type == JSONB_TYPE_INT32: + v = parse_int32(data) + elif type == JSONB_TYPE_UINT32: + v = parse_uint32(data) + elif type == JSONB_TYPE_INT64: + v = parse_int64(data) + elif type == JSONB_TYPE_UINT64: + v = parse_uint64(data) + elif type == JSONB_TYPE_DOUBLE: + v = parse_double(data) + elif type == JSONB_TYPE_STRING: + length, n = decode_variable_length(data) + v = parse_string(n, length, data) + elif type == JSONB_TYPE_OPAQUE: + v = parse_opaque(data) + else: + raise ValueError("Json type %d is not handled" % t) + return v + + +def parse_json_object_or_array(bytes, is_small, is_object): + offset_size = JSONB_SMALL_OFFSET_SIZE if is_small else JSONB_LARGE_OFFSET_SIZE + count = decode_count(bytes, is_small) + size = decode_count(bytes[offset_size:], is_small) + if is_small: + key_entry_size = JSONB_KEY_ENTRY_SIZE_SMALL + value_entry_size = JSONB_VALUE_ENTRY_SIZE_SMALL + else: + key_entry_size = JSONB_KEY_ENTRY_SIZE_LARGE + value_entry_size = JSONB_VALUE_ENTRY_SIZE_LARGE + if is_data_short(bytes, size): + raise ValueError( + "Before MySQL 5.7.22, json type generated column may have invalid value" + ) -def read_offset_or_inline(packet, large): - t = packet.read_uint8() - - if t in (JSONB_TYPE_LITERAL, JSONB_TYPE_INT16, JSONB_TYPE_UINT16): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - - if large: - return (t, packet.read_uint32(), None) - return (t, packet.read_uint16(), None) + header_size = 2 * offset_size + count * value_entry_size + + if is_object: + header_size += count * key_entry_size + + if header_size > size: + raise ValueError("header size > size") + + keys = [] + if is_object: + keys = [] + for i in range(count): + entry_offset = 2 * offset_size + key_entry_size * i + key_offset = decode_count(bytes[entry_offset:], is_small) + key_length = decode_uint(bytes[entry_offset + offset_size :]) + keys.append(bytes[key_offset : key_offset + key_length]) + + values = {} + for i in range(count): + entry_offset = 2 * offset_size + value_entry_size * i + if is_object: + entry_offset += key_entry_size * count + json_type = bytes[entry_offset] + if is_json_inline_value(json_type, is_small): + values[i] = parse_json( + json_type, bytes[entry_offset + 1 : entry_offset + value_entry_size] + ) + continue + value_offset = decode_count(bytes[entry_offset + 1 :], is_small) + if is_data_short(bytes, value_offset): + return None + values[i] = parse_json(json_type, bytes[value_offset:]) + if not is_object: + return list(values.values()) + out = {} + for i in range(count): + out[keys[i]] = values[i] + return out + + +def parse_literal(data: bytes): + json_type = data[0] + if json_type == JSONB_LITERAL_NULL: + return None + elif json_type == JSONB_LITERAL_TRUE: + return True + elif json_type == JSONB_LITERAL_FALSE: + return False + + raise ValueError("NOT LITERAL TYPE") + + +def parse_opaque(data: bytes): + if is_data_short(data, 1): + return None + type_ = data[0] + data = data[1:] + + length, n = decode_variable_length(data) + data = data[n : n + length] + + if type_ in [FIELD_TYPE.NEWDECIMAL, FIELD_TYPE.DECIMAL]: + return decode_decimal(data) + elif type_ in [FIELD_TYPE.TIME, FIELD_TYPE.TIME2]: + return decode_time(data) + elif type_ in [FIELD_TYPE.DATE, FIELD_TYPE.DATETIME, FIELD_TYPE.DATETIME2]: + return decode_datetime(data) + else: + return data.decode(errors="ignore") class BinLogPacketWrapper(object): @@ -375,131 +499,8 @@ def read_binary_json(self, size): if length == 0: # handle NULL value return None - payload = self.read(length) - self.unread(payload) - t = self.read_uint8() - - return self.read_binary_json_type(t, length) - - def read_binary_json_type(self, t, length): - large = t in (JSONB_TYPE_LARGE_OBJECT, JSONB_TYPE_LARGE_ARRAY) - if t in (JSONB_TYPE_SMALL_OBJECT, JSONB_TYPE_LARGE_OBJECT): - return self.read_binary_json_object(length - 1, large) - elif t in (JSONB_TYPE_SMALL_ARRAY, JSONB_TYPE_LARGE_ARRAY): - return self.read_binary_json_array(length - 1, large) - elif t in (JSONB_TYPE_STRING,): - return self.read_variable_length_string() - elif t in (JSONB_TYPE_LITERAL,): - value = self.read_uint8() - if value == JSONB_LITERAL_NULL: - return None - elif value == JSONB_LITERAL_TRUE: - return True - elif value == JSONB_LITERAL_FALSE: - return False - elif t == JSONB_TYPE_INT16: - return self.read_int16() - elif t == JSONB_TYPE_UINT16: - return self.read_uint16() - elif t in (JSONB_TYPE_DOUBLE,): - return struct.unpack(" length: - raise ValueError("Json length is larger than packet length") - - if large: - key_offset_lengths = [ - ( - self.read_uint32(), # offset (we don't actually need that) - self.read_uint16(), # size of the key - ) - for _ in range(elements) - ] - else: - key_offset_lengths = [ - ( - self.read_uint16(), # offset (we don't actually need that) - self.read_uint16(), # size of key - ) - for _ in range(elements) - ] - - value_type_inlined_lengths = [ - read_offset_or_inline(self, large) for _ in range(elements) - ] - - keys = [self.read(x[1]) for x in key_offset_lengths] - - out = {} - for i in range(elements): - if value_type_inlined_lengths[i][1] is None: - data = value_type_inlined_lengths[i][2] - else: - t = value_type_inlined_lengths[i][0] - data = self.read_binary_json_type(t, length) - out[keys[i]] = data - - return out - - def read_binary_json_array(self, length, large): - if large: - elements = self.read_uint32() - size = self.read_uint32() - else: - elements = self.read_uint16() - size = self.read_uint16() - - if size > length: - raise ValueError("Json length is larger than packet length") - - values_type_offset_inline = [ - read_offset_or_inline(self, large) for _ in range(elements) - ] - - def _read(x): - if x[1] is None: - return x[2] - return self.read_binary_json_type(x[0], length) - - return [_read(x) for x in values_type_offset_inline] + data = self.read(length) + return parse_json(data[0], data[1:]) def read_string(self): """Read a 'Length Coded String' from the data buffer. diff --git a/pymysqlreplication/tests/test_basic.py b/pymysqlreplication/tests/test_basic.py index 84081b42..f38d4be7 100644 --- a/pymysqlreplication/tests/test_basic.py +++ b/pymysqlreplication/tests/test_basic.py @@ -596,6 +596,33 @@ def create_binlog_packet_wrapper(pkt): self.assertEqual(binlog_event.event._is_event_valid, True) self.assertNotEqual(wrong_event.event._is_event_valid, True) + def test_json_update(self): + self.stream.close() + self.stream = BinLogStreamReader( + self.database, server_id=1024, only_events=[UpdateRowsEvent] + ) + create_query = ( + "CREATE TABLE setting_table( id SERIAL AUTO_INCREMENT, setting JSON);" + ) + insert_query = """INSERT INTO setting_table (setting) VALUES ('{"btn": true, "model": false}');""" + + update_query = """ UPDATE setting_table + SET setting = JSON_REMOVE(setting, '$.model') + WHERE id=1; + """ + self.execute(create_query) + self.execute(insert_query) + self.execute(update_query) + self.execute("COMMIT;") + event = self.stream.fetchone() + + if event.table_map[event.table_id].column_name_flag: + self.assertEqual( + event.rows[0]["before_values"]["setting"], + {b"btn": True, b"model": False}, + ), + self.assertEqual(event.rows[0]["after_values"]["setting"], {b"btn": True}), + class TestMultipleRowBinLogStreamReader(base.PyMySQLReplicationTestCase): def setUp(self): @@ -1696,7 +1723,7 @@ def test_sync_drop_table_map_event_table_schema(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) - self.assertEqual(event.table_obj.data["columns"][0].name, None) + self.assertEqual(event.table_obj.data["columns"][0].name, "name") self.assertEqual(len(column_schemas), 0) def test_sync_column_drop_event_table_schema(self): @@ -1727,9 +1754,9 @@ def test_sync_column_drop_event_table_schema(self): self.assertEqual(len(event.table_obj.data["columns"]), 3) self.assertEqual(column_schemas[0][0], "drop_column1") self.assertEqual(column_schemas[1][0], "drop_column3") - self.assertEqual(event.table_obj.data["columns"][0].name, None) - self.assertEqual(event.table_obj.data["columns"][1].name, None) - self.assertEqual(event.table_obj.data["columns"][2].name, None) + self.assertEqual(event.table_obj.data["columns"][0].name, "drop_column1") + self.assertEqual(event.table_obj.data["columns"][1].name, "drop_column2") + self.assertEqual(event.table_obj.data["columns"][2].name, "drop_column3") def tearDown(self): self.execute("SET GLOBAL binlog_row_metadata='MINIMAL';") diff --git a/pymysqlreplication/tests/test_data_type.py b/pymysqlreplication/tests/test_data_type.py index 133a3862..6a18aca2 100644 --- a/pymysqlreplication/tests/test_data_type.py +++ b/pymysqlreplication/tests/test_data_type.py @@ -588,8 +588,6 @@ def test_geometry(self): ) def test_json(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" insert_query = """INSERT INTO test (id, value) VALUES (1, '{"my_key": "my_val", "my_key2": "my_val2"}');""" event = self.create_and_insert_value(create_query, insert_query) @@ -600,8 +598,6 @@ def test_json(self): ) def test_json_array(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" insert_query = ( """INSERT INTO test (id, value) VALUES (1, '["my_val", "my_val2"]');""" @@ -611,8 +607,6 @@ def test_json_array(self): self.assertEqual(event.rows[0]["values"]["value"], [b"my_val", b"my_val2"]) def test_json_large(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") data = dict( [("foooo%i" % i, "baaaaar%i" % i) for i in range(2560)] ) # Make it large enough to reach 2^16 length @@ -626,8 +620,6 @@ def test_json_large(self): def test_json_large_array(self): "Test json array larger than 64k bytes" - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" large_array = dict(my_key=[i for i in range(100000)]) insert_query = "INSERT INTO test (id, value) VALUES (1, '%s');" % ( @@ -640,8 +632,6 @@ def test_json_large_array(self): ) def test_json_large_with_literal(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") data = dict( [("foooo%i" % i, "baaaaar%i" % i) for i in range(2560)], literal=True ) # Make it large with literal @@ -654,9 +644,6 @@ def test_json_large_with_literal(self): self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict(data)) def test_json_types(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") - types = [ True, False, @@ -685,9 +672,6 @@ def test_json_types(self): self.setUp() def test_json_basic(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") - types = [ True, False, @@ -714,8 +698,6 @@ def test_json_basic(self): self.setUp() def test_json_unicode(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" insert_query = """INSERT INTO test (id, value) VALUES (1, '{"miam": "🍔"}');""" event = self.create_and_insert_value(create_query, insert_query) @@ -725,8 +707,6 @@ def test_json_unicode(self): ) def test_json_long_string(self): - if not self.isMySQL57(): - self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" # The string length needs to be larger than what can fit in a single byte. string_value = "super_long_string" * 100 @@ -735,12 +715,31 @@ def test_json_long_string(self): % (string_value,) ) event = self.create_and_insert_value(create_query, insert_query) + print(event.rows[0]) if event.table_map[event.table_id].column_name_flag: self.assertEqual( event.rows[0]["values"]["value"], to_binary_dict({"my_key": string_value}), ) + def test_json_deciaml_time_datetime(self): + create_query = """CREATE TABLE json_deciaml_time_datetime_test ( + id INT PRIMARY KEY AUTO_INCREMENT, + json_data JSON + );""" + insert_query = """ + INSERT INTO json_deciaml_time_datetime_test (json_data) VALUES (JSON_OBJECT('time_key', CAST('18:54:12' AS TIME), 'datetime_key', CAST('2023-09-24 18:54:12' AS DATETIME) ,'decimal', CAST('99.99' AS DECIMAL(10, 2))));""" + event = self.create_and_insert_value(create_query, insert_query) + if event.table_map[event.table_id].column_name_flag: + self.assertEqual( + event.rows[0]["values"]["json_data"], + { + b"decimal": Decimal("99.99"), + b"time_key": datetime.time(18, 54, 12), + b"datetime_key": datetime.datetime(2023, 9, 24, 18, 54, 12), + }, + ) + def test_null(self): create_query = "CREATE TABLE test ( \ test TINYINT NULL DEFAULT NULL, \ diff --git a/pymysqlreplication/util/__init__.py b/pymysqlreplication/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymysqlreplication/util/bytes.py b/pymysqlreplication/util/bytes.py new file mode 100644 index 00000000..d4162278 --- /dev/null +++ b/pymysqlreplication/util/bytes.py @@ -0,0 +1,177 @@ +import datetime +import decimal +import struct +import sys + + +def is_data_short(data: bytes, expected: int): + if len(data) < expected: + return True + return False + + +def decode_count(data: bytes, is_small: bool): + if is_small: + return parse_uint16(data) + else: + return parse_uint32(data) + + +def decode_uint(data: bytes): + if is_data_short(data, 2): + return 0 + return parse_uint16(data) + + +def decode_variable_length(data: bytes): + max_count = 5 + if len(data) < max_count: + max_count = len(data) + pos = 0 + length = 0 + for _ in range(max_count): + v = data[pos] + length |= (v & 0x7F) << (7 * pos) + pos += 1 + if v & 0x80 == 0: + if length > sys.maxsize - 1: + return 0, 0 + return int(length), pos + + return 0, 0 + + +def parse_decimal_from_bytes( + raw_decimal: bytes, precision: int, decimals: int +) -> decimal.Decimal: + """ + Parse decimal from bytes. + """ + digits_per_integer = 9 + compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4] + integral = precision - decimals + + uncomp_integral, comp_integral = divmod(integral, digits_per_integer) + uncomp_fractional, comp_fractional = divmod(decimals, digits_per_integer) + + res = "-" if not raw_decimal[0] & 0x80 else "" + mask = -1 if res == "-" else 0 + raw_decimal = bytearray([raw_decimal[0] ^ 0x80]) + raw_decimal[1:] + + def decode_decimal_decompress_value(comp_indx, data, mask): + size = compressed_bytes[comp_indx] + if size > 0: + databuff = bytearray(data[:size]) + for i in range(size): + databuff[i] = (databuff[i] ^ mask) & 0xFF + return size, int.from_bytes(databuff, byteorder="big") + return 0, 0 + + pointer, value = decode_decimal_decompress_value(comp_integral, raw_decimal, mask) + res += str(value) + + for _ in range(uncomp_integral): + value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask + res += "%09d" % value + pointer += 4 + + res += "." + + for _ in range(uncomp_fractional): + value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask + res += "%09d" % value + pointer += 4 + + size, value = decode_decimal_decompress_value( + comp_fractional, raw_decimal[pointer:], mask + ) + if size > 0: + res += "%0*d" % (comp_fractional, value) + return decimal.Decimal(res) + + +def decode_decimal(data: bytes): + return parse_decimal_from_bytes(data[2:], data[0], data[1]) + + +def decode_time(data: bytes): + v = parse_int64(data) + + if v == 0: + return datetime.time(hour=0, minute=0, second=0) + + if v < 0: + v = -v + int_part = v >> 24 + hour = (int_part >> 12) % (1 << 10) + min = (int_part >> 6) % (1 << 6) + sec = int_part % (1 << 6) + frac = v % (1 << 24) + return datetime.time(hour=hour, minute=min, second=sec, microsecond=frac) + + +def decode_datetime(data): + v = parse_int64(data) + + if v == 0: + # datetime parse Error + return "0000-00-00 00:00:00" + + if v < 0: + v = -v + + int_part = v >> 24 + ymd = int_part >> 17 + ym = ymd >> 5 + hms = int_part % (1 << 17) + + year = ym // 13 + month = ym % 13 + day = ymd % (1 << 5) + hour = hms >> 12 + minute = (hms >> 6) % (1 << 6) + second = hms % (1 << 6) + frac = v % (1 << 24) + + return datetime.datetime( + year=year, + month=month, + day=day, + hour=hour, + minute=minute, + second=second, + microsecond=frac, + ) + + +def parse_int16(data: bytes): + return struct.unpack("