Skip to content

Commit

Permalink
Support using ipv6 in make_client/make_server method (#261)
Browse files Browse the repository at this point in the history
* ipv6 support

* Update thriftpy2/rpc.py

Co-authored-by: AN Long <[email protected]>

---------

Co-authored-by: AN Long <[email protected]>
  • Loading branch information
cocolato and aisk authored Apr 18, 2024
1 parent a66b339 commit 516cc30
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 57 deletions.
65 changes: 49 additions & 16 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
# -*- coding: utf-8 -*-
import asyncio
import os
import random
import socket
import sys
import asyncio
# import uvloop
import threading
import random
import time
from unittest.mock import patch

import pytest

import thriftpy2
from thriftpy2.contrib.aio.protocol import (TAsyncBinaryProtocolFactory,
TAsyncCompactProtocolFactory)
from thriftpy2.contrib.aio.transport import (TAsyncBufferedTransportFactory,
TAsyncFramedTransportFactory)
from thriftpy2.rpc import make_aio_client, make_aio_server
from thriftpy2.thrift import TApplicationException
from thriftpy2.transport import TTransportException

# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

import time

import pytest

import thriftpy2

from thriftpy2.contrib.aio.transport import (
TAsyncBufferedTransportFactory,
TAsyncFramedTransportFactory,
)
from thriftpy2.contrib.aio.protocol import (
TAsyncBinaryProtocolFactory,
TAsyncCompactProtocolFactory,
)
from thriftpy2.rpc import make_aio_server, make_aio_client
from thriftpy2.transport import TTransportException
from thriftpy2.thrift import TApplicationException



if sys.platform == "win32":
Expand Down Expand Up @@ -113,6 +113,7 @@ class _TestAIO:
@classmethod
def setup_class(cls):
cls._start_server()
cls._start_ipv6_server()
cls.person = _create_person()

@classmethod
Expand All @@ -139,6 +140,22 @@ def _start_server(cls):
st.start()
time.sleep(0.1)

@classmethod
def _start_ipv6_server(cls):
cls.server = make_aio_server(
addressbook.AddressBookService,
Dispatcher(),
trans_factory=cls.TRANSPORT_FACTORY,
proto_factory=cls.PROTOCOL_FACTORY,
loop=asyncio.new_event_loop(),
socket_family=socket.AF_INET6,
**cls.server_kwargs(),
)
st = threading.Thread(target=cls.server.serve)
st.daemon = True
st.start()
time.sleep(0.1)

@classmethod
def server_kwargs(cls):
name = cls.__name__.lower()
Expand All @@ -157,12 +174,28 @@ async def client(self, timeout: int = 3000000):
**self.client_kwargs(),
)

async def ipv6_client(self, timeout: int = 3000000):
return await make_aio_client(
addressbook.AddressBookService,
trans_factory=self.TRANSPORT_FACTORY,
proto_factory=self.PROTOCOL_FACTORY,
timeout=timeout,
socket_family=socket.AF_INET6,
**self.client_kwargs(),
)

@pytest.mark.asyncio
async def test_void_api(self):
c = await self.client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_api_ipv6(self):
c = await self.ipv6_client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_string_api(self):
c = await self.client()
Expand Down
38 changes: 35 additions & 3 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

thriftpy2.install_import_hook()

from thriftpy2.rpc import make_server, client_context # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.rpc import client_context, make_server # noqa
from thriftpy2.thrift import TApplicationException # noqa

from thriftpy2.transport import TTransportException # noqa

if sys.platform == "win32":
pytest.skip("requires unix domain socket", allow_module_level=True)
Expand Down Expand Up @@ -102,6 +101,26 @@ def fin():
request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ipv6_server(request):
server = make_server(addressbook.AddressBookService, Dispatcher(),
unix_socket=unix_sock, socket_family=socket.AF_INET6)
ps = multiprocessing.Process(target=server.serve)
ps.start()

time.sleep(0.1)

def fin():
if ps.is_alive():
ps.terminate()
try:
os.remove(unix_sock)
except IOError:
pass

request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ssl_server(request):
ssl_server = make_server(addressbook.AddressBookService, Dispatcher(),
Expand Down Expand Up @@ -145,6 +164,14 @@ def client(timeout=3000):
unix_socket=unix_sock)


def ipv6_client(timeout=3000):
return client_context(addressbook.AddressBookService,
socket_timeout=timeout,
connect_timeout=timeout,
unix_socket=unix_sock,
socket_family=socket.AF_INET6)


def ssl_client(timeout=3000):
return client_context(addressbook.AddressBookService,
host='localhost', port=SSL_PORT,
Expand Down Expand Up @@ -175,6 +202,11 @@ def test_void_api(server):
assert c.ping() is None


def test_ipv6_api(ipv6_server):
with ipv6_client() as c:
assert c.ping() is None


def test_void_api_with_ssl(ssl_server):
with ssl_client() as c:
assert c.ping() is None
Expand Down
42 changes: 29 additions & 13 deletions thriftpy2/contrib/aio/rpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import socket
import urllib
import warnings

Expand All @@ -17,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None):
socket_timeout=None, socket_family=socket.AF_INET):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
Expand All @@ -30,22 +31,31 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout)
client_socket = TAsyncSocket(
unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
client_socket = TAsyncSocket(
host,
port,
socket_timeout=timeout,
connect_timeout=connect_timeout,
cafile=cafile,
ssl_context=ssl_context,
certfile=certfile,
keyfile=keyfile,
validate=validate,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
await transport.open()
return TAsyncClient(service, protocol)
Expand All @@ -56,7 +66,8 @@ def make_server(service, handler,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
client_timeout=3000, certfile=None,
keyfile=None, ssl_context=None, loop=None):
keyfile=None, ssl_context=None, loop=None,
socket_family=socket.AF_INET):
processor = TAsyncProcessor(service, handler)

if unix_socket:
Expand All @@ -65,9 +76,14 @@ def make_server(service, handler,
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
server_socket = TAsyncServerSocket(
host=host, port=port,
host=host,
port=port,
client_timeout=client_timeout,
certfile=certfile, keyfile=keyfile, ssl_context=ssl_context)
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand Down
70 changes: 45 additions & 25 deletions thriftpy2/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,30 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
client_socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
if cafile or ssl_context:
socket = TSSLSocket(host, port, socket_timeout=timeout,
socket_family=socket_family, cafile=cafile,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
client_socket = TSSLSocket(
host,
port,
socket_timeout=timeout,
socket_family=socket_family,
cafile=cafile,
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
)
else:
socket = TSocket(host, port, socket_family=socket_family,
socket_timeout=timeout)
client_socket = TSocket(
host, port, socket_family=socket_family, socket_timeout=timeout
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
transport.open()
return TClient(service, protocol)
Expand All @@ -52,7 +59,8 @@ def make_server(service, handler,
host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
client_timeout=3000, certfile=None):
client_timeout=3000, certfile=None,
socket_family=socket.AF_INET):
processor = TProcessor(service, handler)

if unix_socket:
Expand All @@ -63,10 +71,11 @@ def make_server(service, handler,
if certfile:
server_socket = TSSLServerSocket(
host=host, port=port, client_timeout=client_timeout,
certfile=certfile)
certfile=certfile, socket_family=socket_family)
else:
server_socket = TServerSocket(
host=host, port=port, client_timeout=client_timeout)
host=host, port=port, client_timeout=client_timeout,
socket_family=socket_family)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand All @@ -82,7 +91,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
trans_factory=TBufferedTransportFactory(),
timeout=None, socket_timeout=3000, connect_timeout=3000,
cafile=None, ssl_context=None, certfile=None, keyfile=None,
url=""):
url="", socket_family=socket.AF_INET):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
Expand All @@ -94,29 +103,40 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
socket_timeout = connect_timeout = timeout

if unix_socket:
socket = TSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
client_socket = TSocket(
unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
if cafile or ssl_context:
socket = TSSLSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
cafile=cafile,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
client_socket = TSSLSocket(
host,
port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
cafile=cafile,
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family,
)
else:
socket = TSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
client_socket = TSocket(
host,
port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

try:
transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
transport.open()
yield TClient(service, protocol)
Expand Down

0 comments on commit 516cc30

Please sign in to comment.