From ba53fd6b946131a377607ca7ee5fdf1dbaff5128 Mon Sep 17 00:00:00 2001 From: cocolato Date: Mon, 15 Apr 2024 21:42:10 +0800 Subject: [PATCH] ipv6 support --- tests/test_aio.py | 35 +++++++++++++++++++++++++++++++++++ tests/test_rpc.py | 33 +++++++++++++++++++++++++++++++++ thriftpy2/contrib/aio/rpc.py | 19 +++++++++++++------ thriftpy2/rpc.py | 23 +++++++++++++++-------- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index b9de45b..68b6c7a 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -13,6 +13,8 @@ import pytest +import socket + import thriftpy2 from thriftpy2.contrib.aio.transport import ( @@ -113,6 +115,7 @@ class _TestAIO: @classmethod def setup_class(cls): cls._start_server() + cls._start_ipv6_server() cls.person = _create_person() @classmethod @@ -139,6 +142,22 @@ def _start_server(cls): st.start() time.sleep(0.1) + @classmethod + def _start_ipv6_server(cls): + cls.server = make_aio_server( + addressbook.AddressBookService, + Dispatcher(), + trans_factory=cls.TRANSPORT_FACTORY, + proto_factory=cls.PROTOCOL_FACTORY, + loop=asyncio.new_event_loop(), + socket_family=socket.AF_INET6, + **cls.server_kwargs(), + ) + st = threading.Thread(target=cls.server.serve) + st.daemon = True + st.start() + time.sleep(0.1) + @classmethod def server_kwargs(cls): name = cls.__name__.lower() @@ -157,12 +176,28 @@ async def client(self, timeout: int = 3000000): **self.client_kwargs(), ) + async def ipv6_client(self, timeout: int = 3000000): + return await make_aio_client( + addressbook.AddressBookService, + trans_factory=self.TRANSPORT_FACTORY, + proto_factory=self.PROTOCOL_FACTORY, + timeout=timeout, + socket_family=socket.AF_INET6, + **self.client_kwargs(), + ) + @pytest.mark.asyncio async def test_void_api(self): c = await self.client() assert await c.ping() is None c.close() + @pytest.mark.asyncio + async def test_api_ipv6(self): + c = await self.ipv6_client() + assert await c.ping() is None + c.close() + @pytest.mark.asyncio async def test_string_api(self): c = await self.client() diff --git a/tests/test_rpc.py b/tests/test_rpc.py index e7b2283..3db315a 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -102,6 +102,26 @@ def fin(): request.addfinalizer(fin) +@pytest.fixture(scope="module") +def ipv6_server(request): + server = make_server(addressbook.AddressBookService, Dispatcher(), + unix_socket=unix_sock, socket_family=socket.AF_INET6) + ps = multiprocessing.Process(target=server.serve) + ps.start() + + time.sleep(0.1) + + def fin(): + if ps.is_alive(): + ps.terminate() + try: + os.remove(unix_sock) + except IOError: + pass + + request.addfinalizer(fin) + + @pytest.fixture(scope="module") def ssl_server(request): ssl_server = make_server(addressbook.AddressBookService, Dispatcher(), @@ -145,6 +165,14 @@ def client(timeout=3000): unix_socket=unix_sock) +def ipv6_client(timeout=3000): + return client_context(addressbook.AddressBookService, + socket_timeout=timeout, + connect_timeout=timeout, + unix_socket=unix_sock, + socket_family=socket.AF_INET6) + + def ssl_client(timeout=3000): return client_context(addressbook.AddressBookService, host='localhost', port=SSL_PORT, @@ -175,6 +203,11 @@ def test_void_api(server): assert c.ping() is None +def test_ipv6_api(ipv6_server): + with ipv6_client() as c: + assert c.ping() is None + + def test_void_api_with_ssl(ssl_server): with ssl_client() as c: assert c.ping() is None diff --git a/thriftpy2/contrib/aio/rpc.py b/thriftpy2/contrib/aio/rpc.py index 6ee8445..f09bf22 100644 --- a/thriftpy2/contrib/aio/rpc.py +++ b/thriftpy2/contrib/aio/rpc.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import urllib import warnings +import socket from .client import TAsyncClient from .processor import TAsyncProcessor @@ -17,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None, cafile=None, ssl_context=None, certfile=None, keyfile=None, validate=True, url='', - socket_timeout=None): + socket_timeout=None, socket_family=socket.AF_INET): if socket_timeout is not None: warnings.warn( "The 'socket_timeout' argument is deprecated. " @@ -32,7 +33,8 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None, if unix_socket: socket = TAsyncSocket(unix_socket=unix_socket, connect_timeout=connect_timeout, - socket_timeout=timeout) + socket_timeout=timeout, + socket_family=socket_family) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: @@ -40,7 +42,8 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None, host, port, socket_timeout=timeout, connect_timeout=connect_timeout, cafile=cafile, ssl_context=ssl_context, - certfile=certfile, keyfile=keyfile, validate=validate) + certfile=certfile, keyfile=keyfile, validate=validate, + socket_family=socket_family) else: raise ValueError("Either host/port or unix_socket" " or url must be provided.") @@ -56,18 +59,22 @@ def make_server(service, handler, proto_factory=TAsyncBinaryProtocolFactory(), trans_factory=TAsyncBufferedTransportFactory(), client_timeout=3000, certfile=None, - keyfile=None, ssl_context=None, loop=None): + keyfile=None, ssl_context=None, loop=None, + socket_family=socket.AF_INET): processor = TAsyncProcessor(service, handler) if unix_socket: - server_socket = TAsyncServerSocket(unix_socket=unix_socket) + server_socket = TAsyncServerSocket( + unix_socket=unix_socket, + socket_family=socket_family) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: server_socket = TAsyncServerSocket( host=host, port=port, client_timeout=client_timeout, - certfile=certfile, keyfile=keyfile, ssl_context=ssl_context) + certfile=certfile, keyfile=keyfile, ssl_context=ssl_context, + socket_family=socket_family) else: raise ValueError("Either host/port or unix_socket must be provided.") diff --git a/thriftpy2/rpc.py b/thriftpy2/rpc.py index ab6d567..fac5e61 100644 --- a/thriftpy2/rpc.py +++ b/thriftpy2/rpc.py @@ -52,21 +52,25 @@ def make_server(service, handler, host="localhost", port=9090, unix_socket=None, proto_factory=TBinaryProtocolFactory(), trans_factory=TBufferedTransportFactory(), - client_timeout=3000, certfile=None): + client_timeout=3000, certfile=None, + socket_family=socket.AF_INET): processor = TProcessor(service, handler) if unix_socket: - server_socket = TServerSocket(unix_socket=unix_socket) + server_socket = TServerSocket( + unix_socket=unix_socket, + socket_family=socket_family) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: if certfile: server_socket = TSSLServerSocket( host=host, port=port, client_timeout=client_timeout, - certfile=certfile) + certfile=certfile, socket_family=socket_family) else: server_socket = TServerSocket( - host=host, port=port, client_timeout=client_timeout) + host=host, port=port, client_timeout=client_timeout, + socket_family=socket_family) else: raise ValueError("Either host/port or unix_socket must be provided.") @@ -82,7 +86,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, trans_factory=TBufferedTransportFactory(), timeout=None, socket_timeout=3000, connect_timeout=3000, cafile=None, ssl_context=None, certfile=None, keyfile=None, - url=""): + url="", socket_family=socket.AF_INET): if url: parsed_url = urllib.parse.urlparse(url) host = parsed_url.hostname or host @@ -96,7 +100,8 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, if unix_socket: socket = TSocket(unix_socket=unix_socket, connect_timeout=connect_timeout, - socket_timeout=socket_timeout) + socket_timeout=socket_timeout, + socket_family=socket_family) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: @@ -106,11 +111,13 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, socket_timeout=socket_timeout, cafile=cafile, certfile=certfile, keyfile=keyfile, - ssl_context=ssl_context) + ssl_context=ssl_context, + socket_family=socket_family) else: socket = TSocket(host, port, connect_timeout=connect_timeout, - socket_timeout=socket_timeout) + socket_timeout=socket_timeout, + socket_family=socket_family) else: raise ValueError("Either host/port or unix_socket" " or url must be provided.")