Skip to content

Commit

Permalink
changes to make app.current_request thread safe when running chalice …
Browse files Browse the repository at this point in the history
…locally

- fixes race conditions that can occur when chalice is being run locally
  and it handling multiple concurrent requests
  • Loading branch information
Joel Tetrault committed Mar 10, 2020
1 parent e2eac21 commit 33e7702
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
44 changes: 42 additions & 2 deletions chalice/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
from six.moves.BaseHTTPServer import HTTPServer
from six.moves.BaseHTTPServer import BaseHTTPRequestHandler
from six.moves.socketserver import ThreadingMixIn
from typing import List, Any, Dict, Tuple, Callable, Optional, Union # noqa
from typing import (
List,
Any,
Dict,
Tuple,
Callable,
Optional,
Union,
cast,
) # noqa

from chalice.app import Chalice # noqa
from chalice.app import CORSConfig # noqa
Expand Down Expand Up @@ -47,7 +56,9 @@ def time(self):

def create_local_server(app_obj, config, host, port):
# type: (Chalice, Config, str, int) -> LocalDevServer
return LocalDevServer(app_obj, config, host, port)
local_app_obj = LocalChalice(app_obj)
casted_local_app_obj = cast(Chalice, local_app_obj)
return LocalDevServer(casted_local_app_obj, config, host, port)


class LocalARNBuilder(object):
Expand Down Expand Up @@ -661,3 +672,32 @@ def shutdown(self):
# type: () -> None
if self._server is not None:
self._server.shutdown()


class LocalChalice(object):
def __init__(self, chalice):
# type: (Chalice) -> None
self._current_request_lookup = {} # type: Dict[int, Optional[Request]]
self._chalice = chalice

@property
def current_request(self): # noqa
# type: () -> Optional[Request]
thread_id = threading.current_thread().ident
assert thread_id is not None
return self._current_request_lookup.get(thread_id, None)

@current_request.setter
def current_request(self, value): # noqa
# type: (Optional[Request]) -> None
thread_id = threading.current_thread().ident
assert thread_id is not None
self._current_request_lookup[thread_id] = value

def __getattr__(self, name):
# type: (str) -> Any
return getattr(self._chalice, name)

def __call__(self, *args, **kwargs):
# type: (Any, Any) -> Any
return self._chalice(*args, **kwargs)
16 changes: 16 additions & 0 deletions tests/unit/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from chalice.local import ForbiddenError
from chalice.local import InvalidAuthorizerError
from chalice.local import LocalDevServer
from chalice.local import LocalChalice


AWS_REQUEST_ID_PATTERN = re.compile(
Expand Down Expand Up @@ -667,6 +668,21 @@ def test_can_provide_host_to_local_server(sample_app):
assert dev_server.host == '0.0.0.0'


def test_wraps_sample_app_with_local_chalice(sample_app):
dev_server = local.create_local_server(
sample_app, None, "127.0.0.1", 23456
)
assert isinstance(dev_server.app_object, LocalChalice)
assert dev_server.app_object._chalice is sample_app
assert dev_server.app_object.app_name is sample_app.app_name
dev_server.app_object.current_request = "foo"
assert dev_server.app_object.current_request == "foo"
assert (
dev_server.app_object.current_request
is not sample_app.current_request
)


class TestLambdaContext(object):
def test_can_get_remaining_time_once(self, lambda_context_args):
time_source = FakeTimeSource([0, 5])
Expand Down

0 comments on commit 33e7702

Please sign in to comment.