Skip to content

Commit

Permalink
Merge pull request #81 from iamsudip/enforce_req_args
Browse files Browse the repository at this point in the history
enforce required arguments, fixes #72
  • Loading branch information
ethe authored Sep 1, 2019
2 parents f4aa812 + 74b93d0 commit 90cbc37
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/addressbook.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ exception PersonNotExistsError {

service AddressBookService {
void ping();
string hello(1: string name);
string hello(1: required string name);
bool add(1: Person person);
bool remove(1: string name) throws (1: PersonNotExistsError not_exists);
Person get(1: string name) throws (1: PersonNotExistsError not_exists);
Expand Down
11 changes: 11 additions & 0 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from thriftpy2.rpc import make_aio_server, make_aio_client # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.thrift import TApplicationException # noqa

addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
"addressbook.thrift"))
Expand Down Expand Up @@ -164,6 +165,16 @@ async def test_string_api(aio_server):
c.close()


@pytest.mark.asyncio
async def test_required_argument(aio_server):
c = await client()
assert await c.hello("") == "hello "

with pytest.raises(TApplicationException):
await c.hello()
c.close()


@pytest.mark.asyncio
async def test_string_api_with_ssl(aio_ssl_server):
c = await client()
Expand Down
9 changes: 9 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
thriftpy2.install_import_hook() # noqa

from thriftpy2.http import make_server, client_context
from thriftpy2.thrift import TApplicationException


addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -115,6 +116,14 @@ def test_string_api(server):
assert c.hello("world") == "hello world"


def test_required_argument(server):
with client() as c:
with pytest.raises(TApplicationException):
c.hello()

assert c.hello(name="") == "hello "


def test_huge_res(server):
with client() as c:
big_str = "world" * 100000
Expand Down
8 changes: 8 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from thriftpy2._compat import PY3 # noqa
from thriftpy2.rpc import make_server, client_context # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.thrift import TApplicationException # noqa


addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -158,6 +159,13 @@ def test_string_api(server):
assert c.hello("world") == "hello world"


def test_required_argument(server):
with client() as c:
assert c.hello("") == "hello "
with pytest.raises(TApplicationException):
c.hello()


def test_string_api_with_ssl(ssl_server):
with ssl_client() as c:
assert c.hello("world") == "hello world"
Expand Down
13 changes: 9 additions & 4 deletions thriftpy2/contrib/aio/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import asyncio
import functools
from thriftpy2.thrift import args2kwargs
from thriftpy2.thrift import args_to_kwargs
from thriftpy2.thrift import TApplicationException, TMessageType


Expand All @@ -26,9 +26,14 @@ def __dir__(self):

@asyncio.coroutine
def _req(self, _api, *args, **kwargs):
_kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args)
kwargs.update(_kw)
try:
kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args, **kwargs)
except ValueError as e:
raise TApplicationException(
TApplicationException.UNKNOWN_METHOD,
'missing required argument {arg} for {service}.{api}'.format(
arg=e.args[0], service=self._service.__name__, api=_api))
result_cls = getattr(self._service, _api + "_result")

yield from self._send(_api, **kwargs)
Expand Down
30 changes: 23 additions & 7 deletions thriftpy2/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
import linecache
import types

from ._compat import with_metaclass
from ._compat import with_metaclass, PY3
if PY3:
from itertools import zip_longest
else:
from itertools import izip_longest as zip_longest


def args2kwargs(thrift_spec, *args):
arg_names = [item[1][1] for item in sorted(thrift_spec.items())]
return dict(zip(arg_names, args))
def args_to_kwargs(thrift_spec, *args, **kwargs):
for item, value in zip_longest(sorted(thrift_spec.items()), args):
arg_name = item[1][1]
required = item[1][-1]
if value is not None:
kwargs[item[1][1]] = value
if required and arg_name not in kwargs:
raise ValueError(arg_name)
return kwargs


def parse_spec(ttype, spec=None):
Expand Down Expand Up @@ -192,9 +202,15 @@ def __dir__(self):
return self._service.thrift_services

def _req(self, _api, *args, **kwargs):
_kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args)
kwargs.update(_kw)
try:
kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args, **kwargs)
except ValueError as e:
raise TApplicationException(
TApplicationException.UNKNOWN_METHOD,
'{arg} is required argument for {service}.{api}'.format(
arg=e.args[0], service=self._service.__name__, api=_api))

result_cls = getattr(self._service, _api + "_result")

self._send(_api, **kwargs)
Expand Down

0 comments on commit 90cbc37

Please sign in to comment.