Skip to content

Commit

Permalink
Make sure a new connection is established upon retry, before marking …
Browse files Browse the repository at this point in the history
…dead, to avoid reusing a stale socket.
  • Loading branch information
Ricardo Alves committed Apr 9, 2017
1 parent 288c159 commit d26c66b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
8 changes: 4 additions & 4 deletions memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def _unsafe_set():
except _ConnectionDeadError:
# retry once
try:
if server._get_socket():
if server._get_socket(reconnect=True):
return _unsafe_set()
except (_ConnectionDeadError, socket.error) as msg:
server.mark_dead(msg)
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def _unsafe_get():
except _ConnectionDeadError:
# retry once
try:
if server.connect():
if server._get_socket(reconnect=True):
return _unsafe_get()
return None
except (_ConnectionDeadError, socket.error) as msg:
Expand Down Expand Up @@ -1386,10 +1386,10 @@ def mark_dead(self, reason):
self.flush_on_next_connect = 1
self.close_socket()

def _get_socket(self):
def _get_socket(self, reconnect=False):
if self._check_dead():
return None
if self.socket:
if self.socket and not reconnect:
return self.socket
s = socket.socket(self.family, socket.SOCK_STREAM)
if hasattr(s, 'settimeout'):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_memcache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import print_function

import socket
import unittest

import six
Expand Down Expand Up @@ -167,5 +168,52 @@ def test_disconnect_all_delete_multi(self):
self.assertEqual(ret, 1)


class TestMemcacheMarkDead(unittest.TestCase):

def setUp(self):
self.status = locals()
self.address = ("127.0.0.1", 11213)
self._start_stub_server()
self.client = Client(["127.0.0.1:11213"], debug=1)

def tearDown(self):
self._stop_stub_server()

def _start_stub_server(self):
# setup stub server
stub_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
stub_socket.bind(self.address)
stub_socket.listen(1)
self.stub_socket = stub_socket

def _stop_stub_server(self):
self.stub_socket.close()

def test_mark_server_dead(self):
mc_host = self.client._get_server('foo'.encode('utf8'))[0]
client_socket = mc_host._get_socket()

# make sure the server is not marked dead
self.assertEqual(0, mc_host._check_dead())

# stop the stub server
self._stop_stub_server()

# host is not yet marked as dead
self.assertEqual(0, mc_host._check_dead())

# create a new stub socket again
self._start_stub_server()

# The client will try to re-use the old socket and if it fails
# then should re-establish a new connection
# so the socket must be a new one
new_client_socket = mc_host._get_socket(reconnect=True)
self.assertNotEqual(new_client_socket, client_socket)

# server is not marked dead, because connection succeeded
self.assertEqual(0, mc_host._check_dead())


if __name__ == '__main__':
unittest.main()

0 comments on commit d26c66b

Please sign in to comment.