diff --git a/falcon/testing/client.py b/falcon/testing/client.py index fce0f4d0a..d4d4e7403 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -17,19 +17,36 @@ This package includes utilities for simulating HTTP requests against a WSGI callable, without having to stand up a WSGI server. """ +from __future__ import annotations import asyncio import datetime as dt import inspect import json as json_module import time +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import cast from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Mapping from typing import Optional +from typing import overload from typing import Sequence +from typing import TextIO +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import warnings import wsgiref.validate +from typing_extensions import Literal +from typing_extensions import Self +from typing_extensions import TypeGuard + from falcon.asgi_spec import ScopeType from falcon.constants import COMBINED_METHODS from falcon.constants import MEDIA_JSON @@ -45,6 +62,10 @@ from falcon.util import http_date_to_dt from falcon.util import to_query_str +if TYPE_CHECKING: + from falcon import App as WSGI + from falcon.asgi import App as ASGI + warnings.filterwarnings( 'error', ('Unknown REQUEST_METHOD: ' + "'({})'".format('|'.join(COMBINED_METHODS))), @@ -53,19 +74,27 @@ 0, ) +_C = TypeVar('_C', bound=Callable) + -def _simulate_method_alias(method, version_added='3.1', replace_name=None): +def _simulate_method_alias( + method: _C, version_added: str = '3.1', replace_name: Optional[str] = None +) -> _C: return_type = inspect.signature(method).return_annotation - def alias(client, *args, **kwargs) -> return_type: + def alias( + client: Any, *args: Any, **kwargs: Any + ) -> return_type: # type:ignore[valid-type] return method(client, *args, **kwargs) - async def async_alias(client, *args, **kwargs) -> return_type: + async def async_alias( + client, *args, **kwargs + ) -> return_type: # type:ignore[valid-type] return await method(client, *args, **kwargs) alias = async_alias if inspect.iscoroutinefunction(method) else alias - alias.__doc__ = method.__doc__ + '\n .. versionadded:: {}\n'.format( + alias.__doc__ = (method.__doc__ or '') + '\n .. versionadded:: {}\n'.format( version_added ) if replace_name: @@ -74,7 +103,7 @@ async def async_alias(client, *args, **kwargs) -> return_type: else: alias.__name__ = method.__name__.partition('simulate_')[-1] - return alias + return cast('_C', alias) class Cookie: @@ -101,7 +130,7 @@ class Cookie: included in unscripted requests from the client. """ - def __init__(self, morsel): + def __init__(self, morsel: http_cookies.Morsel[str]) -> None: self._name = morsel.key self._value = morsel.value @@ -194,7 +223,7 @@ class _ResultBase: if the encoding can not be determined. """ - def __init__(self, status, headers): + def __init__(self, status: str, headers: Iterable[Tuple[str, str]]) -> None: self._status = status self._status_code = int(status[:3]) self._headers = CaseInsensitiveDict(headers) @@ -234,7 +263,7 @@ def cookies(self) -> Dict[str, Cookie]: return self._cookies @property - def encoding(self) -> str: + def encoding(self) -> str | None: return self._encoding @@ -247,7 +276,7 @@ class ResultBodyStream: collected. """ - def __init__(self, chunks: Sequence[bytes]): + def __init__(self, chunks: Sequence[bytes]) -> None: self._chunks = chunks self._chunk_pos = 0 @@ -315,10 +344,12 @@ class Result(_ResultBase): if the response is not valid JSON. """ - def __init__(self, iterable, status, headers): + def __init__( + self, iterable: Iterable[bytes], status: str, headers: Iterable[Tuple[str, str]] + ) -> None: super().__init__(status, headers) - self._text = None + self._text: Optional[str] = None self._content = b''.join(iterable) @property @@ -341,13 +372,13 @@ def text(self) -> str: return self._text @property - def json(self) -> Optional[Union[dict, list, str, int, float, bool]]: + def json(self) -> Union[dict, list, str, int, float, bool, None]: if not self.text: return None return json_module.loads(self.text) - def __repr__(self): + def __repr__(self) -> str: content_type = self.headers.get('Content-Type', '') if len(self.content) > 40: @@ -402,7 +433,14 @@ class StreamedResult(_ResultBase): stream (ResultStream): Raw response body, as a byte stream. """ - def __init__(self, body_chunks, status, headers, task, req_event_emitter): + def __init__( + self, + body_chunks: Sequence[bytes], + status: str, + headers: Iterable[Tuple[str, str]], + task: asyncio.Task[Any], + req_event_emitter: helpers.ASGIRequestEventEmitter, + ) -> None: super().__init__(status, headers) self._task = task @@ -424,35 +462,39 @@ async def finalize(self): await self._task +_HeadersType = Union[None, Mapping[str, str], Iterable[Tuple[str, str]]] +_AnyApp = Union['WSGI', 'ASGI'] +_AnyResult = Union[Result, StreamedResult] + + # NOTE(kgriffs): The default of asgi_disconnect_ttl was chosen to be # relatively long (5 minutes) to help testers notice when something # appears to be "hanging", which might indicates that the app is # not handling the reception of events correctly. def simulate_request( - app, - method='GET', - path='/', - query_string=None, - headers=None, - content_type=None, - body=None, - json=None, - file_wrapper=None, - wsgierrors=None, - params=None, - params_csv=False, - protocol='http', - host=helpers.DEFAULT_HOST, - remote_addr=None, - extras=None, - http_version='1.1', - port=None, - root_path=None, - cookies=None, - asgi_chunk_size=4096, - asgi_disconnect_ttl=300, -) -> _ResultBase: - + app: _AnyApp, + method: str = 'GET', + path: str = '/', + query_string: Optional[str] = None, + headers: _HeadersType = None, + content_type: Optional[str] = None, + body: Union[str, bytes, None] = None, + json: Any = None, + file_wrapper: Optional[Callable] = None, + wsgierrors: Optional[TextIO] = None, + params: Optional[Mapping[str, Union[str, list[str]]]] = None, + params_csv: bool = False, + protocol: str = 'http', + host: str = helpers.DEFAULT_HOST, + remote_addr: Optional[str] = None, + extras: Optional[Mapping[str, Any]] = None, + http_version: str = '1.1', + port: Optional[int] = None, + root_path: Optional[str] = None, + cookies: _HeadersType = None, + asgi_chunk_size: int = 4096, + asgi_disconnect_ttl: int = 300, +) -> Result: """Simulate a request to a WSGI or ASGI application. Performs a request against a WSGI or ASGI application. In the case of @@ -587,6 +629,7 @@ def simulate_request( asgi_chunk_size=asgi_chunk_size, asgi_disconnect_ttl=asgi_disconnect_ttl, cookies=cookies, + _stream_result=False, ) path, query_string, headers, body, extras = _prepare_sim_args( @@ -618,6 +661,7 @@ def simulate_request( cookies=cookies, ) + assert extras is not None if 'REQUEST_METHOD' in extras and extras['REQUEST_METHOD'] != method: # NOTE(vytas): Even given the duct tape nature of overriding # arbitrary environ variables, changing the method can potentially @@ -638,40 +682,95 @@ def simulate_request( return Result(helpers.closed_wsgi_iterable(iterable), srmock.status, srmock.headers) +@overload +async def _simulate_request_asgi( # type:ignore[overload-overlap] + app: ASGI, + method: str = ..., + path: str = ..., + query_string: Optional[str] = ..., + headers: _HeadersType = ..., + content_type: Optional[str] = ..., + body: Union[str, bytes, None] = ..., + json: Any = ..., + params: Optional[Mapping[str, Union[str, list[str]]]] = ..., + params_csv: bool = ..., + protocol: str = ..., + host: str = ..., + remote_addr: Optional[str] = ..., + extras: Optional[Mapping[str, Any]] = ..., + http_version: str = ..., + port: Optional[int] = ..., + root_path: Optional[str] = ..., + asgi_chunk_size: int = ..., + asgi_disconnect_ttl: int = ..., + cookies: _HeadersType = ..., + _one_shot: Literal[False] = ..., + _stream_result: Literal[True] = ..., +) -> StreamedResult: + ... + + +@overload +async def _simulate_request_asgi( + app: ASGI, + method: str = ..., + path: str = ..., + query_string: Optional[str] = ..., + headers: _HeadersType = ..., + content_type: Optional[str] = ..., + body: Union[str, bytes, None] = ..., + json: Any = ..., + params: Optional[Mapping[str, Union[str, list[str]]]] = ..., + params_csv: bool = ..., + protocol: str = ..., + host: str = ..., + remote_addr: Optional[str] = ..., + extras: Optional[Mapping[str, Any]] = ..., + http_version: str = ..., + port: Optional[int] = ..., + root_path: Optional[str] = ..., + asgi_chunk_size: int = ..., + asgi_disconnect_ttl: int = ..., + cookies: _HeadersType = ..., + _one_shot: bool = ..., + _stream_result: Literal[False] = ..., +) -> Result: + ... + + # NOTE(kgriffs): The default of asgi_disconnect_ttl was chosen to be # relatively long (5 minutes) to help testers notice when something # appears to be "hanging", which might indicates that the app is # not handling the reception of events correctly. async def _simulate_request_asgi( - app, - method='GET', - path='/', - query_string=None, - headers=None, - content_type=None, - body=None, - json=None, - params=None, - params_csv=True, - protocol='http', - host=helpers.DEFAULT_HOST, - remote_addr=None, - extras=None, - http_version='1.1', - port=None, - root_path=None, - asgi_chunk_size=4096, - asgi_disconnect_ttl=300, - cookies=None, + app: ASGI, + method: str = 'GET', + path: str = '/', + query_string: Optional[str] = None, + headers: _HeadersType = None, + content_type: Optional[str] = None, + body: Union[str, bytes, None] = None, + json: Any = None, + params: Optional[Mapping[str, Union[str, list[str]]]] = None, + params_csv: bool = True, + protocol: str = 'http', + host: str = helpers.DEFAULT_HOST, + remote_addr: Optional[str] = None, + extras: Optional[Mapping[str, Any]] = None, + http_version: str = '1.1', + port: Optional[int] = None, + root_path: Optional[str] = None, + asgi_chunk_size: int = 4096, + asgi_disconnect_ttl: int = 300, + cookies: _HeadersType = None, # NOTE(kgriffs): These are undocumented because they are only # meant to be used internally by the framework (i.e., they are # not part of the public interface.) In case we ever expose # simulate_request_asgi() as part of the public interface, we # don't want these kwargs to be documented. - _one_shot=True, - _stream_result=False, -) -> _ResultBase: - + _one_shot: bool = True, + _stream_result: bool = False, +) -> _AnyResult: """Simulate a request to an ASGI application. Keyword Args: @@ -803,6 +902,7 @@ async def _simulate_request_asgi( cookies=cookies, ) + assert extras is not None if 'method' in extras and extras['method'] != method.upper(): raise ValueError( 'ASGI scope extras may not override the request method. ' @@ -813,7 +913,7 @@ async def _simulate_request_asgi( # --------------------------------------------------------------------- if asgi_disconnect_ttl == 0: # Special case - disconnect_at = 0 + disconnect_at = 0.0 else: disconnect_at = time.time() + max(0, asgi_disconnect_ttl) @@ -989,7 +1089,7 @@ async def get_events_sse(): """ - def __init__(self, app, headers=None): + def __init__(self, app: ASGI, headers: _HeadersType = None) -> None: if not _is_asgi_app(app): raise CompatibilityError('ASGIConductor may only be used with an ASGI app') @@ -998,9 +1098,9 @@ def __init__(self, app, headers=None): self._shutting_down = asyncio.Condition() self._lifespan_event_collector = helpers.ASGIResponseEventCollector() - self._lifespan_task = None + self._lifespan_task: Optional[asyncio.Task[Any]] = None - async def __aenter__(self): + async def __aenter__(self) -> Self: lifespan_scope = { 'type': ScopeType.LIFESPAN, 'asgi': { @@ -1024,7 +1124,7 @@ async def __aenter__(self): return self - async def __aexit__(self, ex_type, ex, tb): + async def __aexit__(self, ex_type: Any, ex: Any, tb: Any) -> bool: if ex_type: return False @@ -1034,18 +1134,21 @@ async def __aexit__(self, ex_type, ex, tb): self._shutting_down.notify() await _wait_for_shutdown(self._lifespan_event_collector.events) + assert self._lifespan_task is not None await self._lifespan_task return True - async def simulate_get(self, path='/', **kwargs) -> _ResultBase: + async def simulate_get(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a GET request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_get`) """ return await self.simulate_request('GET', path, **kwargs) - def simulate_get_stream(self, path='/', **kwargs): + def simulate_get_stream( + self, path: str = '/', **kwargs: Any + ) -> _AsyncContextManager: """Simulate a GET request to an ASGI application with a streamed response. (See also: :py:meth:`falcon.testing.simulate_get` for a list of @@ -1075,11 +1178,11 @@ def simulate_get_stream(self, path='/', **kwargs): """ - kwargs['_stream_result'] = True - - return _AsyncContextManager(self.simulate_request('GET', path, **kwargs)) + return _AsyncContextManager( + self.simulate_request('GET', path, **kwargs, _stream_result=True) + ) - def simulate_ws(self, path='/', **kwargs): + def simulate_ws(self, path: str = '/', **kwargs: Any) -> _WSContextManager: """Simulate a WebSocket connection to an ASGI application. All keyword arguments are passed through to @@ -1107,49 +1210,67 @@ def simulate_ws(self, path='/', **kwargs): return _WSContextManager(ws, task_req) - async def simulate_head(self, path='/', **kwargs) -> _ResultBase: + async def simulate_head(self, path='/', **kwargs) -> Result: """Simulate a HEAD request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_head`) """ - return await self.simulate_request('HEAD', path, **kwargs) + return await self.simulate_request('HEAD', path, **kwargs, _stream_result=False) - async def simulate_post(self, path='/', **kwargs) -> _ResultBase: + async def simulate_post(self, path='/', **kwargs) -> Result: """Simulate a POST request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_post`) """ - return await self.simulate_request('POST', path, **kwargs) + return await self.simulate_request('POST', path, **kwargs, _stream_result=False) - async def simulate_put(self, path='/', **kwargs) -> _ResultBase: + async def simulate_put(self, path='/', **kwargs) -> Result: """Simulate a PUT request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_put`) """ - return await self.simulate_request('PUT', path, **kwargs) + return await self.simulate_request('PUT', path, **kwargs, _stream_result=False) - async def simulate_options(self, path='/', **kwargs) -> _ResultBase: + async def simulate_options(self, path='/', **kwargs) -> Result: """Simulate an OPTIONS request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_options`) """ - return await self.simulate_request('OPTIONS', path, **kwargs) + return await self.simulate_request( + 'OPTIONS', path, **kwargs, _stream_result=False + ) - async def simulate_patch(self, path='/', **kwargs) -> _ResultBase: + async def simulate_patch(self, path='/', **kwargs) -> Result: """Simulate a PATCH request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_patch`) """ - return await self.simulate_request('PATCH', path, **kwargs) + return await self.simulate_request( + 'PATCH', path, **kwargs, _stream_result=False + ) - async def simulate_delete(self, path='/', **kwargs) -> _ResultBase: + async def simulate_delete(self, path='/', **kwargs) -> Result: """Simulate a DELETE request to an ASGI application. (See also: :py:meth:`falcon.testing.simulate_delete`) """ - return await self.simulate_request('DELETE', path, **kwargs) + return await self.simulate_request( + 'DELETE', path, **kwargs, _stream_result=False + ) + + @overload + async def simulate_request( # type: ignore[overload-overlap] + self, *args: Any, _stream_result: Literal[False] = ..., **kwargs: Any + ) -> Result: + ... + + @overload + async def simulate_request( + self, *args: Any, _stream_result: Literal[True] = ..., **kwargs: Any + ) -> StreamedResult: + ... - async def simulate_request(self, *args, **kwargs) -> _ResultBase: + async def simulate_request(self, *args: Any, **kwargs: Any) -> _AnyResult: """Simulate a request to an ASGI application. Wraps :py:meth:`falcon.testing.simulate_request` to perform a @@ -1163,7 +1284,7 @@ async def simulate_request(self, *args, **kwargs) -> _ResultBase: # set to None. additional_headers = kwargs.get('headers', {}) or {} - merged_headers = self._default_headers.copy() + merged_headers = dict(self._default_headers) merged_headers.update(additional_headers) kwargs['headers'] = merged_headers @@ -1185,7 +1306,7 @@ async def simulate_request(self, *args, **kwargs) -> _ResultBase: websocket = _simulate_method_alias(simulate_ws, replace_name='websocket') -def simulate_get(app, path, **kwargs) -> _ResultBase: +def simulate_get(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a GET request to a WSGI or ASGI application. Equivalent to:: @@ -1288,7 +1409,7 @@ def simulate_get(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'GET', path, **kwargs) -def simulate_head(app, path, **kwargs) -> _ResultBase: +def simulate_head(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a HEAD request to a WSGI or ASGI application. Equivalent to:: @@ -1385,7 +1506,7 @@ def simulate_head(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'HEAD', path, **kwargs) -def simulate_post(app, path, **kwargs) -> _ResultBase: +def simulate_post(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a POST request to a WSGI or ASGI application. Equivalent to:: @@ -1496,7 +1617,7 @@ def simulate_post(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'POST', path, **kwargs) -def simulate_put(app, path, **kwargs) -> _ResultBase: +def simulate_put(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a PUT request to a WSGI or ASGI application. Equivalent to:: @@ -1607,7 +1728,7 @@ def simulate_put(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'PUT', path, **kwargs) -def simulate_options(app, path, **kwargs) -> _ResultBase: +def simulate_options(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate an OPTIONS request to a WSGI or ASGI application. Equivalent to:: @@ -1696,7 +1817,7 @@ def simulate_options(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'OPTIONS', path, **kwargs) -def simulate_patch(app, path, **kwargs) -> _ResultBase: +def simulate_patch(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a PATCH request to a WSGI or ASGI application. Equivalent to:: @@ -1802,7 +1923,7 @@ def simulate_patch(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'PATCH', path, **kwargs) -def simulate_delete(app, path, **kwargs) -> _ResultBase: +def simulate_delete(app: _AnyApp, path: str, **kwargs: Any) -> Result: """Simulate a DELETE request to a WSGI or ASGI application. Equivalent to:: @@ -1908,7 +2029,10 @@ def simulate_delete(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'DELETE', path, **kwargs) -class TestClient: +_A = TypeVar('_A', bound=_AnyApp) + + +class TestClient(Generic[_A]): """Simulate requests to a WSGI or ASGI application. This class provides a contextual wrapper for Falcon's ``simulate_*()`` @@ -1973,12 +2097,12 @@ class TestClient: # NOTE(aryaniyaps): Prevent pytest from collecting tests on the class. __test__ = False - def __init__(self, app, headers=None): - self.app = app + def __init__(self: TestClient[_A], app: _A, headers: _HeadersType = None) -> None: + self.app: _A = app self._default_headers = headers - self._conductor = None + self._conductor: Optional[ASGIConductor] = None - async def __aenter__(self): + async def __aenter__(self: TestClient[ASGI]) -> ASGIConductor: if not _is_asgi_app(self.app): raise CompatibilityError( 'a conductor context manager may only be used with a Falcon ASGI app' @@ -1993,7 +2117,8 @@ async def __aenter__(self): return self._conductor - async def __aexit__(self, ex_type, ex, tb): + async def __aexit__(self: TestClient[ASGI], ex_type: Any, ex: Any, tb: Any) -> bool: + assert self._conductor is not None result = await self._conductor.__aexit__(ex_type, ex, tb) # NOTE(kgriffs): Reset to allow this instance of TestClient to be @@ -2002,56 +2127,56 @@ async def __aexit__(self, ex_type, ex, tb): return result - def simulate_get(self, path='/', **kwargs) -> _ResultBase: + def simulate_get(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a GET request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_get`) """ return self.simulate_request('GET', path, **kwargs) - def simulate_head(self, path='/', **kwargs) -> _ResultBase: + def simulate_head(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a HEAD request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_head`) """ return self.simulate_request('HEAD', path, **kwargs) - def simulate_post(self, path='/', **kwargs) -> _ResultBase: + def simulate_post(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a POST request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_post`) """ return self.simulate_request('POST', path, **kwargs) - def simulate_put(self, path='/', **kwargs) -> _ResultBase: + def simulate_put(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PUT request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_put`) """ return self.simulate_request('PUT', path, **kwargs) - def simulate_options(self, path='/', **kwargs) -> _ResultBase: + def simulate_options(self, path: str = '/', **kwargs: Any) -> Result: """Simulate an OPTIONS request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_options`) """ return self.simulate_request('OPTIONS', path, **kwargs) - def simulate_patch(self, path='/', **kwargs) -> _ResultBase: + def simulate_patch(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PATCH request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_patch`) """ return self.simulate_request('PATCH', path, **kwargs) - def simulate_delete(self, path='/', **kwargs) -> _ResultBase: + def simulate_delete(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a DELETE request to a WSGI application. (See also: :py:meth:`falcon.testing.simulate_delete`) """ return self.simulate_request('DELETE', path, **kwargs) - def simulate_request(self, *args, **kwargs) -> _ResultBase: + def simulate_request(self, *args: Any, **kwargs: Any) -> Result: """Simulate a request to a WSGI application. Wraps :py:meth:`falcon.testing.simulate_request` to perform a @@ -2065,7 +2190,7 @@ def simulate_request(self, *args, **kwargs) -> _ResultBase: # set to None. additional_headers = kwargs.get('headers', {}) or {} - merged_headers = self._default_headers.copy() + merged_headers = dict(self._default_headers) merged_headers.update(additional_headers) kwargs['headers'] = merged_headers @@ -2088,25 +2213,28 @@ def simulate_request(self, *args, **kwargs) -> _ResultBase: class _AsyncContextManager: - def __init__(self, coro): + def __init__(self, coro: Awaitable[StreamedResult]) -> None: self._coro = coro - self._obj = None + self._obj: Optional[StreamedResult] = None - async def __aenter__(self): + async def __aenter__(self) -> StreamedResult: self._obj = await self._coro return self._obj - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> Any: + assert self._obj is not None await self._obj.finalize() self._obj = None class _WSContextManager: - def __init__(self, ws, task_req): + def __init__( + self, ws: helpers.ASGIWebSocketSimulator, task_req: asyncio.Task[None] + ) -> None: self._ws = ws self._task_req = task_req - async def __aenter__(self): + async def __aenter__(self) -> helpers.ASGIWebSocketSimulator: ready_waiter = create_task(self._ws.wait_ready()) # NOTE(kgriffs): Wait on both so that in the case that the request @@ -2131,7 +2259,7 @@ async def __aenter__(self): return self._ws - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: await self._ws.close() await self._task_req @@ -2174,7 +2302,7 @@ def _prepare_sim_args( return path, query_string, headers, body, extras -def _is_asgi_app(app): +def _is_asgi_app(app) -> TypeGuard[ASGI]: app_args = inspect.getfullargspec(app).args num_app_args = len(app_args) diff --git a/falcon/util/sync.py b/falcon/util/sync.py index db9e21e37..ba1f8a491 100644 --- a/falcon/util/sync.py +++ b/falcon/util/sync.py @@ -11,6 +11,8 @@ from typing import TypeVar from typing import Union +from typing_extensions import ParamSpec + __all__ = [ 'async_to_sync', @@ -190,11 +192,14 @@ def _wrap_non_coroutine_unsafe( return wrap_sync_to_async_unsafe(func) +Params = ParamSpec('Params') Result = TypeVar('Result') def async_to_sync( - coroutine: Callable[..., Awaitable[Result]], *args: Any, **kwargs: Any + coroutine: Callable[Params, Awaitable[Result]], + *args: Params.args, + **kwargs: Params.kwargs, ) -> Result: """Invoke a coroutine function from a synchronous caller. diff --git a/setup.cfg b/setup.cfg index e7c308429..6033cb40c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = falcon -version = attr: falcon.__version__ +version = attr: falcon.version.__version__ description = The ultra-reliable, fast ASGI+WSGI framework for building data plane APIs at scale. long_description_content_type = text/x-rst url = https://falconframework.org @@ -54,6 +54,7 @@ include_package_data = True packages = find: python_requires = >=3.7 install_requires = + typing-extensions >= 4.2.0 tests_require = testtools requests