diff --git a/thriftpy2/contrib/aio/socket.py b/thriftpy2/contrib/aio/socket.py index 8d1e49d..9bc2714 100644 --- a/thriftpy2/contrib/aio/socket.py +++ b/thriftpy2/contrib/aio/socket.py @@ -64,6 +64,8 @@ def __init__(self, host=None, port=None, unix_socket=None, to persist SSLContext object. Caution it's easy to get wrong, only use if you know what you're doing. """ + self.read_sock = None + self.write_sock = None if sock: self.raw_sock = sock elif unix_socket: @@ -93,8 +95,7 @@ def __init__(self, host=None, port=None, unix_socket=None, ciphers=ciphers) if cafile or capath: - self.ssl_context.load_verify_locations(cafile=cafile, - capath=capath) + self.ssl_context.load_verify_locations(cafile=cafile, capath=capath) if certfile: self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) @@ -106,19 +107,34 @@ def __init__(self, host=None, port=None, unix_socket=None, self.ssl_context = None self.server_hostname = None - def _init_sock(self): - if self.unix_socket: - _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - else: - _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) - _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + async def open(self): + addr = self.unix_socket or (self.host, self.port) + try: + if self.unix_socket: + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_unix_connection(addr), self.connect_timeout + ) + else: + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(self.host, self.port, ssl=self.ssl_context), + self.connect_timeout, + ) - # socket options - linger = struct.pack('ii', 0, 0) - _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) - _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + write_sock = self.writer.get_extra_info("socket") + read_sock = self.reader._transport.get_extra_info("socket") + self.write_sock = write_sock + self.read_sock = read_sock + linger = struct.pack("ii", 0, 0) + for sock in [self.read_sock, self.write_sock]: + # socket options + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self.raw_sock = _sock + except (socket.error, OSError): + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not connect to %s" % str(addr), + ) def set_handle(self, sock): self.raw_sock = sock @@ -129,46 +145,22 @@ def set_timeout(self, ms): """ self.socket_timeout = ms / 1000 if (ms and ms > 0) else None self.connect_timeout = self.socket_timeout + if self.read_sock: + self.read_sock.settimeout(self.socket_timeout) + if self.write_sock: + self.write_sock.settimeout(self.socket_timeout) - if self.raw_sock is not None: - self.raw_sock.settimeout(self.socket_timeout) - - def is_open(self): - return bool(self.raw_sock) - - async def open(self): - self._init_sock() - - addr = self.unix_socket or (self.host, self.port) + async def read(self, sz): try: - if self.connect_timeout: - self.raw_sock.settimeout(self.connect_timeout) - - self.raw_sock.connect(addr) - - if self.socket_timeout: - self.raw_sock.settimeout(self.socket_timeout) - - kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context} - if self.server_hostname: - kwargs['server_hostname'] = self.server_hostname - - self.reader, self.writer = await asyncio.wait_for( - self.sock_factory(**kwargs), - self.socket_timeout + buff = await asyncio.wait_for(self.reader.read(sz), self.connect_timeout) + except asyncio.TimeoutError: + raise TTransportException( + type=TTransportException.TIMED_OUT, message="TSocket read timed out" ) - - except (socket.error, OSError): + except asyncio.IncompleteReadError as e: raise TTransportException( - type=TTransportException.NOT_OPEN, - message="Could not connect to %s" % str(addr)) - - async def read(self, sz): - try: - buff = await asyncio.wait_for( - self.reader.read(sz), - self.connect_timeout + type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes" ) except socket.error as e: if e.errno == errno.ECONNRESET and MAC_OR_BSD: @@ -194,16 +186,18 @@ async def flush(self): await asyncio.wait_for(self.writer.drain(), self.connect_timeout) def close(self): - if not self.raw_sock: - return - - try: + if self.writer: self.writer.close() - self.raw_sock.close() - self.raw_sock = None - except (socket.error, OSError): - pass + # await self.writer.wait_closed() + + if self.reader: + self.reader._transport.close() + + if self.read_sock: + self.read_sock.close() + if self.write_sock: + self.write_sock.close() class TAsyncServerSocket(object): """Socket implementation for server side.""" @@ -251,7 +245,7 @@ def __init__(self, host=None, port=None, unix_socket=None, self.ssl_context = ssl_context elif certfile: if not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % certfile) + raise IOError("No such certfile found: %s" % certfile) self.ssl_context = create_thriftpy_context(server_side=True, ciphers=ciphers)