From 516cc30231f9982206980b2ed8f5d75d10f34f2a Mon Sep 17 00:00:00 2001 From: Hai Zhu <35182391+cocolato@users.noreply.github.com> Date: Fri, 19 Apr 2024 00:18:16 +0800 Subject: [PATCH] Support using ipv6 in make_client/make_server method (#261) * ipv6 support * Update thriftpy2/rpc.py Co-authored-by: AN Long --------- Co-authored-by: AN Long --- tests/test_aio.py | 65 ++++++++++++++++++++++++--------- tests/test_rpc.py | 38 ++++++++++++++++++-- thriftpy2/contrib/aio/rpc.py | 42 +++++++++++++++------- thriftpy2/rpc.py | 70 +++++++++++++++++++++++------------- 4 files changed, 158 insertions(+), 57 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index b9de45b..c1d6d85 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -1,31 +1,31 @@ # -*- coding: utf-8 -*- +import asyncio import os +import random +import socket import sys -import asyncio # import uvloop import threading -import random +import time from unittest.mock import patch +import pytest + +import thriftpy2 +from thriftpy2.contrib.aio.protocol import (TAsyncBinaryProtocolFactory, + TAsyncCompactProtocolFactory) +from thriftpy2.contrib.aio.transport import (TAsyncBufferedTransportFactory, + TAsyncFramedTransportFactory) +from thriftpy2.rpc import make_aio_client, make_aio_server +from thriftpy2.thrift import TApplicationException +from thriftpy2.transport import TTransportException + # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -import time -import pytest -import thriftpy2 -from thriftpy2.contrib.aio.transport import ( - TAsyncBufferedTransportFactory, - TAsyncFramedTransportFactory, -) -from thriftpy2.contrib.aio.protocol import ( - TAsyncBinaryProtocolFactory, - TAsyncCompactProtocolFactory, -) -from thriftpy2.rpc import make_aio_server, make_aio_client -from thriftpy2.transport import TTransportException -from thriftpy2.thrift import TApplicationException + if sys.platform == "win32": @@ -113,6 +113,7 @@ class _TestAIO: @classmethod def setup_class(cls): cls._start_server() + cls._start_ipv6_server() cls.person = _create_person() @classmethod @@ -139,6 +140,22 @@ def _start_server(cls): st.start() time.sleep(0.1) + @classmethod + def _start_ipv6_server(cls): + cls.server = make_aio_server( + addressbook.AddressBookService, + Dispatcher(), + trans_factory=cls.TRANSPORT_FACTORY, + proto_factory=cls.PROTOCOL_FACTORY, + loop=asyncio.new_event_loop(), + socket_family=socket.AF_INET6, + **cls.server_kwargs(), + ) + st = threading.Thread(target=cls.server.serve) + st.daemon = True + st.start() + time.sleep(0.1) + @classmethod def server_kwargs(cls): name = cls.__name__.lower() @@ -157,12 +174,28 @@ async def client(self, timeout: int = 3000000): **self.client_kwargs(), ) + async def ipv6_client(self, timeout: int = 3000000): + return await make_aio_client( + addressbook.AddressBookService, + trans_factory=self.TRANSPORT_FACTORY, + proto_factory=self.PROTOCOL_FACTORY, + timeout=timeout, + socket_family=socket.AF_INET6, + **self.client_kwargs(), + ) + @pytest.mark.asyncio async def test_void_api(self): c = await self.client() assert await c.ping() is None c.close() + @pytest.mark.asyncio + async def test_api_ipv6(self): + c = await self.ipv6_client() + assert await c.ping() is None + c.close() + @pytest.mark.asyncio async def test_string_api(self): c = await self.client() diff --git a/tests/test_rpc.py b/tests/test_rpc.py index e7b2283..3864500 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -15,10 +15,9 @@ thriftpy2.install_import_hook() -from thriftpy2.rpc import make_server, client_context # noqa -from thriftpy2.transport import TTransportException # noqa +from thriftpy2.rpc import client_context, make_server # noqa from thriftpy2.thrift import TApplicationException # noqa - +from thriftpy2.transport import TTransportException # noqa if sys.platform == "win32": pytest.skip("requires unix domain socket", allow_module_level=True) @@ -102,6 +101,26 @@ def fin(): request.addfinalizer(fin) +@pytest.fixture(scope="module") +def ipv6_server(request): + server = make_server(addressbook.AddressBookService, Dispatcher(), + unix_socket=unix_sock, socket_family=socket.AF_INET6) + ps = multiprocessing.Process(target=server.serve) + ps.start() + + time.sleep(0.1) + + def fin(): + if ps.is_alive(): + ps.terminate() + try: + os.remove(unix_sock) + except IOError: + pass + + request.addfinalizer(fin) + + @pytest.fixture(scope="module") def ssl_server(request): ssl_server = make_server(addressbook.AddressBookService, Dispatcher(), @@ -145,6 +164,14 @@ def client(timeout=3000): unix_socket=unix_sock) +def ipv6_client(timeout=3000): + return client_context(addressbook.AddressBookService, + socket_timeout=timeout, + connect_timeout=timeout, + unix_socket=unix_sock, + socket_family=socket.AF_INET6) + + def ssl_client(timeout=3000): return client_context(addressbook.AddressBookService, host='localhost', port=SSL_PORT, @@ -175,6 +202,11 @@ def test_void_api(server): assert c.ping() is None +def test_ipv6_api(ipv6_server): + with ipv6_client() as c: + assert c.ping() is None + + def test_void_api_with_ssl(ssl_server): with ssl_client() as c: assert c.ping() is None diff --git a/thriftpy2/contrib/aio/rpc.py b/thriftpy2/contrib/aio/rpc.py index 6ee8445..6f88025 100644 --- a/thriftpy2/contrib/aio/rpc.py +++ b/thriftpy2/contrib/aio/rpc.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import socket import urllib import warnings @@ -17,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None, cafile=None, ssl_context=None, certfile=None, keyfile=None, validate=True, url='', - socket_timeout=None): + socket_timeout=None, socket_family=socket.AF_INET): if socket_timeout is not None: warnings.warn( "The 'socket_timeout' argument is deprecated. " @@ -30,22 +31,31 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None, host = parsed_url.hostname or host port = parsed_url.port or port if unix_socket: - socket = TAsyncSocket(unix_socket=unix_socket, - connect_timeout=connect_timeout, - socket_timeout=timeout) + client_socket = TAsyncSocket( + unix_socket=unix_socket, + connect_timeout=connect_timeout, + socket_timeout=timeout, + ) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: - socket = TAsyncSocket( - host, port, - socket_timeout=timeout, connect_timeout=connect_timeout, - cafile=cafile, ssl_context=ssl_context, - certfile=certfile, keyfile=keyfile, validate=validate) + client_socket = TAsyncSocket( + host, + port, + socket_timeout=timeout, + connect_timeout=connect_timeout, + cafile=cafile, + ssl_context=ssl_context, + certfile=certfile, + keyfile=keyfile, + validate=validate, + socket_family=socket_family, + ) else: raise ValueError("Either host/port or unix_socket" " or url must be provided.") - transport = trans_factory.get_transport(socket) + transport = trans_factory.get_transport(client_socket) protocol = proto_factory.get_protocol(transport) await transport.open() return TAsyncClient(service, protocol) @@ -56,7 +66,8 @@ def make_server(service, handler, proto_factory=TAsyncBinaryProtocolFactory(), trans_factory=TAsyncBufferedTransportFactory(), client_timeout=3000, certfile=None, - keyfile=None, ssl_context=None, loop=None): + keyfile=None, ssl_context=None, loop=None, + socket_family=socket.AF_INET): processor = TAsyncProcessor(service, handler) if unix_socket: @@ -65,9 +76,14 @@ def make_server(service, handler, warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: server_socket = TAsyncServerSocket( - host=host, port=port, + host=host, + port=port, client_timeout=client_timeout, - certfile=certfile, keyfile=keyfile, ssl_context=ssl_context) + certfile=certfile, + keyfile=keyfile, + ssl_context=ssl_context, + socket_family=socket_family, + ) else: raise ValueError("Either host/port or unix_socket must be provided.") diff --git a/thriftpy2/rpc.py b/thriftpy2/rpc.py index ab6d567..c2c8257 100644 --- a/thriftpy2/rpc.py +++ b/thriftpy2/rpc.py @@ -26,23 +26,30 @@ def make_client(service, host="localhost", port=9090, unix_socket=None, host = parsed_url.hostname or host port = parsed_url.port or port if unix_socket: - socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout) + client_socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: if cafile or ssl_context: - socket = TSSLSocket(host, port, socket_timeout=timeout, - socket_family=socket_family, cafile=cafile, - certfile=certfile, keyfile=keyfile, - ssl_context=ssl_context) + client_socket = TSSLSocket( + host, + port, + socket_timeout=timeout, + socket_family=socket_family, + cafile=cafile, + certfile=certfile, + keyfile=keyfile, + ssl_context=ssl_context, + ) else: - socket = TSocket(host, port, socket_family=socket_family, - socket_timeout=timeout) + client_socket = TSocket( + host, port, socket_family=socket_family, socket_timeout=timeout + ) else: raise ValueError("Either host/port or unix_socket" " or url must be provided.") - transport = trans_factory.get_transport(socket) + transport = trans_factory.get_transport(client_socket) protocol = proto_factory.get_protocol(transport) transport.open() return TClient(service, protocol) @@ -52,7 +59,8 @@ def make_server(service, handler, host="localhost", port=9090, unix_socket=None, proto_factory=TBinaryProtocolFactory(), trans_factory=TBufferedTransportFactory(), - client_timeout=3000, certfile=None): + client_timeout=3000, certfile=None, + socket_family=socket.AF_INET): processor = TProcessor(service, handler) if unix_socket: @@ -63,10 +71,11 @@ def make_server(service, handler, if certfile: server_socket = TSSLServerSocket( host=host, port=port, client_timeout=client_timeout, - certfile=certfile) + certfile=certfile, socket_family=socket_family) else: server_socket = TServerSocket( - host=host, port=port, client_timeout=client_timeout) + host=host, port=port, client_timeout=client_timeout, + socket_family=socket_family) else: raise ValueError("Either host/port or unix_socket must be provided.") @@ -82,7 +91,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, trans_factory=TBufferedTransportFactory(), timeout=None, socket_timeout=3000, connect_timeout=3000, cafile=None, ssl_context=None, certfile=None, keyfile=None, - url=""): + url="", socket_family=socket.AF_INET): if url: parsed_url = urllib.parse.urlparse(url) host = parsed_url.hostname or host @@ -94,29 +103,40 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, socket_timeout = connect_timeout = timeout if unix_socket: - socket = TSocket(unix_socket=unix_socket, - connect_timeout=connect_timeout, - socket_timeout=socket_timeout) + client_socket = TSocket( + unix_socket=unix_socket, + connect_timeout=connect_timeout, + socket_timeout=socket_timeout, + ) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: if cafile or ssl_context: - socket = TSSLSocket(host, port, - connect_timeout=connect_timeout, - socket_timeout=socket_timeout, - cafile=cafile, - certfile=certfile, keyfile=keyfile, - ssl_context=ssl_context) + client_socket = TSSLSocket( + host, + port, + connect_timeout=connect_timeout, + socket_timeout=socket_timeout, + cafile=cafile, + certfile=certfile, + keyfile=keyfile, + ssl_context=ssl_context, + socket_family=socket_family, + ) else: - socket = TSocket(host, port, - connect_timeout=connect_timeout, - socket_timeout=socket_timeout) + client_socket = TSocket( + host, + port, + connect_timeout=connect_timeout, + socket_timeout=socket_timeout, + socket_family=socket_family, + ) else: raise ValueError("Either host/port or unix_socket" " or url must be provided.") try: - transport = trans_factory.get_transport(socket) + transport = trans_factory.get_transport(client_socket) protocol = proto_factory.get_protocol(transport) transport.open() yield TClient(service, protocol)