Skip to content

Commit

Permalink
ipv6 support
Browse files Browse the repository at this point in the history
  • Loading branch information
cocolato committed Apr 15, 2024
1 parent a66b339 commit ba53fd6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 14 deletions.
35 changes: 35 additions & 0 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import pytest

import socket

import thriftpy2

from thriftpy2.contrib.aio.transport import (
Expand Down Expand Up @@ -113,6 +115,7 @@ class _TestAIO:
@classmethod
def setup_class(cls):
cls._start_server()
cls._start_ipv6_server()
cls.person = _create_person()

@classmethod
Expand All @@ -139,6 +142,22 @@ def _start_server(cls):
st.start()
time.sleep(0.1)

@classmethod
def _start_ipv6_server(cls):
cls.server = make_aio_server(
addressbook.AddressBookService,
Dispatcher(),
trans_factory=cls.TRANSPORT_FACTORY,
proto_factory=cls.PROTOCOL_FACTORY,
loop=asyncio.new_event_loop(),
socket_family=socket.AF_INET6,
**cls.server_kwargs(),
)
st = threading.Thread(target=cls.server.serve)
st.daemon = True
st.start()
time.sleep(0.1)

@classmethod
def server_kwargs(cls):
name = cls.__name__.lower()
Expand All @@ -157,12 +176,28 @@ async def client(self, timeout: int = 3000000):
**self.client_kwargs(),
)

async def ipv6_client(self, timeout: int = 3000000):
return await make_aio_client(
addressbook.AddressBookService,
trans_factory=self.TRANSPORT_FACTORY,
proto_factory=self.PROTOCOL_FACTORY,
timeout=timeout,
socket_family=socket.AF_INET6,
**self.client_kwargs(),
)

@pytest.mark.asyncio
async def test_void_api(self):
c = await self.client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_api_ipv6(self):
c = await self.ipv6_client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_string_api(self):
c = await self.client()
Expand Down
33 changes: 33 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def fin():
request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ipv6_server(request):
server = make_server(addressbook.AddressBookService, Dispatcher(),
unix_socket=unix_sock, socket_family=socket.AF_INET6)
ps = multiprocessing.Process(target=server.serve)
ps.start()

time.sleep(0.1)

def fin():
if ps.is_alive():
ps.terminate()
try:
os.remove(unix_sock)
except IOError:
pass

request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ssl_server(request):
ssl_server = make_server(addressbook.AddressBookService, Dispatcher(),
Expand Down Expand Up @@ -145,6 +165,14 @@ def client(timeout=3000):
unix_socket=unix_sock)


def ipv6_client(timeout=3000):
return client_context(addressbook.AddressBookService,
socket_timeout=timeout,
connect_timeout=timeout,
unix_socket=unix_sock,
socket_family=socket.AF_INET6)


def ssl_client(timeout=3000):
return client_context(addressbook.AddressBookService,
host='localhost', port=SSL_PORT,
Expand Down Expand Up @@ -175,6 +203,11 @@ def test_void_api(server):
assert c.ping() is None


def test_ipv6_api(ipv6_server):
with ipv6_client() as c:
assert c.ping() is None


def test_void_api_with_ssl(ssl_server):
with ssl_client() as c:
assert c.ping() is None
Expand Down
19 changes: 13 additions & 6 deletions thriftpy2/contrib/aio/rpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import urllib
import warnings
import socket

from .client import TAsyncClient
from .processor import TAsyncProcessor
Expand All @@ -17,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None):
socket_timeout=None, socket_family=socket.AF_INET):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
Expand All @@ -32,15 +33,17 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout)
socket_timeout=timeout,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
certfile=certfile, keyfile=keyfile, validate=validate,
socket_family=socket_family)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")
Expand All @@ -56,18 +59,22 @@ def make_server(service, handler,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
client_timeout=3000, certfile=None,
keyfile=None, ssl_context=None, loop=None):
keyfile=None, ssl_context=None, loop=None,
socket_family=socket.AF_INET):
processor = TAsyncProcessor(service, handler)

if unix_socket:
server_socket = TAsyncServerSocket(unix_socket=unix_socket)
server_socket = TAsyncServerSocket(
unix_socket=unix_socket,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
server_socket = TAsyncServerSocket(
host=host, port=port,
client_timeout=client_timeout,
certfile=certfile, keyfile=keyfile, ssl_context=ssl_context)
certfile=certfile, keyfile=keyfile, ssl_context=ssl_context,
socket_family=socket_family)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand Down
23 changes: 15 additions & 8 deletions thriftpy2/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,25 @@ def make_server(service, handler,
host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
client_timeout=3000, certfile=None):
client_timeout=3000, certfile=None,
socket_family=socket.AF_INET):
processor = TProcessor(service, handler)

if unix_socket:
server_socket = TServerSocket(unix_socket=unix_socket)
server_socket = TServerSocket(
unix_socket=unix_socket,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
if certfile:
server_socket = TSSLServerSocket(
host=host, port=port, client_timeout=client_timeout,
certfile=certfile)
certfile=certfile, socket_family=socket_family)
else:
server_socket = TServerSocket(
host=host, port=port, client_timeout=client_timeout)
host=host, port=port, client_timeout=client_timeout,
socket_family=socket_family)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand All @@ -82,7 +86,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
trans_factory=TBufferedTransportFactory(),
timeout=None, socket_timeout=3000, connect_timeout=3000,
cafile=None, ssl_context=None, certfile=None, keyfile=None,
url=""):
url="", socket_family=socket.AF_INET):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
Expand All @@ -96,7 +100,8 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
if unix_socket:
socket = TSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
socket_timeout=socket_timeout,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
Expand All @@ -106,11 +111,13 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
socket_timeout=socket_timeout,
cafile=cafile,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
ssl_context=ssl_context,
socket_family=socket_family)
else:
socket = TSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
socket_timeout=socket_timeout,
socket_family=socket_family)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")
Expand Down

0 comments on commit ba53fd6

Please sign in to comment.