Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request to process bytearrays and custom _FLAG_ #167

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class Client(threading.local):
_FLAG_LONG = 1 << 2
_FLAG_COMPRESSED = 1 << 3
_FLAG_TEXT = 1 << 4
_FLAG_BYTE_ARRAY = 1 << 5

_SERVER_RETRIES = 10 # how many times to try finding a free server.

Expand Down Expand Up @@ -163,7 +164,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
pload=None, pid=None,
server_max_key_length=None, server_max_value_length=None,
dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
cache_cas=False, flush_on_reconnect=0, check_keys=True):
cache_cas=False, flush_on_reconnect=0, check_keys=True,
flags_to_overwrite={}):
"""Create a new Client object with the given list of servers.

@param servers: C{servers} is passed to L{set_servers}.
Expand Down Expand Up @@ -206,6 +208,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
@param check_keys: (default True) If True, the key is checked
to ensure it is the correct length and composed of the right
characters.
@param flags_to_overwrite: (default {}) A dictionary mapping
_FLAG_* values to new values. This allows compatibility
with other memcached clients which have different flags for
specific types.
"""
super(Client, self).__init__()
self.debug = debug
Expand Down Expand Up @@ -241,6 +247,26 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
except TypeError:
self.picklerIsKeyword = False

# Allow users to overwrite _FLAG_ with custom values
if flags_to_overwrite:
if self._FLAG_PICKLE in flags_to_overwrite:
self._FLAG_PICKLE = flags_to_overwrite[self._FLAG_PICKLE]

if self._FLAG_INTEGER in flags_to_overwrite:
self._FLAG_INTEGER = flags_to_overwrite[self._FLAG_INTEGER]

if self._FLAG_LONG in flags_to_overwrite:
self._FLAG_LONG = flags_to_overwrite[self._FLAG_LONG]

if self._FLAG_COMPRESSED in flags_to_overwrite:
self._FLAG_COMPRESSED = flags_to_overwrite[self._FLAG_COMPRESSED]

if self._FLAG_TEXT in flags_to_overwrite:
self._FLAG_TEXT = flags_to_overwrite[self._FLAG_TEXT]

if self._FLAG_BYTE_ARRAY in flags_to_overwrite:
self._FLAG_BYTE_ARRAY = flags_to_overwrite[self._FLAG_BYTE_ARRAY]

def _encode_key(self, key):
if isinstance(key, tuple):
if isinstance(key[1], six.text_type):
Expand Down Expand Up @@ -963,24 +989,27 @@ def _val_to_store_info(self, val, min_compress_len):
if val_type == six.binary_type:
pass
elif val_type == six.text_type:
flags |= Client._FLAG_TEXT
flags |= self._FLAG_TEXT
val = val.encode('utf-8')
elif val_type == int:
flags |= Client._FLAG_INTEGER
flags |= self._FLAG_INTEGER
val = '%d' % val
if six.PY3:
val = val.encode('ascii')
# force no attempt to compress this silly string.
min_compress_len = 0
elif six.PY2 and isinstance(val, long): # noqa: F821
flags |= Client._FLAG_LONG
flags |= self._FLAG_LONG
val = str(val)
if six.PY3:
val = val.encode('ascii')
# force no attempt to compress this silly string.
min_compress_len = 0
elif isinstance(val, bytearray):
flags |= self._FLAG_BYTE_ARRAY
val = bytes(val)
else:
flags |= Client._FLAG_PICKLE
flags |= self._FLAG_PICKLE
file = BytesIO()
if self.picklerIsKeyword:
pickler = self.pickler(file, protocol=self.pickleProtocol)
Expand All @@ -999,7 +1028,7 @@ def _val_to_store_info(self, val, min_compress_len):
# Only retain the result if the compression result is smaller
# than the original.
if len(comp_val) < lv:
flags |= Client._FLAG_COMPRESSED
flags |= self._FLAG_COMPRESSED
val = comp_val

# silently do not store if value length exceeds maximum
Expand Down Expand Up @@ -1253,22 +1282,22 @@ def _recv_value(self, server, flags, rlen):
if len(buf) == rlen:
buf = buf[:-2] # strip \r\n

if flags & Client._FLAG_COMPRESSED:
if flags & self._FLAG_COMPRESSED:
buf = self.decompressor(buf)
flags &= ~Client._FLAG_COMPRESSED
flags &= ~self._FLAG_COMPRESSED
if flags == 0:
# Bare bytes
val = buf
elif flags & Client._FLAG_TEXT:
elif flags & self._FLAG_TEXT:
val = buf.decode('utf-8')
elif flags & Client._FLAG_INTEGER:
elif flags & self._FLAG_INTEGER:
val = int(buf)
elif flags & Client._FLAG_LONG:
elif flags & self._FLAG_LONG:
if six.PY3:
val = int(buf)
else:
val = long(buf) # noqa: F821
elif flags & Client._FLAG_PICKLE:
elif flags & self._FLAG_PICKLE:
try:
file = BytesIO(buf)
unpickler = self.unpickler(file)
Expand All @@ -1278,6 +1307,8 @@ def _recv_value(self, server, flags, rlen):
except Exception as e:
self.debuglog('Pickle error: %s\n' % e)
return None
elif flags & self._FLAG_BYTE_ARRAY:
val = bytearray(buf)
else:
self.debuglog("unknown flags on get: %x\n" % flags)
raise ValueError('Unknown flags on get: %x' % flags)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def test_binary_string(self):
self.assertEqual(compressed_value, compressed_result)
self.assertEqual(value, zlib.decompress(compressed_result).decode())

def test_setget_bytearray(self):
val = bytearray(b'a string')
self.check_setget("bytearray", val)

def test_ignore_too_large_value(self):
# NOTE: "MemCached: while expecting[...]" is normal...
key = 'keyhere'
Expand Down Expand Up @@ -219,6 +223,25 @@ def test_disconnect_all_delete_multi(self):
"'NOT_FOUND'\n"
)

def test_flag_overwrite_on_instance(self):
"""Testing that flags are overwritten in object instance"""
servers = ["127.0.0.1:11211"]
flags_to_overwrite = {
Client._FLAG_INTEGER: 1 << 3, # 0000 0000 0000 1000 == 8
Client._FLAG_COMPRESSED: 1 << 1, # 0000 0000 0000 0010 == 2
}
another_mc = Client(servers, flags_to_overwrite=flags_to_overwrite, debug=1)

self.assertEqual(another_mc._FLAG_INTEGER, 1 << 3)
self.assertEqual(another_mc._FLAG_COMPRESSED, 1 << 1)
self.assertEqual(self.mc._FLAG_INTEGER, 1 << 1)
self.assertEqual(self.mc._FLAG_COMPRESSED, 1 << 3)
self.assertEqual(Client._FLAG_INTEGER, 1 << 1)
self.assertEqual(Client._FLAG_COMPRESSED, 1 << 3)

another_mc.flush_all()
another_mc.disconnect_all()

@mock.patch.object(_Host, 'send_cmd') # Don't send any commands.
@mock.patch.object(_Host, 'readline')
def test_touch_unexpected_reply(self, mock_readline, mock_send_cmd):
Expand Down