diff --git a/chalice/local.py b/chalice/local.py index f6ce45102..aa9f6bd3f 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -16,7 +16,16 @@ from six.moves.BaseHTTPServer import HTTPServer from six.moves.BaseHTTPServer import BaseHTTPRequestHandler from six.moves.socketserver import ThreadingMixIn -from typing import List, Any, Dict, Tuple, Callable, Optional, Union # noqa +from typing import ( + List, + Any, + Dict, + Tuple, + Callable, + Optional, + Union, + cast, +) # noqa from chalice.app import Chalice # noqa from chalice.app import CORSConfig # noqa @@ -47,7 +56,9 @@ def time(self): def create_local_server(app_obj, config, host, port): # type: (Chalice, Config, str, int) -> LocalDevServer - return LocalDevServer(app_obj, config, host, port) + local_app_obj = LocalChalice(app_obj) + casted_local_app_obj = cast(Chalice, local_app_obj) + return LocalDevServer(casted_local_app_obj, config, host, port) class LocalARNBuilder(object): @@ -661,3 +672,32 @@ def shutdown(self): # type: () -> None if self._server is not None: self._server.shutdown() + + +class LocalChalice(object): + def __init__(self, chalice): + # type: (Chalice) -> None + self._current_request_lookup = {} # type: Dict[int, Optional[Request]] + self._chalice = chalice + + @property + def current_request(self): # noqa + # type: () -> Optional[Request] + thread_id = threading.current_thread().ident + assert thread_id is not None + return self._current_request_lookup.get(thread_id, None) + + @current_request.setter + def current_request(self, value): # noqa + # type: (Optional[Request]) -> None + thread_id = threading.current_thread().ident + assert thread_id is not None + self._current_request_lookup[thread_id] = value + + def __getattr__(self, name): + # type: (str) -> Any + return getattr(self._chalice, name) + + def __call__(self, *args, **kwargs): + # type: (Any, Any) -> Any + return self._chalice(*args, **kwargs) diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index e21cb3932..529bb3615 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -20,6 +20,7 @@ from chalice.local import ForbiddenError from chalice.local import InvalidAuthorizerError from chalice.local import LocalDevServer +from chalice.local import LocalChalice AWS_REQUEST_ID_PATTERN = re.compile( @@ -667,6 +668,21 @@ def test_can_provide_host_to_local_server(sample_app): assert dev_server.host == '0.0.0.0' +def test_wraps_sample_app_with_local_chalice(sample_app): + dev_server = local.create_local_server( + sample_app, None, "127.0.0.1", 23456 + ) + assert isinstance(dev_server.app_object, LocalChalice) + assert dev_server.app_object._chalice is sample_app + assert dev_server.app_object.app_name is sample_app.app_name + dev_server.app_object.current_request = "foo" + assert dev_server.app_object.current_request == "foo" + assert ( + dev_server.app_object.current_request + is not sample_app.current_request + ) + + class TestLambdaContext(object): def test_can_get_remaining_time_once(self, lambda_context_args): time_source = FakeTimeSource([0, 5])