Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes to make app.current_request thread safe #1358

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions chalice/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
from six.moves.BaseHTTPServer import HTTPServer
from six.moves.BaseHTTPServer import BaseHTTPRequestHandler
from six.moves.socketserver import ThreadingMixIn
from typing import List, Any, Dict, Tuple, Callable, Optional, Union # noqa
from typing import (
List,
Any,
Dict,
Tuple,
Callable,
Optional,
Union,
cast,
) # noqa

from chalice.app import Chalice # noqa
from chalice.app import CORSConfig # noqa
Expand Down Expand Up @@ -47,7 +56,9 @@ def time(self):

def create_local_server(app_obj, config, host, port):
# type: (Chalice, Config, str, int) -> LocalDevServer
return LocalDevServer(app_obj, config, host, port)
local_app_obj = LocalChalice(app_obj)
casted_local_app_obj = cast(Chalice, local_app_obj)
return LocalDevServer(casted_local_app_obj, config, host, port)


class LocalARNBuilder(object):
Expand Down Expand Up @@ -661,3 +672,32 @@ def shutdown(self):
# type: () -> None
if self._server is not None:
self._server.shutdown()


class LocalChalice(object):
def __init__(self, chalice):
# type: (Chalice) -> None
self._current_request_lookup = {} # type: Dict[int, Optional[Request]]
self._chalice = chalice

@property
def current_request(self): # noqa
# type: () -> Optional[Request]
thread_id = threading.current_thread().ident
assert thread_id is not None
return self._current_request_lookup.get(thread_id, None)

@current_request.setter
def current_request(self, value): # noqa
# type: (Optional[Request]) -> None
thread_id = threading.current_thread().ident
assert thread_id is not None
self._current_request_lookup[thread_id] = value

def __getattr__(self, name):
# type: (str) -> Any
return getattr(self._chalice, name)

def __call__(self, *args, **kwargs):
# type: (Any, Any) -> Any
return self._chalice(*args, **kwargs)
16 changes: 16 additions & 0 deletions tests/unit/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from chalice.local import ForbiddenError
from chalice.local import InvalidAuthorizerError
from chalice.local import LocalDevServer
from chalice.local import LocalChalice


AWS_REQUEST_ID_PATTERN = re.compile(
Expand Down Expand Up @@ -667,6 +668,21 @@ def test_can_provide_host_to_local_server(sample_app):
assert dev_server.host == '0.0.0.0'


def test_wraps_sample_app_with_local_chalice(sample_app):
dev_server = local.create_local_server(
sample_app, None, "127.0.0.1", 23456
)
assert isinstance(dev_server.app_object, LocalChalice)
assert dev_server.app_object._chalice is sample_app
assert dev_server.app_object.app_name is sample_app.app_name
dev_server.app_object.current_request = "foo"
assert dev_server.app_object.current_request == "foo"
assert (
dev_server.app_object.current_request
is not sample_app.current_request
)


class TestLambdaContext(object):
def test_can_get_remaining_time_once(self, lambda_context_args):
time_source = FakeTimeSource([0, 5])
Expand Down