Skip to content

Commit

Permalink
Merge branch 'master' into dismiss-build-error
Browse files Browse the repository at this point in the history
  • Loading branch information
aisk authored May 7, 2024
2 parents 5dc0172 + a16e0db commit 61f268a
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 111 deletions.
15 changes: 15 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
Changelog
=========

0.5.0
~~~~~

Version 0.5.0
-------------

Released on May 7, 2024.

- Dropped Python2 and Python3.5 Support.
- Added SASL transport client.
- Add submodule to sys.path when loading child idl file.
- Support cythonized module on Windows.
- Support using ipv6 in make_client/make_server method.
- Basic multi-thread support in parser.

0.4.x
~~~~~

Expand Down
65 changes: 49 additions & 16 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
# -*- coding: utf-8 -*-
import asyncio
import os
import random
import socket
import sys
import asyncio
# import uvloop
import threading
import random
import time
from unittest.mock import patch

import pytest

import thriftpy2
from thriftpy2.contrib.aio.protocol import (TAsyncBinaryProtocolFactory,
TAsyncCompactProtocolFactory)
from thriftpy2.contrib.aio.transport import (TAsyncBufferedTransportFactory,
TAsyncFramedTransportFactory)
from thriftpy2.rpc import make_aio_client, make_aio_server
from thriftpy2.thrift import TApplicationException
from thriftpy2.transport import TTransportException

# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

import time

import pytest

import thriftpy2

from thriftpy2.contrib.aio.transport import (
TAsyncBufferedTransportFactory,
TAsyncFramedTransportFactory,
)
from thriftpy2.contrib.aio.protocol import (
TAsyncBinaryProtocolFactory,
TAsyncCompactProtocolFactory,
)
from thriftpy2.rpc import make_aio_server, make_aio_client
from thriftpy2.transport import TTransportException
from thriftpy2.thrift import TApplicationException



if sys.platform == "win32":
Expand Down Expand Up @@ -113,6 +113,7 @@ class _TestAIO:
@classmethod
def setup_class(cls):
cls._start_server()
cls._start_ipv6_server()
cls.person = _create_person()

@classmethod
Expand All @@ -139,6 +140,22 @@ def _start_server(cls):
st.start()
time.sleep(0.1)

@classmethod
def _start_ipv6_server(cls):
cls.server = make_aio_server(
addressbook.AddressBookService,
Dispatcher(),
trans_factory=cls.TRANSPORT_FACTORY,
proto_factory=cls.PROTOCOL_FACTORY,
loop=asyncio.new_event_loop(),
socket_family=socket.AF_INET6,
**cls.server_kwargs(),
)
st = threading.Thread(target=cls.server.serve)
st.daemon = True
st.start()
time.sleep(0.1)

@classmethod
def server_kwargs(cls):
name = cls.__name__.lower()
Expand All @@ -157,12 +174,28 @@ async def client(self, timeout: int = 3000000):
**self.client_kwargs(),
)

async def ipv6_client(self, timeout: int = 3000000):
return await make_aio_client(
addressbook.AddressBookService,
trans_factory=self.TRANSPORT_FACTORY,
proto_factory=self.PROTOCOL_FACTORY,
timeout=timeout,
socket_family=socket.AF_INET6,
**self.client_kwargs(),
)

@pytest.mark.asyncio
async def test_void_api(self):
c = await self.client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_api_ipv6(self):
c = await self.ipv6_client()
assert await c.ping() is None
c.close()

@pytest.mark.asyncio
async def test_string_api(self):
c = await self.client()
Expand Down
38 changes: 35 additions & 3 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

thriftpy2.install_import_hook()

from thriftpy2.rpc import make_server, client_context # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.rpc import client_context, make_server # noqa
from thriftpy2.thrift import TApplicationException # noqa

from thriftpy2.transport import TTransportException # noqa

if sys.platform == "win32":
pytest.skip("requires unix domain socket", allow_module_level=True)
Expand Down Expand Up @@ -102,6 +101,26 @@ def fin():
request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ipv6_server(request):
server = make_server(addressbook.AddressBookService, Dispatcher(),
unix_socket=unix_sock, socket_family=socket.AF_INET6)
ps = multiprocessing.Process(target=server.serve)
ps.start()

time.sleep(0.1)

def fin():
if ps.is_alive():
ps.terminate()
try:
os.remove(unix_sock)
except IOError:
pass

request.addfinalizer(fin)


@pytest.fixture(scope="module")
def ssl_server(request):
ssl_server = make_server(addressbook.AddressBookService, Dispatcher(),
Expand Down Expand Up @@ -145,6 +164,14 @@ def client(timeout=3000):
unix_socket=unix_sock)


def ipv6_client(timeout=3000):
return client_context(addressbook.AddressBookService,
socket_timeout=timeout,
connect_timeout=timeout,
unix_socket=unix_sock,
socket_family=socket.AF_INET6)


def ssl_client(timeout=3000):
return client_context(addressbook.AddressBookService,
host='localhost', port=SSL_PORT,
Expand Down Expand Up @@ -175,6 +202,11 @@ def test_void_api(server):
assert c.ping() is None


def test_ipv6_api(ipv6_server):
with ipv6_client() as c:
assert c.ping() is None


def test_void_api_with_ssl(ssl_server):
with ssl_client() as c:
assert c.ping() is None
Expand Down
2 changes: 1 addition & 1 deletion thriftpy2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .hook import install_import_hook, remove_import_hook
from .parser import load, load_module, load_fp

__version__ = '0.4.20'
__version__ = '0.5.0'
__python__ = sys.version_info
__all__ = ["install_import_hook", "remove_import_hook", "load", "load_module",
"load_fp"]
42 changes: 29 additions & 13 deletions thriftpy2/contrib/aio/rpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import socket
import urllib
import warnings

Expand All @@ -17,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None):
socket_timeout=None, socket_family=socket.AF_INET):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
Expand All @@ -30,22 +31,31 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout)
client_socket = TAsyncSocket(
unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
client_socket = TAsyncSocket(
host,
port,
socket_timeout=timeout,
connect_timeout=connect_timeout,
cafile=cafile,
ssl_context=ssl_context,
certfile=certfile,
keyfile=keyfile,
validate=validate,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
await transport.open()
return TAsyncClient(service, protocol)
Expand All @@ -56,7 +66,8 @@ def make_server(service, handler,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
client_timeout=3000, certfile=None,
keyfile=None, ssl_context=None, loop=None):
keyfile=None, ssl_context=None, loop=None,
socket_family=socket.AF_INET):
processor = TAsyncProcessor(service, handler)

if unix_socket:
Expand All @@ -65,9 +76,14 @@ def make_server(service, handler,
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
server_socket = TAsyncServerSocket(
host=host, port=port,
host=host,
port=port,
client_timeout=client_timeout,
certfile=certfile, keyfile=keyfile, ssl_context=ssl_context)
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand Down
30 changes: 15 additions & 15 deletions thriftpy2/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import types

from .parser import parse, parse_fp, incomplete_type, _cast
from .parser import parse, parse_fp, threadlocal, _cast
from .exc import ThriftParserError
from ..thrift import TPayloadMeta

Expand All @@ -35,7 +35,7 @@ def load(path,
real_module = bool(module_name)
thrift = parse(path, module_name, include_dirs=include_dirs,
include_dir=include_dir, encoding=encoding)
if incomplete_type:
if threadlocal.incomplete_type:
fill_incomplete_ttype(thrift, thrift)

# add sub modules to sys.modules recursively
Expand All @@ -58,18 +58,18 @@ def fill_incomplete_ttype(tmodule, definition):
# construct const value
if definition[0] == 'UNKNOWN_CONST':
ttype = get_definition(
tmodule, incomplete_type[definition[1]][0], definition[3])
tmodule, threadlocal.incomplete_type[definition[1]][0], definition[3])
return _cast(ttype)(definition[2])
# construct incomplete alias type
elif definition[1] in incomplete_type:
elif definition[1] in threadlocal.incomplete_type:
return (
definition[0],
get_definition(tmodule, *incomplete_type[definition[1]])
get_definition(tmodule, *threadlocal.incomplete_type[definition[1]])
)
# construct incomplete type which is contained in service method's args
elif definition[0] in incomplete_type:
elif definition[0] in threadlocal.incomplete_type:
real_type = get_definition(
tmodule, *incomplete_type[definition[0]]
tmodule, *threadlocal.incomplete_type[definition[0]]
)
return (real_type[0], definition[1], real_type[1], definition[2])
# construct incomplete compound type
Expand All @@ -88,10 +88,10 @@ def fill_incomplete_ttype(tmodule, definition):
elif isinstance(definition, TPayloadMeta):
for index, value in definition.thrift_spec.items():
# if the ttype of the field is a single type and it is incompleted
if value[0] in incomplete_type:
if value[0] in threadlocal.incomplete_type:
real_type = fill_incomplete_ttype(
tmodule, get_definition(
tmodule, *incomplete_type[value[0]]
tmodule, *threadlocal.incomplete_type[value[0]]
)
)
# if the incomplete ttype is a compound type
Expand All @@ -107,19 +107,19 @@ def fill_incomplete_ttype(tmodule, definition):
definition.thrift_spec[index] = (
fill_incomplete_ttype(
tmodule, get_definition(
tmodule, *incomplete_type[value[0]]
tmodule, *threadlocal.incomplete_type[value[0]]
)
),
) + tuple(value[1:])
# if the field's ttype is a compound type
# and it contains incomplete types
elif value[2] in incomplete_type:
elif value[2] in threadlocal.incomplete_type:
definition.thrift_spec[index] = (
value[0],
value[1],
fill_incomplete_ttype(
tmodule, get_definition(
tmodule, *incomplete_type[value[2]]
tmodule, *threadlocal.incomplete_type[value[2]]
)
),
value[3])
Expand All @@ -129,8 +129,8 @@ def fill_incomplete_ttype(tmodule, definition):
def walk(part):
if isinstance(part, tuple):
return tuple(walk(x) for x in part)
if part in incomplete_type:
return get_definition(tmodule, *incomplete_type[part])
if part in threadlocal.incomplete_type:
return get_definition(tmodule, *threadlocal.incomplete_type[part])
return part
definition.thrift_spec[index] = (
value[0],
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_definition(thrift, name, lineno):
(name, lineno))
if isinstance(ref_type, int) and ref_type < 0:
raise ThriftParserError('No type found: %r, at line %d' %
incomplete_type[ref_type])
threadlocal.incomplete_type[ref_type])
if hasattr(ref_type, '_ttype'):
return (getattr(ref_type, '_ttype'), ref_type)
else:
Expand Down
Loading

0 comments on commit 61f268a

Please sign in to comment.