Skip to content

Commit

Permalink
🐛 Fix async pending error
Browse files Browse the repository at this point in the history
  • Loading branch information
holegots committed May 28, 2024
1 parent 30deec7 commit 994a6e9
Showing 1 changed file with 24 additions and 56 deletions.
80 changes: 24 additions & 56 deletions thriftpy2/contrib/aio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def __init__(self, host=None, port=None, unix_socket=None,
ciphers=ciphers)

if cafile or capath:
self.ssl_context.load_verify_locations(cafile=cafile,
capath=capath)
self.ssl_context.load_verify_locations(cafile=cafile, capath=capath)

if certfile:
self.ssl_context.load_cert_chain(certfile, keyfile=keyfile)
Expand All @@ -106,58 +105,23 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.ssl_context = None
self.server_hostname = None

def _init_sock(self):
if self.unix_socket:
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
_sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# socket options
linger = struct.pack('ii', 0, 0)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

self.raw_sock = _sock

def set_handle(self, sock):
self.raw_sock = sock

def set_timeout(self, ms):
"""Backward compat api, will bind the timeout to both connect_timeout
and socket_timeout.
"""
self.socket_timeout = ms / 1000 if (ms and ms > 0) else None
self.connect_timeout = self.socket_timeout

if self.raw_sock is not None:
self.raw_sock.settimeout(self.socket_timeout)

def is_open(self):
return bool(self.raw_sock)

async def open(self):
self._init_sock()

addr = self.unix_socket or (self.host, self.port)

try:
if self.connect_timeout:
self.raw_sock.settimeout(self.connect_timeout)

self.raw_sock.connect(addr)

if self.socket_timeout:
self.raw_sock.settimeout(self.socket_timeout)

kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context}
if self.server_hostname:
kwargs['server_hostname'] = self.server_hostname

self.reader, self.writer = await asyncio.wait_for(
self.sock_factory(**kwargs),
self.socket_timeout
)
if self.unix_socket:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_unix_connection(addr), self.connect_timeout
)
else:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(self.host, self.port, ssl=self.ssl_context),
self.connect_timeout,
)
sock = self.writer.get_extra_info("socket")
# socket options
linger = struct.pack("ii", 0, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

except (socket.error, OSError):
raise TTransportException(
Expand All @@ -166,9 +130,14 @@ async def open(self):

async def read(self, sz):
try:
buff = await asyncio.wait_for(
self.reader.read(sz),
self.connect_timeout
buff = await asyncio.wait_for(self.reader.read(sz), self.connect_timeout)
except asyncio.TimeoutError:
raise TTransportException(
type=TTransportException.TIMED_OUT, message="TSocket read timed out"
)
except asyncio.IncompleteReadError as e:
raise TTransportException(
type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
)
except socket.error as e:
if e.errno == errno.ECONNRESET and MAC_OR_BSD:
Expand Down Expand Up @@ -199,7 +168,6 @@ def close(self):

try:
self.writer.close()
self.raw_sock.close()
self.raw_sock = None
except (socket.error, OSError):
pass
Expand Down Expand Up @@ -251,7 +219,7 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.ssl_context = ssl_context
elif certfile:
if not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % certfile)
raise IOError("No such certfile found: %s" % certfile)

self.ssl_context = create_thriftpy_context(server_side=True,
ciphers=ciphers)
Expand Down

0 comments on commit 994a6e9

Please sign in to comment.