From d2b7d435fbf9c70aa7a83b4ccadb0e143097ebb1 Mon Sep 17 00:00:00 2001 From: brimstoned Date: Wed, 11 Sep 2019 13:04:39 -0500 Subject: [PATCH] Adding feature to process bytearrays on set/get Adding feature to overwrite _FLAG_* with alternate values Adding round trip test Adding a test to verify the flags are overwritten in object instance --- memcache.py | 55 +++++++++++++++++++++++++++++++++--------- tests/test_memcache.py | 23 ++++++++++++++++++ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/memcache.py b/memcache.py index 05b6657..ec51048 100644 --- a/memcache.py +++ b/memcache.py @@ -135,6 +135,7 @@ class Client(threading.local): _FLAG_LONG = 1 << 2 _FLAG_COMPRESSED = 1 << 3 _FLAG_TEXT = 1 << 4 + _FLAG_BYTE_ARRAY = 1 << 5 _SERVER_RETRIES = 10 # how many times to try finding a free server. @@ -163,7 +164,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0, pload=None, pid=None, server_max_key_length=None, server_max_value_length=None, dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT, - cache_cas=False, flush_on_reconnect=0, check_keys=True): + cache_cas=False, flush_on_reconnect=0, check_keys=True, + flags_to_overwrite={}): """Create a new Client object with the given list of servers. @param servers: C{servers} is passed to L{set_servers}. @@ -206,6 +208,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0, @param check_keys: (default True) If True, the key is checked to ensure it is the correct length and composed of the right characters. + @param flags_to_overwrite: (default {}) A dictionary mapping + _FLAG_* values to new values. This allows compatibility + with other memcached clients which have different flags for + specific types. """ super(Client, self).__init__() self.debug = debug @@ -241,6 +247,26 @@ def __init__(self, servers, debug=0, pickleProtocol=0, except TypeError: self.picklerIsKeyword = False + # Allow users to overwrite _FLAG_ with custom values + if flags_to_overwrite: + if self._FLAG_PICKLE in flags_to_overwrite: + self._FLAG_PICKLE = flags_to_overwrite[self._FLAG_PICKLE] + + if self._FLAG_INTEGER in flags_to_overwrite: + self._FLAG_INTEGER = flags_to_overwrite[self._FLAG_INTEGER] + + if self._FLAG_LONG in flags_to_overwrite: + self._FLAG_LONG = flags_to_overwrite[self._FLAG_LONG] + + if self._FLAG_COMPRESSED in flags_to_overwrite: + self._FLAG_COMPRESSED = flags_to_overwrite[self._FLAG_COMPRESSED] + + if self._FLAG_TEXT in flags_to_overwrite: + self._FLAG_TEXT = flags_to_overwrite[self._FLAG_TEXT] + + if self._FLAG_BYTE_ARRAY in flags_to_overwrite: + self._FLAG_BYTE_ARRAY = flags_to_overwrite[self._FLAG_BYTE_ARRAY] + def _encode_key(self, key): if isinstance(key, tuple): if isinstance(key[1], six.text_type): @@ -963,24 +989,27 @@ def _val_to_store_info(self, val, min_compress_len): if val_type == six.binary_type: pass elif val_type == six.text_type: - flags |= Client._FLAG_TEXT + flags |= self._FLAG_TEXT val = val.encode('utf-8') elif val_type == int: - flags |= Client._FLAG_INTEGER + flags |= self._FLAG_INTEGER val = '%d' % val if six.PY3: val = val.encode('ascii') # force no attempt to compress this silly string. min_compress_len = 0 elif six.PY2 and isinstance(val, long): # noqa: F821 - flags |= Client._FLAG_LONG + flags |= self._FLAG_LONG val = str(val) if six.PY3: val = val.encode('ascii') # force no attempt to compress this silly string. min_compress_len = 0 + elif isinstance(val, bytearray): + flags |= self._FLAG_BYTE_ARRAY + val = bytes(val) else: - flags |= Client._FLAG_PICKLE + flags |= self._FLAG_PICKLE file = BytesIO() if self.picklerIsKeyword: pickler = self.pickler(file, protocol=self.pickleProtocol) @@ -999,7 +1028,7 @@ def _val_to_store_info(self, val, min_compress_len): # Only retain the result if the compression result is smaller # than the original. if len(comp_val) < lv: - flags |= Client._FLAG_COMPRESSED + flags |= self._FLAG_COMPRESSED val = comp_val # silently do not store if value length exceeds maximum @@ -1253,22 +1282,22 @@ def _recv_value(self, server, flags, rlen): if len(buf) == rlen: buf = buf[:-2] # strip \r\n - if flags & Client._FLAG_COMPRESSED: + if flags & self._FLAG_COMPRESSED: buf = self.decompressor(buf) - flags &= ~Client._FLAG_COMPRESSED + flags &= ~self._FLAG_COMPRESSED if flags == 0: # Bare bytes val = buf - elif flags & Client._FLAG_TEXT: + elif flags & self._FLAG_TEXT: val = buf.decode('utf-8') - elif flags & Client._FLAG_INTEGER: + elif flags & self._FLAG_INTEGER: val = int(buf) - elif flags & Client._FLAG_LONG: + elif flags & self._FLAG_LONG: if six.PY3: val = int(buf) else: val = long(buf) # noqa: F821 - elif flags & Client._FLAG_PICKLE: + elif flags & self._FLAG_PICKLE: try: file = BytesIO(buf) unpickler = self.unpickler(file) @@ -1278,6 +1307,8 @@ def _recv_value(self, server, flags, rlen): except Exception as e: self.debuglog('Pickle error: %s\n' % e) return None + elif flags & self._FLAG_BYTE_ARRAY: + val = bytearray(buf) else: self.debuglog("unknown flags on get: %x\n" % flags) raise ValueError('Unknown flags on get: %x' % flags) diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 40b6524..451759f 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -163,6 +163,10 @@ def test_binary_string(self): self.assertEqual(compressed_value, compressed_result) self.assertEqual(value, zlib.decompress(compressed_result).decode()) + def test_setget_bytearray(self): + val = bytearray(b'a string') + self.check_setget("bytearray", val) + def test_ignore_too_large_value(self): # NOTE: "MemCached: while expecting[...]" is normal... key = 'keyhere' @@ -219,6 +223,25 @@ def test_disconnect_all_delete_multi(self): "'NOT_FOUND'\n" ) + def test_flag_overwrite_on_instance(self): + """Testing that flags are overwritten in object instance""" + servers = ["127.0.0.1:11211"] + flags_to_overwrite = { + Client._FLAG_INTEGER: 1 << 3, # 0000 0000 0000 1000 == 8 + Client._FLAG_COMPRESSED: 1 << 1, # 0000 0000 0000 0010 == 2 + } + another_mc = Client(servers, flags_to_overwrite=flags_to_overwrite, debug=1) + + self.assertEqual(another_mc._FLAG_INTEGER, 1 << 3) + self.assertEqual(another_mc._FLAG_COMPRESSED, 1 << 1) + self.assertEqual(self.mc._FLAG_INTEGER, 1 << 1) + self.assertEqual(self.mc._FLAG_COMPRESSED, 1 << 3) + self.assertEqual(Client._FLAG_INTEGER, 1 << 1) + self.assertEqual(Client._FLAG_COMPRESSED, 1 << 3) + + another_mc.flush_all() + another_mc.disconnect_all() + @mock.patch.object(_Host, 'send_cmd') # Don't send any commands. @mock.patch.object(_Host, 'readline') def test_touch_unexpected_reply(self, mock_readline, mock_send_cmd):